Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Ordinal.EqualsIgnoreCase_Vector with AVX2 and AVX512 #93116

Merged
merged 5 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/libraries/Common/src/System/HexConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ public static bool TryDecodeFromUtf16_Vector128(ReadOnlySpan<char> chars, Span<b
// or some byte greater than 0x0f.
Vector128<byte> nibbles = Vector128.Min(t2 - Vector128.Create((byte)0xF0), t4);
// Any high bit is a sign that input is not a valid hex data
if (!Utf16Utility.AllCharsInVector128AreAscii(vec1 | vec2) ||
if (!Utf16Utility.AllCharsInVectorAreAscii(vec1 | vec2) ||
Vector128.AddSaturate(nibbles, Vector128.Create((byte)(127 - 15))).ExtractMostSignificantBits() != 0)
{
// Input is either non-ASCII or invalid hex data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,39 +78,73 @@ internal static int CompareStringIgnoreCaseNonAscii(ref char strA, int lengthA,
return OrdinalCasing.CompareStringIgnoreCase(ref strA, lengthA, ref strB, lengthB);
}

private static bool EqualsIgnoreCase_Vector128(ref char charA, ref char charB, int length)
private static bool EqualsIgnoreCase_Vector<TVector>(ref char charA, ref char charB, int length)
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
where TVector : struct, ISimdVector<TVector, ushort>
{
Debug.Assert(length >= Vector128<ushort>.Count);
Debug.Assert(Vector128.IsHardwareAccelerated);
Debug.Assert(length >= TVector.Count);

nuint lengthU = (nuint)length;
nuint lengthToExamine = lengthU - (nuint)Vector128<ushort>.Count;
nuint lengthToExamine = lengthU - (nuint)TVector.Count;
nuint i = 0;
Vector128<ushort> vec1;
Vector128<ushort> vec2;
TVector vec1;
TVector vec2;
TVector loweringMask = TVector.Create(0x20);
TVector vecA = TVector.Create('a');
TVector vecZMinusA = TVector.Create('z' - 'a');
do
{
vec1 = Vector128.LoadUnsafe(ref charA, i);
vec2 = Vector128.LoadUnsafe(ref charB, i);
vec1 = TVector.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charA), i);
vec2 = TVector.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charB), i);
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved

if (!Utf16Utility.AllCharsInVector128AreAscii(vec1 | vec2))
if (!Utf16Utility.AllCharsInVectorAreAscii(vec1 | vec2))
{
goto NON_ASCII;
}

if (!Utf16Utility.Vector128OrdinalIgnoreCaseAscii(vec1, vec2))
TVector notEquals = ~TVector.Equals(vec1, vec2);
if (!notEquals.Equals(TVector.Zero))
{
return false;
}
// not exact match

i += (nuint)Vector128<ushort>.Count;
vec1 |= loweringMask;
vec2 |= loweringMask;
if (TVector.GreaterThanAny((vec1 - vecA) & notEquals, vecZMinusA) || !vec1.Equals(vec2))
{
return false; // first input isn't in [A-Za-z], and not exact match of lowered
}
}
i += (nuint)TVector.Count;
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
} while (i <= lengthToExamine);

// Use scalar path for trailing elements
return i == lengthU || EqualsIgnoreCase(ref Unsafe.Add(ref charA, i), ref Unsafe.Add(ref charB, i), (int)(lengthU - i));
// Handle trailing elements
if (i != lengthU)
{
i = lengthU - (nuint)TVector.Count;
vec1 = TVector.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charA), i);
vec2 = TVector.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charB), i);

if (!Utf16Utility.AllCharsInVectorAreAscii(vec1 | vec2))
{
goto NON_ASCII;
}

TVector notEquals = ~TVector.Equals(vec1, vec2);
if (!notEquals.Equals(TVector.Zero))
{
// not exact match

vec1 |= loweringMask;
vec2 |= loweringMask;
if (TVector.GreaterThanAny((vec1 - vecA) & notEquals, vecZMinusA) || !vec1.Equals(vec2))
{
return false; // first input isn't in [A-Za-z], and not exact match of lowered
}
}
}
return true;

NON_ASCII:
if (Utf16Utility.AllCharsInVector128AreAscii(vec1) || Utf16Utility.AllCharsInVector128AreAscii(vec2))
if (Utf16Utility.AllCharsInVectorAreAscii(vec1) || Utf16Utility.AllCharsInVectorAreAscii(vec2))
{
// No need to use the fallback if one of the inputs is full-ASCII
return false;
Expand All @@ -129,8 +163,15 @@ internal static bool EqualsIgnoreCase(ref char charA, ref char charB, int length
{
return EqualsIgnoreCase_Scalar(ref charA, ref charB, length);
}

return EqualsIgnoreCase_Vector128(ref charA, ref charB, length);
if (Vector512.IsHardwareAccelerated && length >= Vector512<ushort>.Count)
{
return EqualsIgnoreCase_Vector<Vector512<ushort>>(ref charA, ref charB, length);
}
if (Vector256.IsHardwareAccelerated && length >= Vector256<ushort>.Count)
{
return EqualsIgnoreCase_Vector<Vector256<ushort>>(ref charA, ref charB, length);
}
return EqualsIgnoreCase_Vector<Vector128<ushort>>(ref charA, ref charB, length);
}

internal static bool EqualsIgnoreCase_Scalar(ref char charA, ref char charB, int length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,41 +278,13 @@ internal static bool UInt64OrdinalIgnoreCaseAscii(ulong valueA, ulong valueB)
}

/// <summary>
/// Returns true iff the Vector128 represents 8 ASCII UTF-16 characters in machine endianness.
/// Returns true iff the TVector represents ASCII UTF-16 characters in machine endianness.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool AllCharsInVector128AreAscii(Vector128<ushort> vec)
internal static bool AllCharsInVectorAreAscii<TVector>(TVector vec)
where TVector : struct, ISimdVector<TVector, ushort>
{
return (vec & Vector128.Create(unchecked((ushort)~0x007F))) == Vector128<ushort>.Zero;
}

/// <summary>
/// Given two Vector128 that represent 8 ASCII UTF-16 characters each, returns true iff
/// the two inputs are equal using an ordinal case-insensitive comparison.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool Vector128OrdinalIgnoreCaseAscii(Vector128<ushort> vec1, Vector128<ushort> vec2)
{
// ASSUMPTION: Caller has validated that input values are ASCII.

// the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'A'
Vector128<sbyte> lowIndicator1 = Vector128.Create((sbyte)(0x80 - 'A')) + vec1.AsSByte();
Vector128<sbyte> lowIndicator2 = Vector128.Create((sbyte)(0x80 - 'A')) + vec2.AsSByte();

// the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'A' and <= 'Z'
Vector128<sbyte> combIndicator1 =
Vector128.LessThan(Vector128.Create(unchecked((sbyte)(('Z' - 'A') - 0x80))), lowIndicator1);
Vector128<sbyte> combIndicator2 =
Vector128.LessThan(Vector128.Create(unchecked((sbyte)(('Z' - 'A') - 0x80))), lowIndicator2);

// Convert both vectors to lower case by adding 0x20 bit for all [A-Z][a-z] characters
Vector128<sbyte> lcVec1 =
Vector128.AndNot(Vector128.Create((sbyte)0x20), combIndicator1) + vec1.AsSByte();
Vector128<sbyte> lcVec2 =
Vector128.AndNot(Vector128.Create((sbyte)0x20), combIndicator2) + vec2.AsSByte();

// Compare two lowercased vectors
return (lcVec1 ^ lcVec2) == Vector128<sbyte>.Zero;
return (vec & TVector.Create(unchecked((ushort)~0x007F))).Equals(TVector.Zero);
}
}
}
Loading