Skip to content

Commit

Permalink
Add fast path for linq count with predicate (#102884)
Browse files Browse the repository at this point in the history
* Add fast path for count with predicate

* Also use TryGetSpan in Aggregate, Any, All, Contains, First, and Single

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
neon-sunset and stephentoub authored Jul 22, 2024
1 parent 810d646 commit 4d49539
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 32 deletions.
52 changes: 44 additions & 8 deletions src/libraries/System.Linq/src/System/Linq/Aggregate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,37 @@ public static TSource Aggregate<TSource>(this IEnumerable<TSource> source, Func<
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.func);
}

using (IEnumerator<TSource> e = source.GetEnumerator())
TSource result;
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
if (span.IsEmpty)
{
ThrowHelper.ThrowNoElementsException();
}

result = span[0];
for (int i = 1; i < span.Length; i++)
{
result = func(result, span[i]);
}
}
else
{
using IEnumerator<TSource> e = source.GetEnumerator();

if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}

TSource result = e.Current;
result = e.Current;
while (e.MoveNext())
{
result = func(result, e.Current);
}

return result;
}

return result;
}

public static TAccumulate Aggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func)
Expand All @@ -49,9 +65,19 @@ public static TAccumulate Aggregate<TSource, TAccumulate>(this IEnumerable<TSour
}

TAccumulate result = seed;
foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
result = func(result, element);
foreach (TSource element in span)
{
result = func(result, element);
}
}
else
{
foreach (TSource element in source)
{
result = func(result, element);
}
}

return result;
Expand All @@ -75,9 +101,19 @@ public static TResult Aggregate<TSource, TAccumulate, TResult>(this IEnumerable<
}

TAccumulate result = seed;
foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
result = func(result, element);
foreach (TSource element in span)
{
result = func(result, element);
}
}
else
{
foreach (TSource element in source)
{
result = func(result, element);
}
}

return resultSelector(result);
Expand Down
38 changes: 32 additions & 6 deletions src/libraries/System.Linq/src/System/Linq/AnyAll.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,24 @@ public static bool Any<TSource>(this IEnumerable<TSource> source, Func<TSource,
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
}

foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
if (predicate(element))
foreach (TSource element in span)
{
return true;
if (predicate(element))
{
return true;
}
}
}
else
{
foreach (TSource element in source)
{
if (predicate(element))
{
return true;
}
}
}

Expand All @@ -78,11 +91,24 @@ public static bool All<TSource>(this IEnumerable<TSource> source, Func<TSource,
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
}

foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
foreach (TSource element in span)
{
if (!predicate(element))
{
return false;
}
}
}
else
{
if (!predicate(element))
foreach (TSource element in source)
{
return false;
if (!predicate(element))
{
return false;
}
}
}

Expand Down
21 changes: 20 additions & 1 deletion src/libraries/System.Linq/src/System/Linq/Contains.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,28 @@ public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource v

if (comparer is null)
{
// While it's tempting, this must not delegate to ICollection<TSource>.Contains, as the historical semantics
// of a null comparer with this method are to use EqualityComparer<TSource>.Default, and that might differ
// from the semantics encoded in ICollection<TSource>.Contains.

// We don't bother special-casing spans here as explicitly providing a null comparer with a known collection type
// is relatively rare. If you don't care about the comparer, you use the other overload, and while it will delegate
// to this overload with a null comparer, it'll only do so for collections from which we can't extract a span.
// And if you do care about the comparer, you're generally passing in a non-null one.

foreach (TSource element in source)
{
if (EqualityComparer<TSource>.Default.Equals(element, value)) // benefits from devirtualization and likely inlining
if (EqualityComparer<TSource>.Default.Equals(element, value))
{
return true;
}
}
}
else if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
foreach (TSource element in span)
{
if (comparer.Equals(element, value))
{
return true;
}
Expand Down
34 changes: 22 additions & 12 deletions src/libraries/System.Linq/src/System/Linq/Count.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,26 @@ public static int Count<TSource>(this IEnumerable<TSource> source, Func<TSource,
}

int count = 0;
foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
checked
foreach (TSource element in span)
{
if (predicate(element))
{
count++;
}
}
}
else
{
foreach (TSource element in source)
{
if (predicate(element))
{
checked { count++; }
}
}
}

return count;
}
Expand Down Expand Up @@ -136,15 +146,15 @@ public static long LongCount<TSource>(this IEnumerable<TSource> source)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}

// TryGetSpan isn't used here because if it's expected that there are more than int.MaxValue elements,
// the source can't possibly be something from which we can extract a span.

long count = 0;
using (IEnumerator<TSource> e = source.GetEnumerator())
{
checked
while (e.MoveNext())
{
while (e.MoveNext())
{
count++;
}
checked { count++; }
}
}

Expand All @@ -163,15 +173,15 @@ public static long LongCount<TSource>(this IEnumerable<TSource> source, Func<TSo
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
}

// TryGetSpan isn't used here because if it's expected that there are more than int.MaxValue elements,
// the source can't possibly be something from which we can extract a span.

long count = 0;
foreach (TSource element in source)
{
checked
if (predicate(element))
{
if (predicate(element))
{
count++;
}
checked { count++; }
}
}

Expand Down
22 changes: 18 additions & 4 deletions src/libraries/System.Linq/src/System/Linq/First.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,26 @@ public static TSource FirstOrDefault<TSource>(this IEnumerable<TSource> source,
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
}

foreach (TSource element in source)
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
if (predicate(element))
foreach (TSource element in span)
{
found = true;
return element;
if (predicate(element))
{
found = true;
return element;
}
}
}
else
{
foreach (TSource element in source)
{
if (predicate(element))
{
found = true;
return element;
}
}
}

Expand Down
24 changes: 23 additions & 1 deletion src/libraries/System.Linq/src/System/Linq/Single.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,30 @@ public static TSource SingleOrDefault<TSource>(this IEnumerable<TSource> source,
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
}

using (IEnumerator<TSource> e = source.GetEnumerator())
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
{
for (int i = 0; i < span.Length; i++)
{
TSource result = span[i];
if (predicate(result))
{
for (i++; (uint)i < (uint)span.Length; i++)
{
if (predicate(span[i]))
{
ThrowHelper.ThrowMoreThanOneMatchException();
}
}

found = true;
return result;
}
}
}
else
{
using IEnumerator<TSource> e = source.GetEnumerator();

while (e.MoveNext())
{
TSource result = e.Current;
Expand Down

0 comments on commit 4d49539

Please sign in to comment.