Skip to content

Commit

Permalink
Add remaining set of TensorPrimitives APIs for .NET 8 (#92154)
Browse files Browse the repository at this point in the history
* Add remaining set of TensorPrimitives APIs for .NET 8

Adds non-vectorized implementations of:
- Max
- Min
- MaxMagnitude
- MinMagnitude
- IndexOfMax
- IndexOfMin
- IndexOfMaxMagnitude
- ConvertToHalf (only on .NET Core)
- ConvertToSingle (only on .NET Core)
- IndexOfMinMagnitude

Adds vectorized implementations of:
- Sum
- SumOfSquares
- SumOfMagnitudes
- Product
- ProductOfSums
- ProductOfDifferences

Also includes the helpers that'll make it trivial to vectorize Dot.

Beyond vectorizing the non-vectorized ones, the vectorized implementations should be improved further, including:
- Handling alignment better
- Vectorizing the remainder that doesn't fit in a vector rather than falling back to scalar

* Cleanup after previous PR, vectorize CosineSimilarity/Dot/L2Normalize/Distance, add tests

* Address PR feedback, and fix a few other issues
  • Loading branch information
stephentoub authored Sep 16, 2023
1 parent dc3d344 commit f5cff04
Show file tree
Hide file tree
Showing 15 changed files with 2,084 additions and 268 deletions.
8 changes: 4 additions & 4 deletions src/coreclr/jit/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,7 @@ float FloatingPointUtils::maximumNumber(float x, float y)
//
// It propagates NaN inputs back to the caller and
// otherwise returns the lesser of the inputs. It
// treats +0 as lesser than -0 as per the specification.
// treats +0 as greater than -0 as per the specification.
//
// Arguments:
// val1 - left operand
Expand Down Expand Up @@ -2763,7 +2763,7 @@ double FloatingPointUtils::minimum(double val1, double val2)
//
// It propagates NaN inputs back to the caller and
// otherwise returns the input with a lesser magnitude.
// It treats +0 as lesser than -0 as per the specification.
// It treats +0 as greater than -0 as per the specification.
//
// Arguments:
// x - left operand
Expand Down Expand Up @@ -2856,7 +2856,7 @@ double FloatingPointUtils::minimumNumber(double x, double y)
//
// It propagates NaN inputs back to the caller and
// otherwise returns the lesser of the inputs. It
// treats +0 as lesser than -0 as per the specification.
// treats +0 as greater than -0 as per the specification.
//
// Arguments:
// val1 - left operand
Expand Down Expand Up @@ -2885,7 +2885,7 @@ float FloatingPointUtils::minimum(float val1, float val2)
//
// It propagates NaN inputs back to the caller and
// otherwise returns the input with a lesser magnitude.
// It treats +0 as lesser than -0 as per the specification.
// It treats +0 as greater than -0 as per the specification.
//
// Arguments:
// x - left operand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public UnixImplementation(int elementCount)

public override bool IsReadonly => false;

public override int Length => _elementCount;

public override Memory<T> Memory => _memoryManager.Memory;

public override Span<T> Span
Expand Down Expand Up @@ -83,10 +85,7 @@ protected override void Dispose(bool disposing)
// no-op; the handle will be disposed separately
}

public override Span<T> GetSpan()
{
throw new NotImplementedException();
}
public override Span<T> GetSpan() => _impl.Span;

public override MemoryHandle Pin(int elementIndex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ internal WindowsImplementation(VirtualAllocHandle handle, int byteOffsetIntoHand

public override bool IsReadonly => (Protection != VirtualAllocProtection.PAGE_READWRITE);

public override int Length => _elementCount;

internal VirtualAllocProtection Protection
{
get
Expand Down Expand Up @@ -189,10 +191,7 @@ protected override void Dispose(bool disposing)
// no-op; the handle will be disposed separately
}

public override Span<T> GetSpan()
{
throw new NotImplementedException();
}
public override Span<T> GetSpan() => _impl.Span;

public override MemoryHandle Pin(int elementIndex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public abstract class BoundedMemory<T> : IDisposable where T : unmanaged
/// </summary>
public abstract bool IsReadonly { get; }

/// <summary>Gets the length of the <see cref="BoundedMemory{T}"/> instance.</summary>
public abstract int Length { get; }

/// <summary>
/// Gets the <see cref="Memory{Byte}"/> which represents this native memory.
/// This <see cref="BoundedMemory{T}"/> instance must be kept alive while working with the <see cref="Memory{Byte}"/>.
Expand Down Expand Up @@ -44,5 +47,23 @@ public abstract class BoundedMemory<T> : IDisposable where T : unmanaged
/// OS does not support marking the memory block as read+write.
/// </summary>
public abstract void MakeWriteable();

/// <summary>
/// Gets the <see cref="Span{Byte}"/> which represents this native memory.
/// This <see cref="BoundedMemory{T}"/> instance must be kept alive while working with the <see cref="Span{Byte}"/>.
/// </summary>
public static implicit operator Span<T>(BoundedMemory<T> boundedMemory) => boundedMemory.Span;

/// <summary>
/// Gets the <see cref="ReadOnlySpan{Byte}"/> which represents this native memory.
/// This <see cref="BoundedMemory{T}"/> instance must be kept alive while working with the <see cref="ReadOnlySpan{Byte}"/>.
/// </summary>
public static implicit operator ReadOnlySpan<T>(BoundedMemory<T> boundedMemory) => boundedMemory.Span;

/// <summary>
/// Gets a reference to the element at the specified index.
/// This <see cref="BoundedMemory{T}"/> instance must be kept alive while working with the reference.
/// </summary>
public ref T this[int index] => ref Span[index];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,33 @@ public static void Divide(System.ReadOnlySpan<float> x, System.ReadOnlySpan<floa
public static void Divide(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { }
public static float Dot(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y) { throw null; }
public static void Exp(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static int IndexOfMax(System.ReadOnlySpan<float> x) { throw null; }
public static int IndexOfMaxMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static int IndexOfMin(System.ReadOnlySpan<float> x) { throw null; }
public static int IndexOfMinMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static float L2Normalize(System.ReadOnlySpan<float> x) { throw null; }
public static void Log(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static float Max(System.ReadOnlySpan<float> x) { throw null; }
public static float MaxMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static float Min(System.ReadOnlySpan<float> x) { throw null; }
public static float MinMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static void Multiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { }
public static void Multiply(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { }
public static void MultiplyAdd(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.ReadOnlySpan<float> addend, System.Span<float> destination) { }
public static void MultiplyAdd(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, float addend, System.Span<float> destination) { }
public static void MultiplyAdd(System.ReadOnlySpan<float> x, float y, System.ReadOnlySpan<float> addend, System.Span<float> destination) { }
public static void Negate(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static float Product(System.ReadOnlySpan<float> x) { throw null; }
public static float ProductOfDifferences(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y) { throw null; }
public static float ProductOfSums(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y) { throw null; }
public static void Sigmoid(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static void Sinh(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static void SoftMax(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static void Subtract(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { }
public static void Subtract(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { }
public static float Sum(System.ReadOnlySpan<float> x) { throw null; }
public static float SumOfMagnitudes(System.ReadOnlySpan<float> x) { throw null; }
public static float SumOfSquares(System.ReadOnlySpan<float> x) { throw null; }
public static void Tanh(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
<Compile Include="System.Numerics.Tensors.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
<Compile Include="System.Numerics.Tensors.netcore.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// ------------------------------------------------------------------------------
// Changes to this file must follow the https://aka.ms/api-review process.
// ------------------------------------------------------------------------------

namespace System.Numerics.Tensors
{
public static partial class TensorPrimitives
{
public static void ConvertToHalf(System.ReadOnlySpan<float> source, System.Span<System.Half> destination) { throw null; }
public static void ConvertToSingle(System.ReadOnlySpan<System.Half> source, System.Span<float> destination) { throw null; }
}
}
Loading

0 comments on commit f5cff04

Please sign in to comment.