From b6eb57086eca17aaac379b401b0d250939e3ca5b Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 7 Mar 2024 19:53:52 +0000 Subject: [PATCH] Hopefully fix vectorization on Mac --- .../VectorCompat.cs | 58 +++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/src/SmartComponents.LocalEmbeddings/VectorCompat.cs b/src/SmartComponents.LocalEmbeddings/VectorCompat.cs index 9cc5813..d387a36 100644 --- a/src/SmartComponents.LocalEmbeddings/VectorCompat.cs +++ b/src/SmartComponents.LocalEmbeddings/VectorCompat.cs @@ -58,7 +58,10 @@ public static unsafe Vector256 Vector256Xor(Vector256 lhs, Vector256 #if NET8_0_OR_GREATER return Vector256.Xor(lhs, rhs); #else - return Vector.Xor(lhs.AsVector(), rhs.AsVector()).AsVector256(); + // Assume the target platform at least supports 128-bit vectors + return Vector256.Create( + Vector128Xor(lhs.GetLower(), rhs.GetLower()), + Vector128Xor(lhs.GetUpper(), rhs.GetUpper())); #endif } @@ -68,7 +71,21 @@ public static unsafe Vector256 Vector256Multiply(Vector256 vector, T va #if NET8_0_OR_GREATER return vector * value; #else - return Vector.Multiply(vector.AsVector(), value).AsVector256(); + // Assume the target platform at least supports 128-bit vectors + var lower = Vector.Multiply(vector.GetLower().AsVector(), value).AsVector128(); + var upper = Vector.Multiply(vector.GetUpper().AsVector(), value).AsVector128(); + return Vector256.Create(lower.AsByte(), upper.AsByte()).As(); +#endif + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe Vector128 Vector128Add(Vector128 lhs, Vector128 rhs) + { +#if NET8_0_OR_GREATER + return lhs + rhs; +#else + // Assume the target platform at least supports 128-bit vectors + return Vector.Add(lhs.AsVector(), rhs.AsVector()).AsVector128(); #endif } @@ -78,7 +95,20 @@ public static unsafe Vector256 Vector256Add(Vector256 lhs, Vector256 Vector128Multiply(Vector128 lhs, Vector128 rhs) where T : unmanaged + { +#if NET8_0_OR_GREATER + return lhs * rhs; +#else + // Assume the target platform at least supports 128-bit vectors + return Vector.Multiply(lhs.AsVector(), rhs.AsVector()).AsVector128(); #endif } @@ -88,20 +118,36 @@ public static unsafe Vector256 Vector256Multiply(Vector256 lhs, Vector2 #if NET8_0_OR_GREATER return lhs * rhs; #else - return Vector.Multiply(lhs.AsVector(), rhs.AsVector()).AsVector256(); + return Vector256.Create( + Vector128Multiply(lhs.GetLower(), rhs.GetLower()).AsByte(), + Vector128Multiply(lhs.GetUpper(), rhs.GetUpper()).AsByte()).As(); #endif } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe T Vector256Sum(Vector256 vector) where T : unmanaged + public static unsafe T Vector128Sum(Vector128 vector) where T : unmanaged { #if NET8_0_OR_GREATER - return Vector256.Sum(vector); + return Vector128.Sum(vector); #else + // Assume the target platform at least supports 128-bit vectors return Vector.Sum(vector.AsVector()); #endif } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe T Vector256Sum(Vector256 vector) where T: unmanaged + { +#if NET8_0_OR_GREATER + return Vector256.Sum(vector); +#else + // Assume the target platform at least supports 128-bit vectors + return Vector.Sum(Vector.Add( + vector.GetLower().AsVector(), + vector.GetUpper().AsVector())); +#endif + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static unsafe Vector256 Vector256ConvertToInt32(Vector256 vector) {