Skip to content

Commit

Permalink
Issue #24343 Vector Ctor using Span
Browse files Browse the repository at this point in the history
  • Loading branch information
WinCPP committed Jan 31, 2018
1 parent 4852538 commit 5273aad
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 2 deletions.
251 changes: 251 additions & 0 deletions src/Common/src/CoreLib/System/Numerics/Vector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,257 @@ public unsafe Vector(T[] values, int index)
}
}

/// <summary>
/// Constructs a vector from the given span.
/// The span must contain at least Vector'T.Count elements.
/// </summary>
public unsafe Vector(Span<T> values)
: this()
{
if (values == null)
{
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
}
if (values.Length < Count)
{
throw new IndexOutOfRangeException();
}

if (Vector.IsHardwareAccelerated)
{
if (typeof(T) == typeof(Byte))
{
fixed (Byte* basePtr = &this.register.byte_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Byte)(object)values[g];
}
}
}
else if (typeof(T) == typeof(SByte))
{
fixed (SByte* basePtr = &this.register.sbyte_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (SByte)(object)values[g];
}
}
}
else if (typeof(T) == typeof(UInt16))
{
fixed (UInt16* basePtr = &this.register.uint16_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (UInt16)(object)values[g];
}
}
}
else if (typeof(T) == typeof(Int16))
{
fixed (Int16* basePtr = &this.register.int16_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Int16)(object)values[g];
}
}
}
else if (typeof(T) == typeof(UInt32))
{
fixed (UInt32* basePtr = &this.register.uint32_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (UInt32)(object)values[g];
}
}
}
else if (typeof(T) == typeof(Int32))
{
fixed (Int32* basePtr = &this.register.int32_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Int32)(object)values[g];
}
}
}
else if (typeof(T) == typeof(UInt64))
{
fixed (UInt64* basePtr = &this.register.uint64_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (UInt64)(object)values[g];
}
}
}
else if (typeof(T) == typeof(Int64))
{
fixed (Int64* basePtr = &this.register.int64_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Int64)(object)values[g];
}
}
}
else if (typeof(T) == typeof(Single))
{
fixed (Single* basePtr = &this.register.single_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Single)(object)values[g];
}
}
}
else if (typeof(T) == typeof(Double))
{
fixed (Double* basePtr = &this.register.double_0)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (Double)(object)values[g];
}
}
}
}
else
{
if (typeof(T) == typeof(Byte))
{
fixed (Byte* basePtr = &this.register.byte_0)
{
*(basePtr + 0) = (Byte)(object)values[0];
*(basePtr + 1) = (Byte)(object)values[1];
*(basePtr + 2) = (Byte)(object)values[2];
*(basePtr + 3) = (Byte)(object)values[3];
*(basePtr + 4) = (Byte)(object)values[4];
*(basePtr + 5) = (Byte)(object)values[5];
*(basePtr + 6) = (Byte)(object)values[6];
*(basePtr + 7) = (Byte)(object)values[7];
*(basePtr + 8) = (Byte)(object)values[8];
*(basePtr + 9) = (Byte)(object)values[9];
*(basePtr + 10) = (Byte)(object)values[10];
*(basePtr + 11) = (Byte)(object)values[11];
*(basePtr + 12) = (Byte)(object)values[12];
*(basePtr + 13) = (Byte)(object)values[13];
*(basePtr + 14) = (Byte)(object)values[14];
*(basePtr + 15) = (Byte)(object)values[15];
}
}
else if (typeof(T) == typeof(SByte))
{
fixed (SByte* basePtr = &this.register.sbyte_0)
{
*(basePtr + 0) = (SByte)(object)values[0];
*(basePtr + 1) = (SByte)(object)values[1];
*(basePtr + 2) = (SByte)(object)values[2];
*(basePtr + 3) = (SByte)(object)values[3];
*(basePtr + 4) = (SByte)(object)values[4];
*(basePtr + 5) = (SByte)(object)values[5];
*(basePtr + 6) = (SByte)(object)values[6];
*(basePtr + 7) = (SByte)(object)values[7];
*(basePtr + 8) = (SByte)(object)values[8];
*(basePtr + 9) = (SByte)(object)values[9];
*(basePtr + 10) = (SByte)(object)values[10];
*(basePtr + 11) = (SByte)(object)values[11];
*(basePtr + 12) = (SByte)(object)values[12];
*(basePtr + 13) = (SByte)(object)values[13];
*(basePtr + 14) = (SByte)(object)values[14];
*(basePtr + 15) = (SByte)(object)values[15];
}
}
else if (typeof(T) == typeof(UInt16))
{
fixed (UInt16* basePtr = &this.register.uint16_0)
{
*(basePtr + 0) = (UInt16)(object)values[0];
*(basePtr + 1) = (UInt16)(object)values[1];
*(basePtr + 2) = (UInt16)(object)values[2];
*(basePtr + 3) = (UInt16)(object)values[3];
*(basePtr + 4) = (UInt16)(object)values[4];
*(basePtr + 5) = (UInt16)(object)values[5];
*(basePtr + 6) = (UInt16)(object)values[6];
*(basePtr + 7) = (UInt16)(object)values[7];
}
}
else if (typeof(T) == typeof(Int16))
{
fixed (Int16* basePtr = &this.register.int16_0)
{
*(basePtr + 0) = (Int16)(object)values[0];
*(basePtr + 1) = (Int16)(object)values[1];
*(basePtr + 2) = (Int16)(object)values[2];
*(basePtr + 3) = (Int16)(object)values[3];
*(basePtr + 4) = (Int16)(object)values[4];
*(basePtr + 5) = (Int16)(object)values[5];
*(basePtr + 6) = (Int16)(object)values[6];
*(basePtr + 7) = (Int16)(object)values[7];
}
}
else if (typeof(T) == typeof(UInt32))
{
fixed (UInt32* basePtr = &this.register.uint32_0)
{
*(basePtr + 0) = (UInt32)(object)values[0];
*(basePtr + 1) = (UInt32)(object)values[1];
*(basePtr + 2) = (UInt32)(object)values[2];
*(basePtr + 3) = (UInt32)(object)values[3];
}
}
else if (typeof(T) == typeof(Int32))
{
fixed (Int32* basePtr = &this.register.int32_0)
{
*(basePtr + 0) = (Int32)(object)values[0];
*(basePtr + 1) = (Int32)(object)values[1];
*(basePtr + 2) = (Int32)(object)values[2];
*(basePtr + 3) = (Int32)(object)values[3];
}
}
else if (typeof(T) == typeof(UInt64))
{
fixed (UInt64* basePtr = &this.register.uint64_0)
{
*(basePtr + 0) = (UInt64)(object)values[0];
*(basePtr + 1) = (UInt64)(object)values[1];
}
}
else if (typeof(T) == typeof(Int64))
{
fixed (Int64* basePtr = &this.register.int64_0)
{
*(basePtr + 0) = (Int64)(object)values[0];
*(basePtr + 1) = (Int64)(object)values[1];
}
}
else if (typeof(T) == typeof(Single))
{
fixed (Single* basePtr = &this.register.single_0)
{
*(basePtr + 0) = (Single)(object)values[0];
*(basePtr + 1) = (Single)(object)values[1];
*(basePtr + 2) = (Single)(object)values[2];
*(basePtr + 3) = (Single)(object)values[3];
}
}
else if (typeof(T) == typeof(Double))
{
fixed (Double* basePtr = &this.register.double_0)
{
*(basePtr + 0) = (Double)(object)values[0];
*(basePtr + 1) = (Double)(object)values[1];
}
}
}
}

#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type
Expand Down
61 changes: 61 additions & 0 deletions src/Common/src/CoreLib/System/Numerics/Vector.tt
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,67 @@ namespace System.Numerics
}
}

/// <summary>
/// Constructs a vector from the given span.
/// The span must contain at least Vector'T.Count elements.
/// </summary>
public unsafe Vector(Span<T> values)
: this()
{
if (values == null)
{
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
}
if (values.Length < Count)
{
throw new IndexOutOfRangeException();
}

if (Vector.IsHardwareAccelerated)
{
<# foreach (Type type in supportedTypes)
{
#>
<#=GenerateIfStatementHeader(type)#>
{
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
{
for (int g = 0; g < Count; g++)
{
*(basePtr + g) = (<#=type.Name#>)(object)values[g];
}
}
}
<#
}
#>
}
else
{
<# foreach (Type type in supportedTypes)
{
#>
<#=GenerateIfStatementHeader(type)#>
{
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
{
<#
for (int g = 0; g < GetNumFields(type, totalSize); g++)
{
#>
*(basePtr + <#=g#>) = (<#=type.Name#>)(object)values[<#=g#>];
<#
}
#>
}
}
<#
}
#>
}
}

#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type
Expand Down
1 change: 1 addition & 0 deletions src/System.Numerics.Vectors/ref/System.Numerics.Vectors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ public partial struct Vector<T> : System.IEquatable<System.Numerics.Vector<T>>,
public Vector(T value) { throw null; }
public Vector(T[] values) { throw null; }
public Vector(T[] values, int index) { throw null; }
public Vector(System.Span<T> values) { throw null; }
public static int Count { get { throw null; } }
public T this[int index] { get { throw null; } }
public static System.Numerics.Vector<T> One { get { throw null; } }
Expand Down
35 changes: 34 additions & 1 deletion src/System.Numerics.Vectors/tests/GenericVectorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ private void TestConstructorDefault<T>() where T : struct
});
}

[Fact]
public void ConstructorWithSpanByte() { TestConstructorWithSpan<Byte>(); }
[Fact]
public void ConstructorWithSpanSByte() { TestConstructorWithSpan<SByte>(); }
[Fact]
public void ConstructorWithSpanUInt16() { TestConstructorWithSpan<UInt16>(); }
[Fact]
public void ConstructorWithSpanInt16() { TestConstructorWithSpan<Int16>(); }
[Fact]
public void ConstructorWithSpanUInt32() { TestConstructorWithSpan<UInt32>(); }
[Fact]
public void ConstructorWithSpanInt32() { TestConstructorWithSpan<Int32>(); }
[Fact]
public void ConstructorWithSpanUInt64() { TestConstructorWithSpan<UInt64>(); }
[Fact]
public void ConstructorWithSpanInt64() { TestConstructorWithSpan<Int64>(); }
[Fact]
public void ConstructorWithSpanSingle() { TestConstructorWithSpan<Single>(); }
[Fact]
public void ConstructorWithSpanDouble() { TestConstructorWithSpan<Double>(); }
private void TestConstructorWithSpan<T>() where T : struct
{
T[] values = GenerateRandomValuesForVector<T>().ToArray();
Span<T> valueSpan = new Span<T>(values);

var vector = new Vector<T>(valueSpan);
ValidateVector(vector,
(index, val) =>
{
Assert.Equal(val, values[index]);
});
}

[Fact]
public void ConstructorExceptionByte() { TestConstructorArrayTooSmallException<Byte>(); }
[Fact]
Expand Down Expand Up @@ -2776,4 +2809,4 @@ internal static T GetValueWithAllOnesSet<T>() where T : struct
}
#endregion
}
}
}
Loading

0 comments on commit 5273aad

Please sign in to comment.