(Ab)using gf2p8affineqb to turn indices into bits
@geofflangdale posed the question on Twitter of how to vectorise this:
__mmask64 reference_impl(__m512i indices, __mmask64 valids) {
__mmask64 result = 0;
for (int i = 0; i < 64; ++i) {
if (valids.bit[i]) {
result ^= 1ull << indices.byte[i];
}
}
return result;
}
After a week of code golf also involving @HaroldAptroot, we ended up with:
__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
// Convert indices to bits within each qword lane.
__m512i khi = _mm512_setr_epi8(
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
);
__m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
__m512i klo = _mm512_set1_epi64(0x0102040810204080);
__m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
__m512i kid = _mm512_set1_epi64(0x8040201008040201);
__m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0);
__m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0);
__m512i x0 = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
// Combine results from various qword lanes.
__m512i ktr = _mm512_setr_epi8(
0, 8, 16, 24, 32, 40, 48, 56,
1, 9, 17, 25, 33, 41, 49, 57,
2, 10, 18, 26, 34, 42, 50, 58,
3, 11, 19, 27, 35, 43, 51, 59,
4, 12, 20, 28, 36, 44, 52, 60,
5, 13, 21, 29, 37, 45, 53, 61,
6, 14, 22, 30, 38, 46, 54, 62,
7, 15, 23, 31, 39, 47, 55, 63);
__m512i x1 = _mm512_permutexvar_epi8(ktr, x0);
__m512i x2 = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0);
// Reduce 64 bytes down to 64 bits.
__m512i kff = _mm512_set1_epi8(0xff);
__m512i x3 = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
return _mm512_movepi8_mask(x3);
}
NB: If the valid indices can be assumed to be distinct, then the final reduction from 64 bytes to 64 bits can instead be:
return _mm512_cmpneq_epi8_mask(x2, _mm512_setzero_si512());
As is often the case, simd_impl
looks nothing like reference_impl
, despite doing the same thing. In particular, simd_impl
contains no shifts, and instead contains alternating shuffles and invocations of the mysterious _mm512_gf2p8affine_epi64_epi8
, which is the intrinsic function corresponding to the gf2p8affineqb
assembly instruction. To understand how simd_impl
works, we're going to have to first understand what gf2p8affineqb
does.
There are various ways of understanding what gf2p8affineqb
does, but for the purposes of this blog post, I think the following Python pseudo-code is most useful:
def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
assert len(src1.byte) == len(src2.byte)
dst = vector()
for i in range(len(src1.byte)):
munged_src2 = munge(src2.qword[i // 8])
dst.byte[i] = xor_selected(src1.byte[i], munged_src2, imm8)
return dst
def xor_selected(src1 : u8, munged_src2 : u64, imm8 : u8) -> u8:
result = imm8
for i in range(8):
if src1.bit[i]:
result ^= munged_src2.byte[i]
return result
def munge(x : u64) -> u64:
return transpose8x8(byte_swap(x))
# Or equivalently:
return bitrev_in_each_byte(transpose8x8(x))
def transpose8x8(x : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
result.byte[i].bit[j] = x.byte[j].bit[i]
return result
def byte_swap(x : u64) -> u64:
result = 0
for i in range(8):
result.byte[i] = x.byte[7 - i]
return result
def bitrev_in_each_byte(x : u64) -> u64:
result = 0
for i in range(8):
result.byte[i] = bitrev(x.byte[i])
return result
def bitrev(x : u8) -> u8:
result = 0
for i in range(8):
result.bit[i] = x.bit[7 - i]
return result
The mathematically inclined might notice that the above is in fact doing matrix multiplication of two 8x8 matrices of bits:
def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
assert len(src1.byte) == len(src2.byte)
dst = vector()
for i in range(len(src1.qword)):
dst.qword[i] = matmul(src1.qword[i], munge(src2.qword[i]))
for i in range(len(src1.byte)):
dst.byte[i] ^= imm8
return dst
def matmul(lhs : u64, rhs : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
for k in range(8):
b = lhs.byte[i].bit[j] * rhs.byte[j].bit[k] # * or &
result.byte[i].bit[k] += b # + or ^
return result
def munge(x : u64) -> u64:
# Same as previously
The xor_selected
view of gf2p8affineqb
and the matmul
view of gf2p8affineqb
are complementary: I think that the xor_selected
view makes it clearer what is going on, but the matmul
view is useful for higher level transformations and optimisations. As a middle ground between the two views, matmul
can be re-expressed as byte-level operations by unrolling the k
loop:
def matmul(lhs : u64, rhs : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
if lhs.byte[i].bit[j]:
result.byte[i] ^= rhs.byte[j]
return result
One observation from the matmul
view is that when src1.qword[i]
is the identity matrix, we end up with dst.qword[i]
being munge(src2.qword[i])
. As a 64-bit integer, said identity matrix is 0x8040201008040201
(i.e. in byte i
, just bit i
is set). This explains __m512i kid = _mm512_set1_epi64(0x8040201008040201)
in simd_impl
(kid
is just an identity matrix) and also explains __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0)
and __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0)
- these are just applying munge
to every qword (as for what said munges are achieving, we'll get to later).
Changing tack somewhat, it is time to gradually transform reference_impl
to make it look more like matmul
. For this, we'll start with a simplified version of reference_impl
that takes 8 indices rather than 64:
__mmask64 reference_impl_1(__m64i indices, __mmask8 valids) {
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
result ^= 1ull << indices.byte[i];
}
}
return result;
}
The first transformation is to split each 6-bit index into its low 3 bits and high 3 bits, so that we can address bytes of result
:
__mmask64 reference_impl_2(__m64i indices, __mmask8 valids) {
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
uint8_t b = indices.byte[i];
uint8_t hi = b >> 3;
uint8_t lo = b & 7;
result.byte[hi] ^= 1 << lo;
}
}
return result;
}
Next up we perform loop fission; doing the exact same work, but using two loops rather than one (so that we can focus on the loops separately):
__mmask64 reference_impl_3(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = b & 7;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
result.byte[hi.byte[i]] ^= 1 << lo.byte[i];
}
}
return result;
}
Then the if
and the 1 <<
can also be moved from the 2nd loop to the 1st loop:
__mmask64 reference_impl_4(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
result.byte[hi.byte[i]] ^= lo.byte[i];
}
return result;
}
Then a transformation that looks utterly deranged, but is key to the SIMD transformation; rather than directly indexing using hi.byte[i]
, we'll loop over the 8 possible values of hi.byte[i]
and act when we find the right value:
__mmask64 reference_impl_5(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[i] == j) {
result.byte[j] ^= lo.byte[i];
}
}
}
return result;
}
Next up we perform loop interchange of the two nested loops:
__mmask64 reference_impl_6(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[j] == i) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
Then another transformation that initially looks deranged; the ==
in hi.byte[j] == i
is annoying, and can be replaced by a bit test if we one-hot encode hi
:
__mmask64 reference_impl_7(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[j].bit[i]) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
Then one final transformation to get where we want to be; apply transpose8x8
to hi
, and undo it by changing .byte[j].bit[i]
to .byte[i].bit[j]
:
__mmask64 reference_impl_8(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (transpose8x8(hi).byte[i].bit[j]) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
A number of these transformations seemed pointless or even unhelpful, but having done them all, the latter half of reference_impl_8
is exactly result = matmul(transpose8x8(hi), lo)
.
The expression matmul(transpose8x8(A), B)
looks deceptively similar to the matmul(A, munge(B))
done by gf2p8affineqb(A, B, 0)
, and if munge
was just transpose8x8
, then gf2p8affineqb(munge(A), munge(B), 0)
would be exactly matmul(transpose8x8(A), B)
. Unfortunately, munge
also does a bit or byte reversal, causing gf2p8affineqb(munge(A), munge(B), 0)
to actually be matmul(transpose8x8(A), bitrev_in_each_byte(B))
(if deriving this, note that munge(A)
is bitrev_in_each_byte(transpose8x8(A))
, munge(munge(B))
is byte_swap(bitrev_in_each_byte(B))
, and then the bitrev_in_each_byte
on A
cancels out with the byte_swap
on B
).
The expression matmul(transpose8x8(A), bitrev_in_each_byte(B))
is very close to what we want, and the errant bitrev_in_each_byte
can be cancelled out by doing another bitrev_in_each_byte
on B
:
__mmask64 reference_impl_9(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = bitrev(valids.bit[i] ? 1 << (b & 7) : 0);
}
__mmask64 result = gf2p8affineqb(munge(hi), munge(lo), 0);
return result;
}
The 1st loop is easy to express in a SIMD manner via a pair of table lookups, thereby giving us the first chunk of simd_impl
:
__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
// Convert indices to bits within each qword lane.
__m512i khi = _mm512_setr_epi8(
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
);
__m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
__m512i klo = _mm512_set1_epi64(0x0102040810204080);
__m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
__m512i kid = _mm512_set1_epi64(0x8040201008040201);
__m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0); // munge
__m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0); // munge
__m512i x0 = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
}
At this point, x0.qword[i]
contains reference_impl_9(indices.qword[i], valids.word[i])
. To finish up, "all" we need to do is xor together the eight qwords of x0
. The traditional way of doing this would be a shuffle followed by a xor to reduce eight to four, another shuffle followed by a xor to reduce four to two, and yet another shuffle followed by a xor to reduce two to one. We can do better than the traditional approach though. The first step is to do one big shuffle rather than three sequential suffles, where the result of the big shuffle moves the eight bytes qword[i].byte[0]
to be contiguous, then the eight bytes qword[i].byte[1]
to be contiguous, and so on. Seen differently, the bug shuffle is a transpose on an 8x8 matrix of bytes. After this big shuffle, the remaining problem is to take each contiguous group of eight bytes and xor them together. If we wanted to add together each contiguous group of eight bytes, then _mm512_sad_epu8
against zero would be one option, but we want xor rather than add. There are a few different ways of approaching the problem, but one cute way is to apply transpose8x8
to each contiguous group of eight bytes, after which we just need to xor together each contiguous group of eight bits. Applying transpose8x8
on its own is hard, but we can apply munge
fairly easily, which does transpose8x8
followed by bitrev_in_each_byte
, and the bitrev_in_each_byte
is harmless given that we're about to xor together the bits in each byte. This gives us the next chunk of simd_impl
:
// Combine results from various qword lanes.
__m512i ktr = _mm512_setr_epi8(
0, 8, 16, 24, 32, 40, 48, 56,
1, 9, 17, 25, 33, 41, 49, 57,
2, 10, 18, 26, 34, 42, 50, 58,
3, 11, 19, 27, 35, 43, 51, 59,
4, 12, 20, 28, 36, 44, 52, 60,
5, 13, 21, 29, 37, 45, 53, 61,
6, 14, 22, 30, 38, 46, 54, 62,
7, 15, 23, 31, 39, 47, 55, 63);
__m512i x1 = _mm512_permutexvar_epi8(ktr, x0); // transpose bytes
__m512i x2 = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0); // munge
If the valid indices can be assumed to be distinct, then we can or (rather than xor) together the bits in each byte, which is just _mm512_cmpneq_epi8_mask
against zero.
If we really do need to xor the bits together, then what we want is this function applied to every byte:
def xor_together_bits(x : u8) -> u8:
result = 0
for i in range(8):
if x.bit[i]:
result ^= 0xff
return result
If you're thinking that xor_together_bits
looks very similar to xor_selected
, then you'd be right: xor_together_bits
is just xor_selected
where every byte of munged_src2
is 0xff
, and it so happens that if every byte of src2
is 0xff
, then the same is true for munged_src2
. This gives the final chunk of simd_impl
:
// Reduce 64 bytes down to 64 bits.
__m512i kff = _mm512_set1_epi8(0xff);
__m512i x3 = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
return _mm512_movepi8_mask(x3);