Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vectorized paths for Span<T>.Reverse #64412

Merged
merged 14 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/libraries/System.Private.CoreLib/src/System/Array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,8 @@ public static void Reverse<T>(T[] array)
{
if (array == null)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.array);
Reverse(array, 0, array.Length);
if (array.Length > 1)
SpanHelpers.Reverse(ref MemoryMarshal.GetArrayDataReference(array), (nuint)array.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before we merge this change, we need to update our tests, as some of them are using Array.Reverse to get the expected output for Span.Reverse:

https://github.com/dotnet/runtime/blob/c663fa429a88d3089b740a89181c2b6f033b2839/src/libraries/System.Memory/tests/Span/Reverse.cs#L134-L137

I am going to send a PR for that in a few minutes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #68493

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I didn't notice that. Sounds like a good idea to wait on this PR until the tests are updated, just to be sure.

}

public static void Reverse<T>(T[] array, int index, int length)
Expand All @@ -1739,16 +1740,7 @@ public static void Reverse<T>(T[] array, int index, int length)
if (length <= 1)
return;

ref T first = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index);
ref T last = ref Unsafe.Add(ref Unsafe.Add(ref first, length), -1);
do
{
T temp = first;
first = last;
last = temp;
first = ref Unsafe.Add(ref first, 1);
last = ref Unsafe.Add(ref last, -1);
} while (Unsafe.IsAddressLessThan(ref first, ref last));
SpanHelpers.Reverse(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index), (nuint)length);
}

// Sorts the elements of an array. The sort compares the elements to each
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;

namespace System
{
Expand Down Expand Up @@ -1543,21 +1545,10 @@ ref MemoryMarshal.GetReference(value),
/// </summary>
public static void Reverse<T>(this Span<T> span)
{
if (span.Length <= 1)
if (span.Length > 1)
{
return;
SpanHelpers.Reverse(ref MemoryMarshal.GetReference(span), (nuint)span.Length);
}

ref T first = ref MemoryMarshal.GetReference(span);
ref T last = ref Unsafe.Add(ref Unsafe.Add(ref first, span.Length), -1);
do
{
T temp = first;
first = last;
last = temp;
first = ref Unsafe.Add(ref first, 1);
last = ref Unsafe.Add(ref last, -1);
} while (Unsafe.IsAddressLessThan(ref first, ref last));
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2238,5 +2238,96 @@ private static uint FindFirstMatchedLane(Vector128<byte> compareResult)
// Find the first lane that is set inside compareResult.
return (uint)BitOperations.TrailingZeroCount(selectedLanes) >> 2;
}

public static void Reverse(ref byte buf, nuint length)
{
if (Avx2.IsSupported && (nuint)Vector256<byte>.Count * 2 <= length)
{
Vector256<byte> reverseMask = Vector256.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, // first 128-bit lane
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); // second 128-bit lane
nuint numElements = (nuint)Vector256<byte>.Count;
nuint numIters = (length / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
{
nuint firstOffset = i * numElements;
nuint lastOffset = length - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref buf, firstOffset);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref buf, lastOffset);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for adding a great explanation here 👍

// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
// +-------------------------------------------------------------------------------+
// | A1 | B1 | C1 | D1 | E1 | F1 | G1 | H1 | I1 | J1 | K1 | L1 | M1 | N1 | O1 | P1 |
// +-------------------------------------------------------------------------------+
// | A2 | B2 | C2 | D2 | E2 | F2 | G2 | H2 | I2 | J2 | K2 | L2 | M2 | N2 | O2 | P2 |
// +-------------------------------------------------------------------------------+
// Shuffle --->
// +-------------------------------------------------------------------------------+
// | P1 | O1 | N1 | M1 | L1 | K1 | J1 | I1 | H1 | G1 | F1 | E1 | D1 | C1 | B1 | A1 |
// +-------------------------------------------------------------------------------+
// | P2 | O2 | N2 | M2 | L2 | K2 | J2 | I2 | H2 | G2 | F2 | E2 | D2 | C2 | B2 | A2 |
// +-------------------------------------------------------------------------------+
// Permute --->
// +-------------------------------------------------------------------------------+
// | P2 | O2 | N2 | M2 | L2 | K2 | J2 | I2 | H2 | G2 | F2 | E2 | D2 | C2 | B2 | A2 |
// +-------------------------------------------------------------------------------+
// | P1 | O1 | N1 | M1 | L1 | K1 | J1 | I1 | H1 | G1 | F1 | E1 | D1 | C1 | B1 | A1 |
// +-------------------------------------------------------------------------------+
tempFirst = Avx2.Shuffle(tempFirst, reverseMask);
tempFirst = Avx2.Permute2x128(tempFirst, tempFirst, 0b00_01);
tempLast = Avx2.Shuffle(tempLast, reverseMask);
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref buf, firstOffset);
tempFirst.StoreUnsafe(ref buf, lastOffset);
}
buf = ref Unsafe.Add(ref buf, numIters * numElements);
length -= numIters * numElements * 2;
}
else if (Sse2.IsSupported && (nuint)Vector128<byte>.Count * 2 <= length)
{
Vector128<byte> reverseMask = Vector128.Create((byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
nuint numElements = (nuint)Vector128<byte>.Count;
nuint numIters = (length / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
{
nuint firstOffset = i * numElements;
nuint lastOffset = length - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector128<byte> tempFirst = Vector128.LoadUnsafe(ref buf, firstOffset);
Vector128<byte> tempLast = Vector128.LoadUnsafe(ref buf, lastOffset);

// Shuffle to reverse each vector:
// +---------------------------------------------------------------+
// | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P |
// +---------------------------------------------------------------+
// --->
// +---------------------------------------------------------------+
// | P | O | N | M | L | K | J | I | H | G | F | E | D | C | B | A |
// +---------------------------------------------------------------+
tempFirst = Ssse3.Shuffle(tempFirst, reverseMask);
tempLast = Ssse3.Shuffle(tempLast, reverseMask);

// Store the reversed vectors
tempLast.StoreUnsafe(ref buf, firstOffset);
tempFirst.StoreUnsafe(ref buf, lastOffset);
}
buf = ref Unsafe.Add(ref buf, numIters * numElements);
length -= numIters * numElements * 2;
}

// Store any remaining values one-by-one
for (nuint i = 0; i < (length / 2); i++)
{
ref byte first = ref Unsafe.Add(ref buf, i);
ref byte last = ref Unsafe.Add(ref buf, length - 1 - i);
(last, first) = (first, last);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2015,5 +2015,93 @@ private static int FindFirstMatchedLane(Vector128<ushort> compareResult)

return BitOperations.TrailingZeroCount(selectedLanes) >> 3;
}

public static void Reverse(ref char buf, nuint length)
alexcovington marked this conversation as resolved.
Show resolved Hide resolved
{
ref byte bufByte = ref Unsafe.As<char, byte>(ref buf);
nuint byteLength = length * sizeof(char);
if (Avx2.IsSupported && (nuint)Vector256<short>.Count * 2 <= length)
{
Vector256<byte> reverseMask = Vector256.Create(
(byte)14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, // first 128-bit lane
14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1); // second 128-bit lane
nuint numElements = (nuint)Vector256<byte>.Count;
nuint numIters = (byteLength / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
{
nuint firstOffset = i * numElements;
nuint lastOffset = byteLength - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref bufByte, firstOffset);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref bufByte, lastOffset);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
// +---------------------------------------------------------------+
// | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P |
// +---------------------------------------------------------------+
// Shuffle --->
// +---------------------------------------------------------------+
// | H | G | F | E | D | C | B | A | P | O | N | M | L | K | J | I |
// +---------------------------------------------------------------+
// Permute --->
// +---------------------------------------------------------------+
// | P | O | N | M | L | K | J | I | H | G | F | E | D | C | B | A |
// +---------------------------------------------------------------+
tempFirst = Avx2.Shuffle(tempFirst, reverseMask);
tempFirst = Avx2.Permute2x128(tempFirst, tempFirst, 0b00_01);
tempLast = Avx2.Shuffle(tempLast, reverseMask);
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref bufByte, firstOffset);
tempFirst.StoreUnsafe(ref bufByte, lastOffset);
}
bufByte = ref Unsafe.Add(ref bufByte, numIters * numElements);
length -= numIters * (nuint)Vector256<short>.Count * 2;
}
else if (Sse2.IsSupported && (nuint)Vector128<short>.Count * 2 <= length)
{
Vector128<byte> reverseMask = Vector128.Create((byte)14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
nuint numElements = (nuint)Vector128<byte>.Count;
nuint numIters = ((length * sizeof(char)) / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
{
nuint firstOffset = i * numElements;
nuint lastOffset = byteLength - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector128<byte> tempFirst = Vector128.LoadUnsafe(ref bufByte, firstOffset);
Vector128<byte> tempLast = Vector128.LoadUnsafe(ref bufByte, lastOffset);

// Shuffle to reverse each vector:
// +-------------------------------+
// | A | B | C | D | E | F | G | H |
// +-------------------------------+
// --->
// +-------------------------------+
// | H | G | F | E | D | C | B | A |
// +-------------------------------+
tempFirst = Ssse3.Shuffle(tempFirst, reverseMask);
tempLast = Ssse3.Shuffle(tempLast, reverseMask);

// Store the reversed vectors
tempLast.StoreUnsafe(ref bufByte, firstOffset);
tempFirst.StoreUnsafe(ref bufByte, lastOffset);
}
bufByte = ref Unsafe.Add(ref bufByte, numIters * numElements);
length -= numIters * (nuint)Vector128<short>.Count * 2;
}

// Store any remaining values one-by-one
buf = ref Unsafe.As<byte, char>(ref bufByte);
for (nuint i = 0; i < (length / 2); i++)
{
ref char first = ref Unsafe.Add(ref buf, i);
ref char last = ref Unsafe.Add(ref buf, length - 1 - i);
(last, first) = (first, last);
}
}
}
}
Loading