Skip to content

Commit

Permalink
Improve vectorization of IndexOf(chars, StringComparison.OrdinalIgnor…
Browse files Browse the repository at this point in the history
…eCase) (#85437)

* Improve vectorization of IndexOf(chars, StringComparison.OrdinalIgnoreCase)

Use the same general "Algorithm 1: Generic SIMD" that we do for StringComparison.Ordinal, adapter for OrdinalIgnoreCase.

* Fix duplicate local
  • Loading branch information
stephentoub authored May 1, 2023
1 parent 049acec commit 80cd72e
Showing 1 changed file with 200 additions and 13 deletions.
213 changes: 200 additions & 13 deletions src/libraries/System.Private.CoreLib/src/System/Globalization/Ordinal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Text.Unicode;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Text.Unicode;

namespace System.Globalization
{
Expand Down Expand Up @@ -295,7 +297,6 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
// A non-linguistic search compares chars directly against one another, so large
// target strings can never be found inside small search spaces. This check also
// handles empty 'source' spans.

return -1;
}

Expand All @@ -309,25 +310,38 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
return CompareInfo.NlsIndexOfOrdinalCore(source, value, ignoreCase: true, fromBeginning: true);
}

// If value starts with an ASCII char, we can use a vectorized path
// If value doesn't start with ASCII, fall back to a non-vectorized non-ASCII friendly version.
ref char valueRef = ref MemoryMarshal.GetReference(value);
char valueChar = valueRef;

if (!char.IsAscii(valueChar))
{
// Fallback to a more non-ASCII friendly version
return OrdinalCasing.IndexOf(source, value);
}

// Hoist some expressions from the loop
int valueTailLength = value.Length - 1;
int searchSpaceLength = source.Length - valueTailLength;
int searchSpaceMinusValueTailLength = source.Length - valueTailLength;
ref char searchSpace = ref MemoryMarshal.GetReference(source);
char valueCharU = default;
char valueCharL = default;
nint offset = 0;
bool isLetter = false;

// If the input is long enough and the value ends with ASCII, we can take a special vectorized
// path that compares both the beginning and the end at the same time.
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<ushort>.Count)
{
valueCharU = Unsafe.Add(ref valueRef, valueTailLength);
if (char.IsAscii(valueCharU))
{
goto SearchTwoChars;
}
}

// We're searching for the first character and it's known to be ASCII. If it's not a letter,
// then IgnoreCase doesn't impact what it matches and we just need to do a normal search
// for that single character. If it is a letter, then we need to search for both its upper
// and lower-case variants.
if (char.IsAsciiLetter(valueChar))
{
valueCharU = (char)(valueChar & ~0x20);
Expand All @@ -340,16 +354,16 @@ internal static int IndexOfOrdinalIgnoreCase(ReadOnlySpan<char> source, ReadOnly
// Do a quick search for the first element of "value".
int relativeIndex = isLetter ?
PackedSpanHelpers.PackedIndexOfIsSupported
? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength)
: SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength) :
SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceLength);
? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength)
: SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength) :
SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceMinusValueTailLength);
if (relativeIndex < 0)
{
break;
}

searchSpaceLength -= relativeIndex;
if (searchSpaceLength <= 0)
searchSpaceMinusValueTailLength -= relativeIndex;
if (searchSpaceMinusValueTailLength <= 0)
{
break;
}
Expand All @@ -364,12 +378,185 @@ ref Unsafe.Add(ref valueRef, 1), valueTailLength))
return (int)offset; // The tail matched. Return a successful find.
}

searchSpaceLength--;
searchSpaceMinusValueTailLength--;
offset++;
}
while (searchSpaceLength > 0);
while (searchSpaceMinusValueTailLength > 0);

return -1;

// Based on SpanHelpers.IndexOf(ref char, int, ref char, int), which was in turn based on
// http://0x80.pl/articles/simd-strfind.html#algorithm-1-generic-simd. This version has additional
// modifications to support case-insensitive searches.
SearchTwoChars:
// Both the first character in value (valueChar) and the last character in value (valueCharU) are ASCII. Get their lowercase variants.
valueChar = (char)(valueChar | 0x20);
valueCharU = (char)(valueCharU | 0x20);

// The search is more efficient if the two characters being searched for are different. As long as they are equal, walk backwards
// from the last character in the search value until we find a character that's different. Since we're dealing with IgnoreCase,
// we compare the lowercase variants, as that's what we'll be comparing against in the main loop.
nint ch1ch2Distance = valueTailLength;
while (valueCharU == valueChar && ch1ch2Distance > 1)
{
char tmp = Unsafe.Add(ref valueRef, ch1ch2Distance - 1);
if (!char.IsAscii(tmp))
{
break;
}
--ch1ch2Distance;
valueCharU = (char)(tmp | 0x20);
}

// Use Vector256 if the input is long enough.
if (Vector256.IsHardwareAccelerated && searchSpaceMinusValueTailLength - Vector256<ushort>.Count >= 0)
{
// Create a vector for each of the lowercase ASCII characters we're searching for.
Vector256<ushort> ch1 = Vector256.Create((ushort)valueChar);
Vector256<ushort> ch2 = Vector256.Create((ushort)valueCharU);

nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector256<ushort>.Count;
do
{
// Make sure we don't go out of bounds.
Debug.Assert(offset + ch1ch2Distance + Vector256<ushort>.Count <= source.Length);

// Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
// For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
// was no match. If it wasn't, we have to do more work to check for a match.
Vector256<ushort> cmpCh2 = Vector256.Equals(ch2, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector256.Create((ushort)0x20)));
Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)offset), Vector256.Create((ushort)0x20)));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
if (cmpAnd != Vector256<byte>.Zero)
{
goto CandidateFound;
}

LoopFooter:
// No match. Advance to the next vector.
offset += Vector256<ushort>.Count;

// If we've reached the end of the search space, bail.
if (offset == searchSpaceMinusValueTailLength)
{
return -1;
}

// If we're within a vector's length of the end of the search space, adjust the offset
// to point to the last vector so that our next iteration will process it.
if (offset > searchSpaceMinusValueTailLengthAndVector)
{
offset = searchSpaceMinusValueTailLengthAndVector;
}

continue;

CandidateFound:
// Possible matches at the current location. Extract the bits for each element.
// For each set bits, we'll check if it's a match at that location.
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
// Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
// but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
// the full string always.
int bitPos = BitOperations.TrailingZeroCount(mask);
nint charPos = (nint)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
{
// Match! Return the index.
return (int)(offset + charPos);
}

// Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
// If any remain, we loop around to do the next comparison.
if (Bmi1.IsSupported)
{
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
}
else
{
mask &= ~(uint)(0b11 << bitPos);
}
} while (mask != 0);
goto LoopFooter;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
{
// Create a vector for each of the lowercase ASCII characters we're searching for.
Vector128<ushort> ch1 = Vector128.Create((ushort)valueChar);
Vector128<ushort> ch2 = Vector128.Create((ushort)valueCharU);

nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector128<ushort>.Count;
do
{
// Make sure we don't go out of bounds.
Debug.Assert(offset + ch1ch2Distance + Vector128<ushort>.Count <= source.Length);

// Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
// For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
// was no match. If it wasn't, we have to do more work to check for a match.
Vector128<ushort> cmpCh2 = Vector128.Equals(ch2, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector128.Create((ushort)0x20)));
Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)offset), Vector128.Create((ushort)0x20)));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
if (cmpAnd != Vector128<byte>.Zero)
{
goto CandidateFound;
}

LoopFooter:
// No match. Advance to the next vector.
offset += Vector128<ushort>.Count;

// If we've reached the end of the search space, bail.
if (offset == searchSpaceMinusValueTailLength)
{
return -1;
}

// If we're within a vector's length of the end of the search space, adjust the offset
// to point to the last vector so that our next iteration will process it.
if (offset > searchSpaceMinusValueTailLengthAndVector)
{
offset = searchSpaceMinusValueTailLengthAndVector;
}

continue;

CandidateFound:
// Possible matches at the current location. Extract the bits for each element.
// For each set bits, we'll check if it's a match at that location.
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
// Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
// but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
// the full string always.
int bitPos = BitOperations.TrailingZeroCount(mask);
int charPos = (int)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
{
// Match! Return the index.
return (int)(offset + charPos);
}

// Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
// If any remain, we loop around to do the next comparison.
if (Bmi1.IsSupported)
{
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
}
else
{
mask &= ~(uint)(0b11 << bitPos);
}
} while (mask != 0);
goto LoopFooter;

} while (true);
}
}

internal static int LastIndexOf(string source, string value, int startIndex, int count)
Expand Down

0 comments on commit 80cd72e

Please sign in to comment.