Cで2つの4x4行列を乗算するための、より高速でトリッキーな方法を探しています。現在の研究は、SIMD拡張機能を備えたx86-64アセンブリに焦点を当てています。これまでのところ、私は関数witchを作成しましたが、これは単純なC実装よりも約6倍高速であり、パフォーマンスの向上に対する私の期待を上回っています。残念ながら、これは、コンパイルに最適化フラグが使用されていない場合にのみ当てはまります(GCC4.7)。 -O2
、Cが速くなり、私の努力は無意味になります。
最近のコンパイラーは、複雑な最適化手法を利用して、ほぼ完璧なコードを達成していることを知っています。通常、独創的な手作りのアセンブリーよりも高速です。しかし、パフォーマンスが重要な少数のケースでは、人間がコンパイラを使用してクロックサイクルを争おうとする場合があります。特に、現代のISA)に裏打ちされたいくつかの数学を探求できる場合(私の場合のように)。
私の関数は次のようになります(AT&T構文、GNUアセンブラ):
.text
.globl matrixMultiplyASM
.type matrixMultiplyASM, @function
matrixMultiplyASM:
movaps (%rdi), %xmm0 # fetch the first matrix (use four registers)
movaps 16(%rdi), %xmm1
movaps 32(%rdi), %xmm2
movaps 48(%rdi), %xmm3
xorq %rcx, %rcx # reset (forward) loop iterator
.ROW:
movss (%rsi), %xmm4 # Compute four values (one row) in parallel:
shufps $0x0, %xmm4, %xmm4 # 4x 4FP mul's, 3x 4FP add's 6x mov's per row,
mulps %xmm0, %xmm4 # expressed in four sequences of 5 instructions,
movaps %xmm4, %xmm5 # executed 4 times for 1 matrix multiplication.
addq $0x4, %rsi
movss (%rsi), %xmm4 # movss + shufps comprise _mm_set1_ps intrinsic
shufps $0x0, %xmm4, %xmm4 #
mulps %xmm1, %xmm4
addps %xmm4, %xmm5
addq $0x4, %rsi # manual pointer arithmetic simplifies addressing
movss (%rsi), %xmm4
shufps $0x0, %xmm4, %xmm4
mulps %xmm2, %xmm4 # actual computation happens here
addps %xmm4, %xmm5 #
addq $0x4, %rsi
movss (%rsi), %xmm4 # one mulps operand fetched per sequence
shufps $0x0, %xmm4, %xmm4 # |
mulps %xmm3, %xmm4 # the other is already waiting in %xmm[0-3]
addps %xmm4, %xmm5
addq $0x4, %rsi # 5 preceding comments stride among the 4 blocks
movaps %xmm5, (%rdx,%rcx) # store the resulting row, actually, a column
addq $0x10, %rcx # (matrices are stored in column-major order)
cmpq $0x40, %rcx
jne .ROW
ret
.size matrixMultiplyASM, .-matrixMultiplyASM
128ビットSSEレジスタにパックされた4つのfloatを処理することにより、反復ごとに結果の行列の列全体を計算します。完全なベクトル化は、少しの数学(演算の並べ替えと集計)とmullps
/addps
4xfloatパッケージの並列乗算/加算の命令。コードはパラメーターを渡すためのレジスターを再利用します(%rdi
、%rsi
、%rdx
:GNU/Linux ABI)、(内部)ループ展開の恩恵を受け、1つの行列を完全にXMMレジスターに保持して、メモリの読み取りを減らします。ご覧のとおり、私はこのトピックを調査し、時間をかけて可能な限り実装しました。
私のコードを征服する素朴なC計算は次のようになります。
void matrixMultiplyNormal(mat4_t *mat_a, mat4_t *mat_b, mat4_t *mat_r) {
for (unsigned int i = 0; i < 16; i += 4)
for (unsigned int j = 0; j < 4; ++j)
mat_r->m[i + j] = (mat_b->m[i + 0] * mat_a->m[j + 0])
+ (mat_b->m[i + 1] * mat_a->m[j + 4])
+ (mat_b->m[i + 2] * mat_a->m[j + 8])
+ (mat_b->m[i + 3] * mat_a->m[j + 12]);
}
上記のCコードの最適化されたアセンブリ出力を調査しました。これは、フロートをXMMレジスタに格納する一方で、並列演算は含まれません –スカラー計算、ポインタ演算、および条件付きジャンプのみです。コンパイラのコードはあまり意図的ではないようですが、それでも、ベクトル化されたバージョンよりも約4倍高速であると予想されるよりもわずかに効果的です。一般的な考え方は正しいと思います。プログラマーは同じようなことを行い、やりがいのある結果をもたらします。しかし、ここで何が問題になっていますか?知らないレジスタ割り当てや命令スケジューリングの問題はありますか?マシンとの戦いをサポートするためのx86-64アセンブリツールまたはトリックを知っていますか?
コードを高速化し、コンパイラーを凌駕する方法があります。高度なパイプライン分析や詳細なコードのマイクロ最適化は含まれていません(これらのメリットをこれ以上享受できないという意味ではありません)。最適化では、次の3つの簡単なトリックを使用します。
この関数は32バイトに調整されました(これによりパフォーマンスが大幅に向上しました)。
メインループは逆になり、ゼロテスト(EFLAGSに基づく)との比較が減ります。
命令レベルのアドレス演算は、「外部」ポインタ計算よりも高速であることが証明されました(ただし、「3/4の場合」に2倍の追加が必要です)。ループ本体を4命令短縮し、実行パス内のデータ依存性を減らしました。 関連する質問を参照 。
さらに、コードは、GCCがインライン化しようとしたときに発生するシンボル再定義エラーを抑制する相対ジャンプ構文を使用します(asm
ステートメント内に配置され、-O3
でコンパイルされた後)。
.text
.align 32 # 1. function entry alignment
.globl matrixMultiplyASM # (for a faster call)
.type matrixMultiplyASM, @function
matrixMultiplyASM:
movaps (%rdi), %xmm0
movaps 16(%rdi), %xmm1
movaps 32(%rdi), %xmm2
movaps 48(%rdi), %xmm3
movq $48, %rcx # 2. loop reversal
1: # (for simpler exit condition)
movss (%rsi, %rcx), %xmm4 # 3. extended address operands
shufps $0, %xmm4, %xmm4 # (faster than pointer calculation)
mulps %xmm0, %xmm4
movaps %xmm4, %xmm5
movss 4(%rsi, %rcx), %xmm4
shufps $0, %xmm4, %xmm4
mulps %xmm1, %xmm4
addps %xmm4, %xmm5
movss 8(%rsi, %rcx), %xmm4
shufps $0, %xmm4, %xmm4
mulps %xmm2, %xmm4
addps %xmm4, %xmm5
movss 12(%rsi, %rcx), %xmm4
shufps $0, %xmm4, %xmm4
mulps %xmm3, %xmm4
addps %xmm4, %xmm5
movaps %xmm5, (%rdx, %rcx)
subq $16, %rcx # one 'sub' (vs 'add' & 'cmp')
jge 1b # SF=OF, idiom: jump if positive
ret
これは、これまでに見た中で最速のx86-64実装です。私は感謝し、投票し、その目的のためのより速いアセンブリを提供する答えを受け入れます!
4x4行列の乗算は、64回の乗算と48回の加算です。 SSEこれは16の乗算と12の加算(および16のブロードキャスト)に減らすことができます。次のコードはこれを行います。必要なのはSSE( #include <xmmintrin.h>
)。配列A
、B
、およびC
は、16バイトに整列する必要があります。 hadd
(SSE3)やdpps
(SSE4.1)などの水平方向の命令を使用すると、 効率が低下します (特にdpps
)になります。ループ展開が役立つかどうかはわかりません。
void M4x4_SSE(float *A, float *B, float *C) {
__m128 row1 = _mm_load_ps(&B[0]);
__m128 row2 = _mm_load_ps(&B[4]);
__m128 row3 = _mm_load_ps(&B[8]);
__m128 row4 = _mm_load_ps(&B[12]);
for(int i=0; i<4; i++) {
__m128 brod1 = _mm_set1_ps(A[4*i + 0]);
__m128 brod2 = _mm_set1_ps(A[4*i + 1]);
__m128 brod3 = _mm_set1_ps(A[4*i + 2]);
__m128 brod4 = _mm_set1_ps(A[4*i + 3]);
__m128 row = _mm_add_ps(
_mm_add_ps(
_mm_mul_ps(brod1, row1),
_mm_mul_ps(brod2, row2)),
_mm_add_ps(
_mm_mul_ps(brod3, row3),
_mm_mul_ps(brod4, row4)));
_mm_store_ps(&C[4*i], row);
}
}
行列の1つを転置することが有益であるかどうか疑問に思います。
次の2つの行列をどのように乗算するかを考えてください...
A1 A2 A3 A4 W1 W2 W3 W4
B1 B2 B3 B4 X1 X2 X3 X4
C1 C2 C3 C4 * Y1 Y2 Y3 Y4
D1 D2 D3 D4 Z1 Z2 Z3 Z4
これにより、...
dot(A,?1) dot(A,?2) dot(A,?3) dot(A,?4)
dot(B,?1) dot(B,?2) dot(B,?3) dot(B,?4)
dot(C,?1) dot(C,?2) dot(C,?3) dot(C,?4)
dot(D,?1) dot(D,?2) dot(D,?3) dot(D,?4)
行と列の内積を行うのは面倒です。
乗算する前に2番目の行列を転置した場合はどうなりますか?
A1 A2 A3 A4 W1 X1 Y1 Z1
B1 B2 B3 B4 W2 X2 Y2 Z2
C1 C2 C3 C4 * W3 X3 Y3 Z3
D1 D2 D3 D4 W4 X4 Y4 Z4
現在、行と列の内積を実行する代わりに、2行の内積を実行しています。これは、SIMD命令のより良い使用に役立つ可能性があります。
お役に立てれば。
上記のSandyBridgeは、8要素のベクトル演算をサポートするように命令セットを拡張します。この実装を検討してください。
struct MATRIX {
union {
float f[4][4];
__m128 m[4];
__m256 n[2];
};
};
MATRIX myMultiply(MATRIX M1, MATRIX M2) {
// Perform a 4x4 matrix multiply by a 4x4 matrix
// Be sure to run in 64 bit mode and set right flags
// Properties, C/C++, Enable Enhanced Instruction, /Arch:AVX
// Having MATRIX on a 32 byte bundry does help performance
MATRIX mResult;
__m256 a0, a1, b0, b1;
__m256 c0, c1, c2, c3, c4, c5, c6, c7;
__m256 t0, t1, u0, u1;
t0 = M1.n[0]; // t0 = a00, a01, a02, a03, a10, a11, a12, a13
t1 = M1.n[1]; // t1 = a20, a21, a22, a23, a30, a31, a32, a33
u0 = M2.n[0]; // u0 = b00, b01, b02, b03, b10, b11, b12, b13
u1 = M2.n[1]; // u1 = b20, b21, b22, b23, b30, b31, b32, b33
a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0)); // a0 = a00, a00, a00, a00, a10, a10, a10, a10
a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0)); // a1 = a20, a20, a20, a20, a30, a30, a30, a30
b0 = _mm256_permute2f128_ps(u0, u0, 0x00); // b0 = b00, b01, b02, b03, b00, b01, b02, b03
c0 = _mm256_mul_ps(a0, b0); // c0 = a00*b00 a00*b01 a00*b02 a00*b03 a10*b00 a10*b01 a10*b02 a10*b03
c1 = _mm256_mul_ps(a1, b0); // c1 = a20*b00 a20*b01 a20*b02 a20*b03 a30*b00 a30*b01 a30*b02 a30*b03
a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1)); // a0 = a01, a01, a01, a01, a11, a11, a11, a11
a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1)); // a1 = a21, a21, a21, a21, a31, a31, a31, a31
b0 = _mm256_permute2f128_ps(u0, u0, 0x11); // b0 = b10, b11, b12, b13, b10, b11, b12, b13
c2 = _mm256_mul_ps(a0, b0); // c2 = a01*b10 a01*b11 a01*b12 a01*b13 a11*b10 a11*b11 a11*b12 a11*b13
c3 = _mm256_mul_ps(a1, b0); // c3 = a21*b10 a21*b11 a21*b12 a21*b13 a31*b10 a31*b11 a31*b12 a31*b13
a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2)); // a0 = a02, a02, a02, a02, a12, a12, a12, a12
a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2)); // a1 = a22, a22, a22, a22, a32, a32, a32, a32
b1 = _mm256_permute2f128_ps(u1, u1, 0x00); // b0 = b20, b21, b22, b23, b20, b21, b22, b23
c4 = _mm256_mul_ps(a0, b1); // c4 = a02*b20 a02*b21 a02*b22 a02*b23 a12*b20 a12*b21 a12*b22 a12*b23
c5 = _mm256_mul_ps(a1, b1); // c5 = a22*b20 a22*b21 a22*b22 a22*b23 a32*b20 a32*b21 a32*b22 a32*b23
a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3)); // a0 = a03, a03, a03, a03, a13, a13, a13, a13
a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3)); // a1 = a23, a23, a23, a23, a33, a33, a33, a33
b1 = _mm256_permute2f128_ps(u1, u1, 0x11); // b0 = b30, b31, b32, b33, b30, b31, b32, b33
c6 = _mm256_mul_ps(a0, b1); // c6 = a03*b30 a03*b31 a03*b32 a03*b33 a13*b30 a13*b31 a13*b32 a13*b33
c7 = _mm256_mul_ps(a1, b1); // c7 = a23*b30 a23*b31 a23*b32 a23*b33 a33*b30 a33*b31 a33*b32 a33*b33
c0 = _mm256_add_ps(c0, c2); // c0 = c0 + c2 (two terms, first two rows)
c4 = _mm256_add_ps(c4, c6); // c4 = c4 + c6 (the other two terms, first two rows)
c1 = _mm256_add_ps(c1, c3); // c1 = c1 + c3 (two terms, second two rows)
c5 = _mm256_add_ps(c5, c7); // c5 = c5 + c7 (the other two terms, second two rose)
// Finally complete addition of all four terms and return the results
mResult.n[0] = _mm256_add_ps(c0, c4); // n0 = a00*b00+a01*b10+a02*b20+a03*b30 a00*b01+a01*b11+a02*b21+a03*b31 a00*b02+a01*b12+a02*b22+a03*b32 a00*b03+a01*b13+a02*b23+a03*b33
// a10*b00+a11*b10+a12*b20+a13*b30 a10*b01+a11*b11+a12*b21+a13*b31 a10*b02+a11*b12+a12*b22+a13*b32 a10*b03+a11*b13+a12*b23+a13*b33
mResult.n[1] = _mm256_add_ps(c1, c5); // n1 = a20*b00+a21*b10+a22*b20+a23*b30 a20*b01+a21*b11+a22*b21+a23*b31 a20*b02+a21*b12+a22*b22+a23*b32 a20*b03+a21*b13+a22*b23+a23*b33
// a30*b00+a31*b10+a32*b20+a33*b30 a30*b01+a31*b11+a32*b21+a33*b31 a30*b02+a31*b12+a32*b22+a33*b32 a30*b03+a31*b13+a32*b23+a33*b33
return mResult;
}