Skip to content

Commit

Permalink
ARM64-SVE: gathervector
Browse files Browse the repository at this point in the history
  • Loading branch information
a74nh committed Jun 7, 2024
1 parent 7393b6e commit 8d0ac71
Show file tree
Hide file tree
Showing 9 changed files with 1,020 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,12 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(varTypeIsSIMD(op2->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op2ClsHnd));
}
#elif defined(TARGET_ARM64)
if (intrinsic == NI_Sve_GatherVector)
{
assert(varTypeIsSIMD(op3->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
}
#endif
break;
}
Expand Down
31 changes: 31 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,37 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_GatherVector:
{
if (!varTypeIsSIMD(intrin.op2->gtType))
{
// GatherVector(Vector<T> mask, T* address, Vector<T2> indices)

var_types auxType = node->GetAuxiliaryType();
emitAttr auxSize = emitActualTypeSize(auxType);

if (auxSize == EA_8BYTE)
{
opt = varTypeIsUnsigned(auxType) ? INS_OPTS_SCALABLE_D_UXTW : INS_OPTS_SCALABLE_D_SXTW;
}
else
{
assert(auxSize == EA_4BYTE);
opt = varTypeIsUnsigned(auxType) ? INS_OPTS_SCALABLE_S_UXTW : INS_OPTS_SCALABLE_S_SXTW;
}

GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, INS_SCALABLE_OPTS_MOD_N);
}
else
{
// GatherVector(Vector<T> mask, Vector<T2> addresses)

GetEmitter()->emitIns_R_R_R_I(ins, emitSize, targetReg, op1Reg, op2Reg, 0, opt);
}

break;
}

case NI_Sve_ReverseElement:
// Use non-predicated version explicitly
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt, INS_SCALABLE_OPTS_UNPREDICATED);
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated,
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)

HARDWARE_INTRINSIC(Sve, GatherVector, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation)

HARDWARE_INTRINSIC(Sve, GetActiveElementCount, -1, 2, true, {INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation)
HARDWARE_INTRINSIC(Sve, LeadingSignCount, -1, -1, false, {INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, LeadingZeroCount, -1, -1, false, {INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,117 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }


/// Unextended load

/// <summary>
/// svfloat64_t svld1_gather_[s64]index[_f64](svbool_t pg, const float64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svld1_gather[_u64base]_f64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svld1_gather_[u64]index[_f64](svbool_t pg, const float64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svld1_gather_[s32]index[_s32](svbool_t pg, const int32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svld1_gather[_u32base]_s32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svld1_gather_[u32]index[_s32](svbool_t pg, const int32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather_[s64]index[_s64](svbool_t pg, const int64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather[_u64base]_s64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather_[u64]index[_s64](svbool_t pg, const int64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svld1_gather_[s32]index[_f32](svbool_t pg, const float32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svld1_gather[_u32base]_f32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svld1_gather_[u32]index[_f32](svbool_t pg, const float32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svld1_gather_[s32]index[_u32](svbool_t pg, const uint32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svld1_gather[_u32base]_u32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svld1_gather_[u32]index[_u32](svbool_t pg, const uint32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather_[s64]index[_u64](svbool_t pg, const uint64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather[_u64base]_u64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather_[u64]index[_u64](svbool_t pg, const uint64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }


/// Count set predicate bits

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,117 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) => FusedMultiplySubtractNegated(minuend, left, right);


/// Unextended load

/// <summary>
/// svfloat64_t svld1_gather_[s64]index[_f64](svbool_t pg, const float64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svfloat64_t svld1_gather[_u64base]_f64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svfloat64_t svld1_gather_[u64]index[_f64](svbool_t pg, const float64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<ulong> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint32_t svld1_gather_[s32]index[_s32](svbool_t pg, const int32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<int> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint32_t svld1_gather[_u32base]_s32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svint32_t svld1_gather_[u32]index[_s32](svbool_t pg, const int32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint64_t svld1_gather_[s64]index[_s64](svbool_t pg, const int64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint64_t svld1_gather[_u64base]_s64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svint64_t svld1_gather_[u64]index[_s64](svbool_t pg, const int64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<ulong> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svfloat32_t svld1_gather_[s32]index[_f32](svbool_t pg, const float32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<int> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svfloat32_t svld1_gather[_u32base]_f32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svfloat32_t svld1_gather_[u32]index[_f32](svbool_t pg, const float32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint32_t svld1_gather_[s32]index[_u32](svbool_t pg, const uint32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<int> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint32_t svld1_gather[_u32base]_u32(svbool_t pg, svuint32_t bases)
/// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svuint32_t svld1_gather_[u32]index[_u32](svbool_t pg, const uint32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint64_t svld1_gather_[s64]index[_u64](svbool_t pg, const uint64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint64_t svld1_gather[_u64base]_u64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svuint64_t svld1_gather_[u64]index[_u64](svbool_t pg, const uint64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<ulong> indices) => GatherVector(mask, address, indices);


/// Count set predicate bits

/// <summary>
Expand Down
Loading

0 comments on commit 8d0ac71

Please sign in to comment.