Skip to content

Commit

Permalink
Implement IList<T> on some LINQ iterators (#88249)
Browse files Browse the repository at this point in the history
`ICollection<T>` provides both a Count and a CopyTo, and `IList<T>` an indexer, all of which can make various consumption mechanisms more efficient. We only implement the interfaces when the underlying collection has a fixed size and all of the interface implementations are side-effect free (in particular, while appealing to do so, we don't implement them on various Select iterators).

Some of the serialization tests need to be fixed as a result. The state of Queue's array is a bit different based on how its initialized, and such private details show up in BinaryFormatter output.  Rather than special-casing the output per framework and core, I've just changed the test itself to ensure Queue can't see the size of the input collection.
  • Loading branch information
stephentoub authored Jul 1, 2023
1 parent 0185099 commit 0d77cf0
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 10 deletions.
77 changes: 73 additions & 4 deletions src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace System.Linq
/// that an operation will result in zero elements.
/// </remarks>
[DebuggerDisplay("Count = 0")]
internal sealed class EmptyPartition<TElement> : IPartition<TElement>, IEnumerator<TElement>
internal sealed class EmptyPartition<TElement> : IPartition<TElement>, IEnumerator<TElement>, IList<TElement>, IReadOnlyList<TElement>
{
/// <summary>
/// A cached, immutable instance of an empty enumerable.
Expand Down Expand Up @@ -77,6 +77,32 @@ void IDisposable.Dispose()
public List<TElement> ToList() => new List<TElement>();

public int GetCount(bool onlyIfCheap) => 0;

public int Count => 0;

public bool Contains(TElement item) => false;

public int IndexOf(TElement item) => -1;

public void CopyTo(TElement[] array, int arrayIndex) { }

public TElement this[int index]
{
get
{
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
return default!;
}
set => ThrowHelper.ThrowNotSupportedException();
}

public bool IsReadOnly => true;

void ICollection<TElement>.Add(TElement item) => ThrowHelper.ThrowNotSupportedException();
void ICollection<TElement>.Clear() => ThrowHelper.ThrowNotSupportedException();
void IList<TElement>.Insert(int index, TElement item) => ThrowHelper.ThrowNotSupportedException();
bool ICollection<TElement>.Remove(TElement item) => ThrowHelper.ThrowNotSupportedException_Boolean();
void IList<TElement>.RemoveAt(int index) => ThrowHelper.ThrowNotSupportedException();
}

internal sealed class OrderedPartition<TElement> : IPartition<TElement>
Expand Down Expand Up @@ -143,7 +169,7 @@ public static partial class Enumerable
/// </summary>
/// <typeparam name="TSource">The type of the source list.</typeparam>
[DebuggerDisplay("Count = {Count}")]
private sealed class ListPartition<TSource> : Iterator<TSource>, IPartition<TSource>
private sealed class ListPartition<TSource> : Iterator<TSource>, IPartition<TSource>, IList<TSource>, IReadOnlyList<TSource>
{
private readonly IList<TSource> _source;
private readonly int _minIndexInclusive;
Expand Down Expand Up @@ -231,7 +257,7 @@ public IPartition<TSource> Take(int count)
return default;
}

private int Count
public int Count
{
get
{
Expand All @@ -245,6 +271,8 @@ private int Count
}
}

public int GetCount(bool onlyIfCheap) => Count;

public TSource[] ToArray()
{
int count = Count;
Expand All @@ -271,6 +299,9 @@ public List<TSource> ToList()
return list;
}

public void CopyTo(TSource[] array, int arrayIndex) =>
Fill(_source, array.AsSpan(arrayIndex, Count), _minIndexInclusive);

private static void Fill(IList<TSource> source, Span<TSource> destination, int sourceIndex)
{
for (int i = 0; i < destination.Length; i++, sourceIndex++)
Expand All @@ -279,7 +310,45 @@ private static void Fill(IList<TSource> source, Span<TSource> destination, int s
}
}

public int GetCount(bool onlyIfCheap) => Count;
public bool Contains(TSource item) => IndexOf(item) >= 0;

public int IndexOf(TSource item)
{
IList<TSource> source = _source;

int end = _minIndexInclusive + Count;
for (int i = _minIndexInclusive; i < end; i++)
{
if (EqualityComparer<TSource>.Default.Equals(source[i], item))
{
return i - _minIndexInclusive;
}
}

return -1;
}

public TSource this[int index]
{
get
{
if ((uint)index >= (uint)Count)
{
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
}

return _source[_minIndexInclusive + index];
}
set => ThrowHelper.ThrowNotSupportedException();
}

public bool IsReadOnly => true;

void ICollection<TSource>.Add(TSource item) => ThrowHelper.ThrowNotSupportedException();
void ICollection<TSource>.Clear() => ThrowHelper.ThrowNotSupportedException();
void IList<TSource>.Insert(int index, TSource item) => ThrowHelper.ThrowNotSupportedException();
bool ICollection<TSource>.Remove(TSource item) => ThrowHelper.ThrowNotSupportedException_Boolean();
void IList<TSource>.RemoveAt(int index) => ThrowHelper.ThrowNotSupportedException();
}

/// <summary>
Expand Down
35 changes: 34 additions & 1 deletion src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace System.Linq
{
public static partial class Enumerable
{
private sealed partial class RangeIterator : IPartition<int>
private sealed partial class RangeIterator : IPartition<int>, IList<int>, IReadOnlyList<int>
{
public override IEnumerable<TResult> Select<TResult>(Func<int, TResult> selector)
{
Expand All @@ -28,6 +28,9 @@ public List<int> ToList()
return list;
}

public void CopyTo(int[] array, int arrayIndex) =>
Fill(array.AsSpan(arrayIndex, _end - _start), _start);

private static void Fill(Span<int> destination, int value)
{
for (int i = 0; i < destination.Length; i++, value++)
Expand All @@ -38,6 +41,8 @@ private static void Fill(Span<int> destination, int value)

public int GetCount(bool onlyIfCheap) => unchecked(_end - _start);

public int Count => _end - _start;

public IPartition<int> Skip(int count)
{
if (count >= _end - _start)
Expand Down Expand Up @@ -82,6 +87,34 @@ public int TryGetLast(out bool found)
found = true;
return _end - 1;
}

public bool Contains(int item) =>
(uint)(item - _start) < (uint)(_end - _start);

public int IndexOf(int item) =>
Contains(item) ? item - _start : -1;

public int this[int index]
{
get
{
if ((uint)index >= (uint)(_end - _start))
{
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
}

return _start + index;
}
set => ThrowHelper.ThrowNotSupportedException();
}

public bool IsReadOnly => true;

void ICollection<int>.Add(int item) => ThrowHelper.ThrowNotSupportedException();
void ICollection<int>.Clear() => ThrowHelper.ThrowNotSupportedException();
void IList<int>.Insert(int index, int item) => ThrowHelper.ThrowNotSupportedException();
bool ICollection<int>.Remove(int item) => ThrowHelper.ThrowNotSupportedException_Boolean();
void IList<int>.RemoveAt(int index) => ThrowHelper.ThrowNotSupportedException();
}
}
}
38 changes: 36 additions & 2 deletions src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;

namespace System.Linq
{
public static partial class Enumerable
{
private sealed partial class RepeatIterator<TResult> : IPartition<TResult>
private sealed partial class RepeatIterator<TResult> : IPartition<TResult>, IList<TResult>, IReadOnlyList<TResult>
{
public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
new SelectIPartitionIterator<TResult, TResult2>(this, selector);
Expand All @@ -35,6 +34,8 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap) => _count;

public int Count => _count;

public IPartition<TResult> Skip(int count)
{
Debug.Assert(count > 0);
Expand Down Expand Up @@ -82,6 +83,39 @@ public TResult TryGetLast(out bool found)
found = true;
return _current;
}

public bool Contains(TResult item)
{
Debug.Assert(_count > 0);
return EqualityComparer<TResult>.Default.Equals(_current, item);
}

public int IndexOf(TResult item) => Contains(item) ? 0 : -1;

public void CopyTo(TResult[] array, int arrayIndex) =>
array.AsSpan(arrayIndex, _count).Fill(_current);

public TResult this[int index]
{
get
{
if ((uint)index >= (uint)_count)
{
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
}

return _current;
}
set => ThrowHelper.ThrowNotSupportedException();
}

public bool IsReadOnly => true;

void ICollection<TResult>.Add(TResult item) => ThrowHelper.ThrowNotSupportedException();
void ICollection<TResult>.Clear() => ThrowHelper.ThrowNotSupportedException();
void IList<TResult>.Insert(int index, TResult item) => ThrowHelper.ThrowNotSupportedException();
bool ICollection<TResult>.Remove(TResult item) => ThrowHelper.ThrowNotSupportedException_Boolean();
void IList<TResult>.RemoveAt(int index) => ThrowHelper.ThrowNotSupportedException();
}
}
}
3 changes: 3 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ internal static class ThrowHelper
[DoesNotReturn]
internal static void ThrowNotSupportedException() => throw new NotSupportedException();

[DoesNotReturn]
internal static bool ThrowNotSupportedException_Boolean() => throw new NotSupportedException();

[DoesNotReturn]
internal static void ThrowOverflowException() => throw new OverflowException();

Expand Down
32 changes: 31 additions & 1 deletion src/libraries/System.Linq/tests/EmptyEnumerable.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Xunit;

namespace System.Linq.Tests
Expand All @@ -12,7 +13,7 @@ private void TestEmptyCached<T>()
var enumerable1 = Enumerable.Empty<T>();
var enumerable2 = Enumerable.Empty<T>();

Assert.Same(enumerable1, enumerable2); // Enumerable.Empty is not cached if not the same.
Assert.Same(enumerable1, enumerable2);
}

[Fact]
Expand All @@ -39,5 +40,34 @@ public void EmptyEnumerableIsIndeedEmpty()
TestEmptyEmpty<object>();
TestEmptyEmpty<EmptyEnumerableTest>();
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsSpeedOptimized))]
public void IListImplementationIsValid()
{
IList<int> list = Assert.IsAssignableFrom<IList<int>>(Enumerable.Empty<int>());
IReadOnlyList<int> roList = Assert.IsAssignableFrom<IReadOnlyList<int>>(Enumerable.Empty<int>());

Assert.Throws<NotSupportedException>(() => list.Add(42));
Assert.Throws<NotSupportedException>(() => list.Insert(0, 42));
Assert.Throws<NotSupportedException>(() => list.Clear());
Assert.Throws<NotSupportedException>(() => list.Remove(42));
Assert.Throws<NotSupportedException>(() => list.RemoveAt(0));
Assert.Throws<NotSupportedException>(() => list[0] = 42);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => list[0]);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => roList[0]);

Assert.True(list.IsReadOnly);
Assert.Equal(0, list.Count);
Assert.Equal(0, roList.Count);

Assert.False(list.Contains(42));
Assert.Equal(-1, list.IndexOf(42));

list.CopyTo(Array.Empty<int>(), 0);
list.CopyTo(Array.Empty<int>(), 1);
int[] array = new int[1] { 42 };
list.CopyTo(array, 0);
Assert.Equal(42, array[0]);
}
}
}
29 changes: 29 additions & 0 deletions src/libraries/System.Linq/tests/EmptyPartitionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,34 @@ public void ResetIsNop()
en.Reset();
en.Reset();
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsSpeedOptimized))]
public void IListImplementationIsValid()
{
IList<int> list = Assert.IsAssignableFrom<IList<int>>(Enumerable.Empty<int>());
IReadOnlyList<int> roList = Assert.IsAssignableFrom<IReadOnlyList<int>>(Enumerable.Empty<int>());

Assert.Throws<NotSupportedException>(() => list.Add(42));
Assert.Throws<NotSupportedException>(() => list.Insert(0, 42));
Assert.Throws<NotSupportedException>(() => list.Clear());
Assert.Throws<NotSupportedException>(() => list.Remove(42));
Assert.Throws<NotSupportedException>(() => list.RemoveAt(0));
Assert.Throws<NotSupportedException>(() => list[0] = 42);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => list[0]);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => roList[0]);

Assert.True(list.IsReadOnly);
Assert.Equal(0, list.Count);
Assert.Equal(0, roList.Count);

Assert.False(list.Contains(42));
Assert.Equal(-1, list.IndexOf(42));

list.CopyTo(Array.Empty<int>(), 0);
list.CopyTo(Array.Empty<int>(), 1);
int[] array = new int[1] { 42 };
list.CopyTo(array, 0);
Assert.Equal(42, array[0]);
}
}
}
43 changes: 43 additions & 0 deletions src/libraries/System.Linq/tests/RangeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,48 @@ public void LastOrDefault()
{
Assert.Equal(int.MaxValue - 101, Enumerable.Range(-100, int.MaxValue).LastOrDefault());
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsSpeedOptimized))]
public void IListImplementationIsValid()
{
Validate(Enumerable.Range(42, 10), new[] { 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 });
Validate(Enumerable.Range(42, 10).Skip(3).Take(4), new[] { 45, 46, 47, 48 });

static void Validate(IEnumerable<int> e, int[] expected)
{
IList<int> list = Assert.IsAssignableFrom<IList<int>>(e);
IReadOnlyList<int> roList = Assert.IsAssignableFrom<IReadOnlyList<int>>(e);

Assert.Throws<NotSupportedException>(() => list.Add(42));
Assert.Throws<NotSupportedException>(() => list.Insert(0, 42));
Assert.Throws<NotSupportedException>(() => list.Clear());
Assert.Throws<NotSupportedException>(() => list.Remove(42));
Assert.Throws<NotSupportedException>(() => list[0] = 42);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => list[-1]);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => list[expected.Length]);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => roList[-1]);
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => roList[expected.Length]);

Assert.True(list.IsReadOnly);
Assert.Equal(expected.Length, list.Count);
Assert.Equal(expected.Length, roList.Count);

Assert.False(list.Contains(expected[0] - 1));
Assert.False(list.Contains(expected[^1] + 1));
Assert.All(expected, i => Assert.True(list.Contains(i)));
Assert.All(expected, i => Assert.Equal(Array.IndexOf(expected, i), list.IndexOf(i)));
for (int i = 0; i < expected.Length; i++)
{
Assert.Equal(expected[i], list[i]);
Assert.Equal(expected[i], roList[i]);
}

int[] actual = new int[expected.Length + 2];
list.CopyTo(actual, 1);
Assert.Equal(0, actual[0]);
Assert.Equal(0, actual[^1]);
AssertExtensions.SequenceEqual(expected, actual.AsSpan(1, expected.Length));
}
}
}
}
Loading

0 comments on commit 0d77cf0

Please sign in to comment.