Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vectorization of IndexOf(chars, StringComparison.OrdinalIgnoreCase) #85437

Merged
merged 3 commits into from
May 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 200 additions & 13 deletions src/libraries/System.Private.CoreLib/src/System/Globalization/Ordinal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Text.Unicode;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Text.Unicode;

namespace System.Globalization
{
Expand Down Expand Up @@ -295,7 +297,6 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
// A non-linguistic search compares chars directly against one another, so large
// target strings can never be found inside small search spaces. This check also
// handles empty 'source' spans.

return -1;
}

Expand All @@ -309,25 +310,38 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
return CompareInfo.NlsIndexOfOrdinalCore(source, value, ignoreCase: true, fromBeginning: true);
}

// If value starts with an ASCII char, we can use a vectorized path
// If value doesn't start with ASCII, fall back to a non-vectorized non-ASCII friendly version.
ref char valueRef = ref MemoryMarshal.GetReference(value);
char valueChar = valueRef;

if (!char.IsAscii(valueChar))
{
// Fallback to a more non-ASCII friendly version
return OrdinalCasing.IndexOf(source, value);
}

// Hoist some expressions from the loop
int valueTailLength = value.Length - 1;
int searchSpaceLength = source.Length - valueTailLength;
int searchSpaceMinusValueTailLength = source.Length - valueTailLength;
ref char searchSpace = ref MemoryMarshal.GetReference(source);
char valueCharU = default;
char valueCharL = default;
nint offset = 0;
bool isLetter = false;

// If the input is long enough and the value ends with ASCII, we can take a special vectorized
// path that compares both the beginning and the end at the same time.
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<ushort>.Count)
{
valueCharU = Unsafe.Add(ref valueRef, valueTailLength);
if (char.IsAscii(valueCharU))
{
goto SearchTwoChars;
}
}

// We're searching for the first character and it's known to be ASCII. If it's not a letter,
// then IgnoreCase doesn't impact what it matches and we just need to do a normal search
// for that single character. If it is a letter, then we need to search for both its upper
// and lower-case variants.
if (char.IsAsciiLetter(valueChar))
{
valueCharU = (char)(valueChar & ~0x20);
Expand All @@ -340,16 +354,16 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
// Do a quick search for the first element of "value".
int relativeIndex = isLetter ?
PackedSpanHelpers.PackedIndexOfIsSupported
? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength)
: SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength) :
SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceLength);
? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength)
: SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength) :
SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceMinusValueTailLength);
if (relativeIndex < 0)
{
break;
}

searchSpaceLength -= relativeIndex;
if (searchSpaceLength <= 0)
searchSpaceMinusValueTailLength -= relativeIndex;
if (searchSpaceMinusValueTailLength <= 0)
{
break;
}
Expand All @@ -364,12 +378,185 @@ ref Unsafe.Add(ref valueRef, 1), valueTailLength))
return (int)offset; // The tail matched. Return a successful find.
}

searchSpaceLength--;
searchSpaceMinusValueTailLength--;
offset++;
}
while (searchSpaceLength > 0);
while (searchSpaceMinusValueTailLength > 0);

return -1;

// Based on SpanHelpers.IndexOf(ref char, int, ref char, int), which was in turn based on
// http://0x80.pl/articles/simd-strfind.html#algorithm-1-generic-simd. This version has additional
// modifications to support case-insensitive searches.
SearchTwoChars:
// Both the first character in value (valueChar) and the last character in value (valueCharU) are ASCII. Get their lowercase variants.
valueChar = (char)(valueChar | 0x20);
valueCharU = (char)(valueCharU | 0x20);

// The search is more efficient if the two characters being searched for are different. As long as they are equal, walk backwards
// from the last character in the search value until we find a character that's different. Since we're dealing with IgnoreCase,
// we compare the lowercase variants, as that's what we'll be comparing against in the main loop.
nint ch1ch2Distance = valueTailLength;
while (valueCharU == valueChar && ch1ch2Distance > 1)
{
char tmp = Unsafe.Add(ref valueRef, ch1ch2Distance - 1);
if (!char.IsAscii(tmp))
{
break;
}
--ch1ch2Distance;
valueCharU = (char)(tmp | 0x20);
}

// Use Vector256 if the input is long enough.
if (Vector256.IsHardwareAccelerated && searchSpaceMinusValueTailLength - Vector256<ushort>.Count >= 0)
{
// Create a vector for each of the lowercase ASCII characters we're searching for.
Vector256<ushort> ch1 = Vector256.Create((ushort)valueChar);
Vector256<ushort> ch2 = Vector256.Create((ushort)valueCharU);

nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector256<ushort>.Count;
do
{
// Make sure we don't go out of bounds.
Debug.Assert(offset + ch1ch2Distance + Vector256<ushort>.Count <= source.Length);

// Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
// For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
// was no match. If it wasn't, we have to do more work to check for a match.
Vector256<ushort> cmpCh2 = Vector256.Equals(ch2, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector256.Create((ushort)0x20)));
Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)offset), Vector256.Create((ushort)0x20)));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
if (cmpAnd != Vector256<byte>.Zero)
{
goto CandidateFound;
}

LoopFooter:
// No match. Advance to the next vector.
offset += Vector256<ushort>.Count;

// If we've reached the end of the search space, bail.
if (offset == searchSpaceMinusValueTailLength)
{
return -1;
}

// If we're within a vector's length of the end of the search space, adjust the offset
// to point to the last vector so that our next iteration will process it.
if (offset > searchSpaceMinusValueTailLengthAndVector)
{
offset = searchSpaceMinusValueTailLengthAndVector;
}

continue;

CandidateFound:
// Possible matches at the current location. Extract the bits for each element.
// For each set bits, we'll check if it's a match at that location.
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
// Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
// but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
// the full string always.
int bitPos = BitOperations.TrailingZeroCount(mask);
nint charPos = (nint)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
{
// Match! Return the index.
return (int)(offset + charPos);
}

// Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
// If any remain, we loop around to do the next comparison.
if (Bmi1.IsSupported)
{
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
}
else
{
mask &= ~(uint)(0b11 << bitPos);
}
} while (mask != 0);
goto LoopFooter;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
{
// Create a vector for each of the lowercase ASCII characters we're searching for.
Vector128<ushort> ch1 = Vector128.Create((ushort)valueChar);
Vector128<ushort> ch2 = Vector128.Create((ushort)valueCharU);

nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector128<ushort>.Count;
do
{
// Make sure we don't go out of bounds.
Debug.Assert(offset + ch1ch2Distance + Vector128<ushort>.Count <= source.Length);

// Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
// For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
// was no match. If it wasn't, we have to do more work to check for a match.
Vector128<ushort> cmpCh2 = Vector128.Equals(ch2, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector128.Create((ushort)0x20)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nit: Vector128.BitwiseOr -> |

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just style, right? Happy to change it, just questioning whether it's worth rerunning ci.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely not worth it 🙂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right its "just style". There is also the general considerations of "methods" vs "operators" (such as precedence and readability) but we're not super consistent today just due to the operators being relatively new.

Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)offset), Vector128.Create((ushort)0x20)));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
if (cmpAnd != Vector128<byte>.Zero)
{
goto CandidateFound;
}

LoopFooter:
// No match. Advance to the next vector.
offset += Vector128<ushort>.Count;

// If we've reached the end of the search space, bail.
if (offset == searchSpaceMinusValueTailLength)
{
return -1;
}

// If we're within a vector's length of the end of the search space, adjust the offset
// to point to the last vector so that our next iteration will process it.
if (offset > searchSpaceMinusValueTailLengthAndVector)
{
offset = searchSpaceMinusValueTailLengthAndVector;
}

continue;

CandidateFound:
// Possible matches at the current location. Extract the bits for each element.
// For each set bits, we'll check if it's a match at that location.
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
// Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
// but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
// the full string always.
int bitPos = BitOperations.TrailingZeroCount(mask);
int charPos = (int)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
{
// Match! Return the index.
return (int)(offset + charPos);
}

// Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
// If any remain, we loop around to do the next comparison.
if (Bmi1.IsSupported)
{
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
}
else
{
mask &= ~(uint)(0b11 << bitPos);
}
} while (mask != 0);
goto LoopFooter;

} while (true);
}
}

internal static int LastIndexOf(string source, string value, int startIndex, int count)
Expand Down