Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector512 Support for Enumerable<int>.Min/Max #93369

Merged
merged 9 commits into from
Oct 16, 2023
1 change: 1 addition & 0 deletions src/libraries/System.Linq/src/System/Linq/Max.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static partial class Enumerable
public static bool Compare(T left, T right) => left > right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Max(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Max(left, right);
public static Vector512<T> Compare(Vector512<T> left, Vector512<T> right) => Vector512.Max(left, right);
}

public static int? Max(this IEnumerable<int?> source) => MaxInteger(source);
Expand Down
27 changes: 26 additions & 1 deletion src/libraries/System.Linq/src/System/Linq/MaxMin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ private interface IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
public static abstract bool Compare(T left, T right);
public static abstract Vector128<T> Compare(Vector128<T> left, Vector128<T> right);
public static abstract Vector256<T> Compare(Vector256<T> left, Vector256<T> right);
public static abstract Vector512<T> Compare(Vector512<T> left, Vector512<T> right);
}

private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
Expand Down Expand Up @@ -66,7 +67,7 @@ private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
}
}
}
else
else if (!Vector512.IsHardwareAccelerated || span.Length < Vector512<T>.Count)
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector256<T>.Count);
Expand All @@ -90,6 +91,30 @@ private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
}
}
}
else
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector512<T>.Count);

Vector512<T> best = Vector512.LoadUnsafe(ref current);
current = ref Unsafe.Add(ref current, Vector512<T>.Count);

while (Unsafe.IsAddressLessThan(ref current, ref lastVectorStart))
{
best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref current));
current = ref Unsafe.Add(ref current, Vector512<T>.Count);
}
best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref lastVectorStart));

value = best[0];
for (int i = 1; i < Vector512<T>.Count; i++)
{
if (TMinMax.Compare(best[i], value))
{
value = best[i];
}
}
Spacefish marked this conversation as resolved.
Show resolved Hide resolved
}
}
else
{
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Linq/src/System/Linq/Min.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static partial class Enumerable
public static bool Compare(T left, T right) => left < right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Min(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Min(left, right);
public static Vector512<T> Compare(Vector512<T> left, Vector512<T> right) => Vector512.Min(left, right);
}

public static int? Min(this IEnumerable<int?> source) => MinInteger(source);
Expand Down
9 changes: 6 additions & 3 deletions src/libraries/System.Linq/tests/MaxTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ public class MaxTests : EnumerableTests
{
public static IEnumerable<object[]> Max_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
for (int length = 2; length < 65; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)(length + length - 1) };

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) };
// Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue
if ((length + length) < sbyte.MaxValue) {
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) };
}

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)(length + length - 1) };
Expand Down
9 changes: 6 additions & 3 deletions src/libraries/System.Linq/tests/MinTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ public class MinTests : EnumerableTests
{
public static IEnumerable<object[]> Min_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
for (int length = 2; length < 65; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)length };

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length };
// Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue, otherwise the type overflows
if ((length + length) < sbyte.MaxValue) {
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length };
}

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)length };
Expand Down