Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light up BitArray APIs with Vector512 code path #91903

Merged
merged 13 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 95 additions & 33 deletions src/libraries/System.Collections/src/System/Collections/BitArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ public BitArray(byte[] bytes)
_version = 0;
}

private const uint Vector128ByteCount = 16;
private const uint Vector128IntCount = 4;
private const uint Vector256ByteCount = 32;
private const uint Vector256IntCount = 8;
public unsafe BitArray(bool[] values)
{
ArgumentNullException.ThrowIfNull(values);
Expand All @@ -138,10 +134,21 @@ public unsafe BitArray(bool[] values)
// Instead, We compare with zeroes (== false) then negate the result to ensure compatibility.

ref byte value = ref Unsafe.As<bool, byte>(ref MemoryMarshal.GetArrayDataReference<bool>(values));
if (Vector512.IsHardwareAccelerated)
{
for (; i <= (uint)values.Length - Vector512<byte>.Count; i += (uint)Vector512<byte>.Count)
{
Vector512<byte> vector = Vector512.LoadUnsafe(ref value, i);
Vector512<byte> isFalse = Vector512.Equals(vector, Vector512<byte>.Zero);

if (Vector256.IsHardwareAccelerated)
ulong result = isFalse.ExtractMostSignificantBits();
m_array[i / 32u] = (int)(~result & 0x00000000FFFFFFFF);
m_array[(i / 32u) + 1] = (int)((~result >> 32) & 0x00000000FFFFFFFF);
}
}
else if (Vector256.IsHardwareAccelerated)
{
for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount)
for (; i <= (uint)values.Length - Vector256<byte>.Count; i += (uint)Vector256<byte>.Count)
{
Vector256<byte> vector = Vector256.LoadUnsafe(ref value, i);
Vector256<byte> isFalse = Vector256.Equals(vector, Vector256<byte>.Zero);
Expand All @@ -152,13 +159,13 @@ public unsafe BitArray(bool[] values)
}
else if (Vector128.IsHardwareAccelerated)
{
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
for (; i <= (uint)values.Length - Vector128<byte>.Count * 2u; i += (uint)Vector128<byte>.Count * 2u)
{
Vector128<byte> lowerVector = Vector128.LoadUnsafe(ref value, i);
Vector128<byte> lowerIsFalse = Vector128.Equals(lowerVector, Vector128<byte>.Zero);
uint lowerResult = lowerIsFalse.ExtractMostSignificantBits();

Vector128<byte> upperVector = Vector128.LoadUnsafe(ref value, i + Vector128ByteCount);
Vector128<byte> upperVector = Vector128.LoadUnsafe(ref value, i + (uint)Vector128<byte>.Count);
Vector128<byte> upperIsFalse = Vector128.Equals(upperVector, Vector128<byte>.Zero);
uint upperResult = upperIsFalse.ExtractMostSignificantBits();

Expand Down Expand Up @@ -339,18 +346,25 @@ public unsafe BitArray And(BitArray value)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512<int>.Count)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
for (; i < (uint)count - (Vector512<int>.Count - 1u); i += (uint)Vector512<int>.Count)
{
Vector512<int> result = Vector512.LoadUnsafe(ref left, i) & Vector512.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256<int>.Count)
{
for (; i < (uint)count - (Vector256<int>.Count - 1u); i += (uint)Vector256<int>.Count)
{
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) & Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector128.IsHardwareAccelerated)
else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128<int>.Count)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
for (; i < (uint)count - (Vector128<int>.Count - 1u); i += (uint)Vector128<int>.Count)
{
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) & Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
Expand Down Expand Up @@ -405,18 +419,25 @@ public unsafe BitArray Or(BitArray value)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512<int>.Count)
{
for (; i < (uint)count - (Vector512<int>.Count - 1u); i += (uint)Vector512<int>.Count)
{
Vector512<int> result = Vector512.LoadUnsafe(ref left, i) | Vector512.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256<int>.Count)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
for (; i < (uint)count - (Vector256<int>.Count - 1u); i += (uint)Vector256<int>.Count)
{
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) | Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector128.IsHardwareAccelerated)
else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128<int>.Count)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
for (; i < (uint)count - (Vector128<int>.Count - 1u); i += (uint)Vector128<int>.Count)
{
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) | Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
Expand Down Expand Up @@ -472,17 +493,25 @@ public unsafe BitArray Xor(BitArray value)
ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512<int>.Count)
{
for (; i < (uint)count - (Vector512<int>.Count - 1u); i += (uint)Vector512<int>.Count)
{
Vector512<int> result = Vector512.LoadUnsafe(ref left, i) ^ Vector512.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256<int>.Count)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
for (; i < (uint)count - (Vector256<int>.Count - 1u); i += (uint)Vector256<int>.Count)
{
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) ^ Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Vector128.IsHardwareAccelerated)
else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128<int>.Count)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
for (; i < (uint)count - (Vector128<int>.Count - 1u); i += (uint)Vector128<int>.Count)
{
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) ^ Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
Expand Down Expand Up @@ -529,18 +558,25 @@ public unsafe BitArray Not()
uint i = 0;

ref int value = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);

if (Vector256.IsHardwareAccelerated)
if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512<int>.Count)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
for (; i < (uint)count - (Vector512<int>.Count - 1u); i += (uint)Vector512<int>.Count)
{
Vector512<int> result = ~Vector512.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}
else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256<int>.Count)
{
for (; i < (uint)count - (Vector256<int>.Count - 1u); i += (uint)Vector256<int>.Count)
{
Vector256<int> result = ~Vector256.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}
else if (Vector128.IsHardwareAccelerated)
else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128<int>.Count)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
for (; i < (uint)count - (Vector128<int>.Count - 1u); i += (uint)Vector128<int>.Count)
{
Vector128<int> result = ~Vector128.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
Expand Down Expand Up @@ -797,21 +833,47 @@ public unsafe void CopyTo(Array array, int index)
if (m_length < BitsPerInt32)
goto LessThan32;

// The mask used when shuffling a single int into Vector128/256.
// The mask used when shuffling a single int into Vector128/256/512.
// On little endian machines, the lower 8 bits of int belong in the first byte, next lower 8 in the second and so on.
// We place the bytes that contain the bits to its respective byte so that we can mask out only the relevant bits later.
Vector128<byte> lowerShuffleMask_CopyToBoolArray = Vector128.Create(0, 0x01010101_01010101).AsByte();
Vector128<byte> upperShuffleMask_CopyToBoolArray = Vector128.Create(0x02020202_02020202, 0x03030303_03030303).AsByte();

if (Avx2.IsSupported)
if (Avx512F.IsSupported && (uint)m_length >= Vector512<byte>.Count)
{
Vector256<byte> upperShuffleMask_CopyToBoolArray256 = Vector256.Create(0x04040404_04040404, 0x05050505_05050505,
0x06060606_06060606, 0x07070707_07070707).AsByte();
Vector256<byte> lowerShuffleMask_CopyToBoolArray256 = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray);
Vector512<byte> shuffleMask = Vector512.Create(lowerShuffleMask_CopyToBoolArray256, upperShuffleMask_CopyToBoolArray256);
Vector512<byte> bitMask = Vector512.Create(0x80402010_08040201).AsByte();
Vector512<byte> ones = Vector512.Create((byte)1);

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector512<byte>.Count) <= (uint)m_length; i += (uint)Vector512<byte>.Count)
{
ulong bits = (ulong)(uint)m_array[i / (uint)BitsPerInt32] + ((ulong)m_array[(i / (uint)BitsPerInt32) + 1] << BitsPerInt32);
Vector512<ulong> scalar = Vector512.Create(bits);
Vector512<byte> shuffled = Avx512BW.Shuffle(scalar.AsByte(), shuffleMask);
Vector512<byte> extracted = Avx512F.And(shuffled, bitMask);

// The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1
// to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined)
Vector512<byte> normalized = Avx512BW.Min(extracted, ones);
Avx512F.Store((byte*)destination + i, normalized);
}
}
}
else if (Avx2.IsSupported && (uint)m_length >= Vector256<byte>.Count)
{
Vector256<byte> shuffleMask = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray);
Vector256<byte> bitMask = Vector256.Create(0x80402010_08040201).AsByte();
//Internal.Console.WriteLine(bitMask);
Vector256<byte> ones = Vector256.Create((byte)1);

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector256ByteCount) <= (uint)m_length; i += Vector256ByteCount)
for (; (i + Vector256<byte>.Count) <= (uint)m_length; i += (uint)Vector256<byte>.Count)
{
int bits = m_array[i / (uint)BitsPerInt32];
Vector256<int> scalar = Vector256.Create(bits);
Expand All @@ -825,7 +887,7 @@ public unsafe void CopyTo(Array array, int index)
}
}
}
else if (Ssse3.IsSupported)
else if (Ssse3.IsSupported && ((uint)m_length >= Vector512<byte>.Count * 2u))
{
Vector128<byte> lowerShuffleMask = lowerShuffleMask_CopyToBoolArray;
Vector128<byte> upperShuffleMask = upperShuffleMask_CopyToBoolArray;
Expand All @@ -836,7 +898,7 @@ public unsafe void CopyTo(Array array, int index)

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u)
for (; (i + Vector128<byte>.Count * 2u) <= (uint)m_length; i += (uint)Vector128<byte>.Count * 2u)
{
int bits = m_array[i / (uint)BitsPerInt32];
Vector128<int> scalar = Vector128.CreateScalarUnsafe(bits);
Expand All @@ -862,7 +924,7 @@ public unsafe void CopyTo(Array array, int index)

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u)
for (; (i + Vector128<byte>.Count * 2u) <= (uint)m_length; i += (uint)Vector128<byte>.Count * 2u)
{
int bits = m_array[i / (uint)BitsPerInt32];
// Same logic as SSSE3 path, except we do not have Shuffle instruction.
Expand Down