diff --git a/README.md b/README.md index dd106f6..fb53e14 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ go-and [![Go Reference](https://pkg.go.dev/badge/github.com/bwesterb/go-and.svg)](https://pkg.go.dev/github.com/bwesterb/go-and) -Fast bitwise and, or, andn, popcount and memset for `[]byte` slices. +Fast bitwise and, or, xor, andn, popcount and memset for `[]byte` slices. ```go import "github.com/bwesterb/go-and" diff --git a/and_amd64.s b/and_amd64.s index b4d8b0b..be0a0e4 100644 --- a/and_amd64.s +++ b/and_amd64.s @@ -100,6 +100,54 @@ loop: JNZ loop RET +// func xorAVX2(dst *byte, a *byte, b *byte, l uint64) +// Requires: AVX, AVX2 +TEXT ·xorAVX2(SB), NOSPLIT, $0-32 + MOVQ a+8(FP), AX + MOVQ b+16(FP), CX + MOVQ dst+0(FP), DX + MOVQ l+24(FP), BX + +loop: + VMOVDQU (AX), Y0 + VMOVDQU (CX), Y8 + VMOVDQU 32(AX), Y1 + VMOVDQU 32(CX), Y9 + VMOVDQU 64(AX), Y2 + VMOVDQU 64(CX), Y10 + VMOVDQU 96(AX), Y3 + VMOVDQU 96(CX), Y11 + VMOVDQU 128(AX), Y4 + VMOVDQU 128(CX), Y12 + VMOVDQU 160(AX), Y5 + VMOVDQU 160(CX), Y13 + VMOVDQU 192(AX), Y6 + VMOVDQU 192(CX), Y14 + VMOVDQU 224(AX), Y7 + VMOVDQU 224(CX), Y15 + VPXOR Y8, Y0, Y8 + VPXOR Y9, Y1, Y9 + VPXOR Y10, Y2, Y10 + VPXOR Y11, Y3, Y11 + VPXOR Y12, Y4, Y12 + VPXOR Y13, Y5, Y13 + VPXOR Y14, Y6, Y14 + VPXOR Y15, Y7, Y15 + VMOVDQU Y8, (DX) + VMOVDQU Y9, 32(DX) + VMOVDQU Y10, 64(DX) + VMOVDQU Y11, 96(DX) + VMOVDQU Y12, 128(DX) + VMOVDQU Y13, 160(DX) + VMOVDQU Y14, 192(DX) + VMOVDQU Y15, 224(DX) + ADDQ $0x00000100, AX + ADDQ $0x00000100, CX + ADDQ $0x00000100, DX + SUBQ $0x00000001, BX + JNZ loop + RET + // func andNotAVX2(dst *byte, a *byte, b *byte, l uint64) // Requires: AVX, AVX2 TEXT ·andNotAVX2(SB), NOSPLIT, $0-32 diff --git a/and_arm64.go b/and_arm64.go index 019d1b8..aff13d1 100644 --- a/and_arm64.go +++ b/and_arm64.go @@ -8,6 +8,9 @@ func andNEON(dst, a, b *byte, l uint64) //go:noescape func orNEON(dst, a, b *byte, l uint64) +//go:noescape +func xorNEON(dst, a, b *byte, l uint64) + //go:noescape func popcntNEON(a *byte, l uint64) uint64 @@ -29,6 +32,15 @@ func or(dst, a, b []byte) { orGeneric(dst[l:], a[l:], b[l:]) } +func xor(dst, a, b []byte) { + l := uint64(len(a)) >> 8 + if l != 0 { + xorNEON(&dst[0], &a[0], &b[0], l) + } + l <<= 8 + xorGeneric(dst[l:], a[l:], b[l:]) +} + func andNot(dst, a, b []byte) { // TODO: Write a NEON version for this andNotGeneric(dst, a, b) diff --git a/and_arm64.s b/and_arm64.s index 883de34..296328a 100644 --- a/and_arm64.s +++ b/and_arm64.s @@ -49,6 +49,53 @@ loop: RET +// func xorNEON(dst *byte, a *byte, b *byte, l uint64) +TEXT ·xorNEON(SB), NOSPLIT, $0-32 + MOVD dst+0(FP), R0 + MOVD a+8(FP), R1 + MOVD b+16(FP), R2 + MOVD l+24(FP), R3 + +loop: + VLD1.P 64(R1), [ V0.B16, V1.B16, V2.B16, V3.B16] + VLD1.P 64(R2), [ V4.B16, V5.B16, V6.B16, V7.B16] + VLD1.P 64(R1), [ V8.B16, V9.B16, V10.B16, V11.B16] + VLD1.P 64(R2), [V12.B16, V13.B16, V14.B16, V15.B16] + VLD1.P 64(R1), [V16.B16, V17.B16, V18.B16, V19.B16] + VLD1.P 64(R2), [V20.B16, V21.B16, V22.B16, V23.B16] + VLD1.P 64(R1), [V24.B16, V25.B16, V26.B16, V27.B16] + VLD1.P 64(R2), [V28.B16, V29.B16, V30.B16, V31.B16] + + VEOR V0.B16, V4.B16, V0.B16 + VEOR V1.B16, V5.B16, V1.B16 + VEOR V2.B16, V6.B16, V2.B16 + VEOR V3.B16, V7.B16, V3.B16 + + VEOR V8.B16, V12.B16, V8.B16 + VEOR V9.B16, V13.B16, V9.B16 + VEOR V10.B16, V14.B16, V10.B16 + VEOR V11.B16, V15.B16, V11.B16 + + VEOR V16.B16, V20.B16, V16.B16 + VEOR V17.B16, V21.B16, V17.B16 + VEOR V18.B16, V22.B16, V18.B16 + VEOR V19.B16, V23.B16, V19.B16 + + VEOR V24.B16, V28.B16, V24.B16 + VEOR V25.B16, V29.B16, V25.B16 + VEOR V26.B16, V30.B16, V26.B16 + VEOR V27.B16, V31.B16, V27.B16 + + VST1.P [ V0.B16, V1.B16, V2.B16, V3.B16], 64(R0) + VST1.P [ V8.B16, V9.B16, V10.B16, V11.B16], 64(R0) + VST1.P [V16.B16, V17.B16, V18.B16, V19.B16], 64(R0) + VST1.P [V24.B16, V25.B16, V26.B16, V27.B16], 64(R0) + + SUBS $1, R3, R3 + CBNZ R3, loop + + RET + // func orNEON(dst *byte, a *byte, b *byte, l uint64) TEXT ·orNEON(SB), NOSPLIT, $0-32 MOVD dst+0(FP), R0 diff --git a/and_stubs_amd64.go b/and_stubs_amd64.go index be685eb..fcdb598 100644 --- a/and_stubs_amd64.go +++ b/and_stubs_amd64.go @@ -14,6 +14,11 @@ func andAVX2(dst *byte, a *byte, b *byte, l uint64) //go:noescape func orAVX2(dst *byte, a *byte, b *byte, l uint64) +// Sets dst to the bitwise xor of a and b assuming all are 256*l bytes +// +//go:noescape +func xorAVX2(dst *byte, a *byte, b *byte, l uint64) + // Sets dst to the bitwise and of not(a) and b assuming all are 256*l bytes // //go:noescape diff --git a/and_test.go b/and_test.go index 8369a55..cea64c6 100644 --- a/and_test.go +++ b/and_test.go @@ -8,6 +8,12 @@ import ( "testing" ) +func xorNaive(dst, a, b []byte) { + for i := range dst { + dst[i] = a[i] ^ b[i] + } +} + func andNaive(dst, a, b []byte) { for i := range dst { dst[i] = a[i] & b[i] @@ -55,6 +61,18 @@ func TestAnd(t *testing.T) { } } +func TestXor(t *testing.T) { + for i := 0; i < 20; i++ { + size := 1 << i + testAgainst(t, Xor, xorNaive, size) + testAgainst(t, xorGeneric, xorNaive, size) + for j := 0; j < 10; j++ { + testAgainst(t, Xor, xorNaive, size+rand.IntN(100)) + testAgainst(t, xorGeneric, xorNaive, size+rand.IntN(100)) + } + } +} + func TestOr(t *testing.T) { for i := 0; i < 20; i++ { size := 1 << i @@ -151,6 +169,42 @@ func BenchmarkOrNaive(b *testing.B) { } } +func BenchmarkXor(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + Xor(a, a, bb) + } +} + +func BenchmarkXorGeneric(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + xorGeneric(a, a, bb) + } +} + +func BenchmarkXorNaive(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + xorNaive(a, a, bb) + } +} + func BenchmarkAndNot(b *testing.B) { b.StopTimer() size := 1000000 diff --git a/internal/asm/src.go b/internal/asm/src.go index fca8996..e5dc5c4 100644 --- a/internal/asm/src.go +++ b/internal/asm/src.go @@ -11,6 +11,7 @@ func main() { gen("and", VPAND, "Sets dst to the bitwise and of a and b") gen("or", VPOR, "Sets dst to the bitwise or of a and b") + gen("xor", VPXOR, "Sets dst to the bitwise xor of a and b") gen("andNot", VPANDN, "Sets dst to the bitwise and of not(a) and b") genPopcnt() genMemset() diff --git a/lib.go b/lib.go index f0e6d65..3751425 100644 --- a/lib.go +++ b/lib.go @@ -58,6 +58,32 @@ func orGeneric(dst, a, b []byte) { } } +// Writes bitwise xor of a and b to dst. +// +// Panics if len(a) ≠ len(b), or len(dst) ≠ len(a). +func Xor(dst, a, b []byte) { + if len(a) != len(b) || len(b) != len(dst) { + panic("lengths of a, b and dst must be equal") + } + + xor(dst, a, b) +} + +func xorGeneric(dst, a, b []byte) { + i := 0 + + for ; i <= len(a)-8; i += 8 { + binary.LittleEndian.PutUint64( + dst[i:], + binary.LittleEndian.Uint64(a[i:])^binary.LittleEndian.Uint64(b[i:]), + ) + } + + for ; i < len(a); i++ { + dst[i] = a[i] ^ b[i] + } +} + // Writes bitwise and of not(a) and b to dst. // // Panics if len(a) ≠ len(b), or len(dst) ≠ len(a).