diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td index 859d5fdfea504d..6b31dec004a1e2 100644 --- a/clang/include/clang/Basic/arm_sme.td +++ b/clang/include/clang/Basic/arm_sme.td @@ -873,6 +873,11 @@ let SMETargetGuard = "sme-f8f32" in { [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; def SVMLA_FP8_SINGLE_ZA32_VG4x4 : Inst<"svmla[_single]_za32[_mf8]_vg4x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlall_single_za32_vg4x4", [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; + // FMLALL (multiple) + def SVMLA_FP8_MULTI_ZA32_VG4x2 : Inst<"svmla_za32[_mf8]_vg4x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x2", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; + def SVMLA_FP8_MULTI_ZA32_VG4x4 : Inst<"svmla_za32[_mf8]_vg4x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x4", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; } let SMETargetGuard = "sme-f8f16" in { @@ -892,6 +897,11 @@ let SMETargetGuard = "sme-f8f16" in { [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; def SVMLA_FP8_SINGLE_ZA16_VG2x4 : Inst<"svmla[_single]_za16[_mf8]_vg2x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlal_single_za16_vg2x4", [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; + // FMLAL (multiple) + def SVMLA_FP8_MULTI_ZA16_VG2x2 : Inst<"svmla_za16[_mf8]_vg2x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x2", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; + def SVMLA_FP8_MULTI_ZA16_VG2x4 : Inst<"svmla_za16[_mf8]_vg2x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x4", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>; } } // let SVETargetGuard = InvalidMode diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c index 942c4a24f77812..d603045edf2824 100644 --- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c +++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c @@ -1,4 +1,3 @@ - // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 // REQUIRES: aarch64-registered-target @@ -239,3 +238,79 @@ void test_svmla_single_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8_t void test_svmla_single_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") { SME_ACLE_FUNC(svmla,_single,_za32,_mf8,_vg4x4_fpm)(slice, zn, zm, fpm); } + +// FMLAL (multi) + +// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x2( +// CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZM_COERCE0]], [[ZM_COERCE1]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x2j13svmfloat8x2_tS_m( +// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZM_COERCE0]], [[ZM_COERCE1]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmla_multi_za16_vg2x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") { + SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x2_fpm,,)(slice, zn, zm, fpm); +} + +// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x4( +// CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZN_COERCE2:%.*]], [[ZN_COERCE3:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], [[ZM_COERCE2:%.*]], [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZN_COERCE2]], [[ZN_COERCE3]], [[ZM_COERCE0]], [[ZM_COERCE1]], [[ZM_COERCE2]], [[ZM_COERCE3]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x4j13svmfloat8x4_tS_m( +// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZN_COERCE2:%.*]], [[ZN_COERCE3:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], [[ZM_COERCE2:%.*]], [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZN_COERCE2]], [[ZN_COERCE3]], [[ZM_COERCE0]], [[ZM_COERCE1]], [[ZM_COERCE2]], [[ZM_COERCE3]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmla_multi_za16_vg2x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") { + SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x4_fpm,,)(slice, zn, zm, fpm); +} + +// FMLALL (multi) + +// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x2( +// CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZM_COERCE0]], [[ZM_COERCE1]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x2j13svmfloat8x2_tS_m( +// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZM_COERCE0]], [[ZM_COERCE1]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmla_multi_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") { + SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x2_fpm,,)(slice, zn, zm, fpm); +} + +// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x4( +// CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZN_COERCE2:%.*]], [[ZN_COERCE3:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], [[ZM_COERCE2:%.*]], [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZN_COERCE2]], [[ZN_COERCE3]], [[ZM_COERCE0]], [[ZM_COERCE1]], [[ZM_COERCE2]], [[ZM_COERCE3]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x4j13svmfloat8x4_tS_m( +// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], [[ZN_COERCE0:%.*]], [[ZN_COERCE1:%.*]], [[ZN_COERCE2:%.*]], [[ZN_COERCE3:%.*]], [[ZM_COERCE0:%.*]], [[ZM_COERCE1:%.*]], [[ZM_COERCE2:%.*]], [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], [[ZN_COERCE0]], [[ZN_COERCE1]], [[ZN_COERCE2]], [[ZN_COERCE3]], [[ZM_COERCE0]], [[ZM_COERCE1]], [[ZM_COERCE2]], [[ZM_COERCE3]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmla_multi_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") { + SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x4_fpm,,)(slice, zn, zm, fpm); +} diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c index 9204ca0ae5d4cf..1dbc30196a5544 100644 --- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c @@ -41,4 +41,16 @@ void test_svmla(uint32_t slice, svmfloat8_t zn, svmfloat8x2_t znx2, svmfloat8x4_ // expected-error@+1 {{'svmla_single_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}} svmla_single_za32_mf8_vg4x4_fpm(slice, znx4, zn, fpmr); + + // expected-error@+1 {{'svmla_za16_mf8_vg2x2_fpm' needs target feature sme,sme-f8f16}} + svmla_za16_mf8_vg2x2_fpm(slice, znx2, znx2, fpmr); + + // expected-error@+1 {{'svmla_za16_mf8_vg2x4_fpm' needs target feature sme,sme-f8f16}} + svmla_za16_mf8_vg2x4_fpm(slice, znx4, znx4, fpmr); + + // expected-error@+1 {{'svmla_za32_mf8_vg4x2_fpm' needs target feature sme,sme-f8f32}} + svmla_za32_mf8_vg4x2_fpm(slice, znx2, znx2, fpmr); + + // expected-error@+1 {{'svmla_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}} + svmla_za32_mf8_vg4x4_fpm(slice, znx4, znx4, fpmr); } \ No newline at end of file diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 7fa0d421babc5d..53a66099a92bda 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3856,29 +3856,6 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic; def int_aarch64_neon_famax : AdvSIMD_2VectorArg_Intrinsic; def int_aarch64_neon_famin : AdvSIMD_2VectorArg_Intrinsic; - -// SME FP8 FDOT intrinsics -let TargetPrefix = "aarch64" in { - -class SME2_FP8_FDOT_MULTI_VG1x2 : - DefaultAttrsIntrinsic<[], [llvm_i32_ty, - llvm_nxv16i8_ty, llvm_nxv16i8_ty, - llvm_nxv16i8_ty, llvm_nxv16i8_ty], - [IntrInaccessibleMemOnly, IntrHasSideEffects]>; - -class SME2_FP8_FDOT_MULTI_VG1x4 : - DefaultAttrsIntrinsic<[], [llvm_i32_ty, - llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, - llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty], - [IntrInaccessibleMemOnly, IntrHasSideEffects]>; - - def int_aarch64_sme_fp8_fdot_multi_za16_vg1x2 : SME2_FP8_FDOT_MULTI_VG1x2; - def int_aarch64_sme_fp8_fdot_multi_za16_vg1x4 : SME2_FP8_FDOT_MULTI_VG1x4; - - def int_aarch64_sme_fp8_fdot_multi_za32_vg1x2 : SME2_FP8_FDOT_MULTI_VG1x2; - def int_aarch64_sme_fp8_fdot_multi_za32_vg1x4 : SME2_FP8_FDOT_MULTI_VG1x4; -} - // // FP8 Intrinsics // @@ -3997,6 +3974,18 @@ let TargetPrefix = "aarch64" in { : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty], + [IntrInaccessibleMemOnly, IntrHasSideEffects]>; + + class SME_FP8_ZA_MULTI_VGx2_Intrinsic + : DefaultAttrsIntrinsic<[], [llvm_i32_ty, + llvm_nxv16i8_ty, llvm_nxv16i8_ty, + llvm_nxv16i8_ty, llvm_nxv16i8_ty], + [IntrInaccessibleMemOnly, IntrHasSideEffects]>; + + class SME_FP8_ZA_MULTI_VGx4_Intrinsic + : DefaultAttrsIntrinsic<[], [llvm_i32_ty, + llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, + llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty], [IntrInaccessibleMemOnly, IntrHasSideEffects]>; // // CVT from FP8 to half-precision/BFloat16 multi-vector @@ -4036,6 +4025,9 @@ let TargetPrefix = "aarch64" in { def int_aarch64_sme_fp8_fmlal_single_za16_vg2x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic; def int_aarch64_sme_fp8_fmlal_single_za16_vg2x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic; def int_aarch64_sme_fp8_fmlal_single_za16_vg2x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic; + // Multi + def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic; + def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic; // Quad-vector groups (F8F32) def int_aarch64_sme_fp8_fmlall_lane_za32_vg4x1 : SME_FP8_ZA_LANE_VGx1_Intrinsic; @@ -4045,6 +4037,9 @@ let TargetPrefix = "aarch64" in { def int_aarch64_sme_fp8_fmlall_single_za32_vg4x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic; def int_aarch64_sme_fp8_fmlall_single_za32_vg4x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic; def int_aarch64_sme_fp8_fmlall_single_za32_vg4x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic; + // Multi + def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic; + def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic; // // FP8 FDOT intrinsics @@ -4061,6 +4056,12 @@ let TargetPrefix = "aarch64" in { def int_aarch64_sme_fp8_fdot_single_za16_vg1x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic; def int_aarch64_sme_fp8_fdot_single_za32_vg1x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic; + // Multi + def int_aarch64_sme_fp8_fdot_multi_za16_vg1x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic; + def int_aarch64_sme_fp8_fdot_multi_za32_vg1x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic; + + def int_aarch64_sme_fp8_fdot_multi_za16_vg1x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic; + def int_aarch64_sme_fp8_fdot_multi_za32_vg1x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic; // FVDOT def int_aarch64_sme_fp8_fvdot_lane_za16_vg1x2 : SME_FP8_ZA_LANE_VGx2_Intrinsic; diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 22daaa12e29b98..fedf761f53b647 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -1003,8 +1003,9 @@ defm FMLAL_VG2_MZZ_BtoH : sme2_fp8_fmlal_single_za16<"fmlal", int_aarch64_sme_f defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x2, [FPMR, FPCR]>; defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x4, [FPMR, FPCR]>; -defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>; -defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>; +// FP8 FMLALL (multi) +defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2, [FPMR, FPCR]>; +defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4, [FPMR, FPCR]>; defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>; } //[HasSMEF8F16] @@ -1030,8 +1031,9 @@ defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixO defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x2, [FPMR, FPCR]>; defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x4, [FPMR, FPCR]>; -defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>; -defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>; +// FP8 FMLALL (multi) +defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2, [FPMR, FPCR]>; +defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4, [FPMR, FPCR]>; defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>; } //[HasSMEF8F32] diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index 69780b0765c5b0..81004e70dc179b 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -2262,11 +2262,12 @@ class sme2_mla_long_array_vg2_multi op0, bits<3> op, } multiclass sme2_fp_mla_long_array_vg2_multi op, MatrixOperand matrix_ty, - RegisterOperand multi_vector_ty, - ValueType zpr_ty, SDPatternOperator intrinsic> { - + RegisterOperand multi_vector_ty, ValueType zpr_ty, + SDPatternOperator intrinsic, list uses=[]> { def NAME : sme2_mla_long_array_vg2_multi, - SMEPseudo2Instr; + SMEPseudo2Instr { + let Uses = uses; + } def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo; @@ -2309,9 +2310,11 @@ class sme2_mla_long_array_vg4_multi op0, bits<3> op, multiclass sme2_fp_mla_long_array_vg4_multi op, MatrixOperand matrix_ty, RegisterOperand multi_vector_ty, ValueType zpr_ty, - SDPatternOperator intrinsic> { + SDPatternOperator intrinsic, list uses=[]> { def NAME : sme2_mla_long_array_vg4_multi, - SMEPseudo2Instr; + SMEPseudo2Instr { + let Uses = uses; + } def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo; @@ -3341,9 +3344,11 @@ class sme2_mla_ll_array_vg2_multi op, MatrixOperand matrix_ty, multiclass sme2_mla_ll_array_vg2_multi op, MatrixOperand matrix_ty, - RegisterOperand vector_ty, - ValueType vt, SDPatternOperator intrinsic> { - def NAME : sme2_mla_ll_array_vg2_multi, SMEPseudo2Instr; + RegisterOperand vector_ty, ValueType vt, + SDPatternOperator intrinsic, list uses=[]> { + def NAME : sme2_mla_ll_array_vg2_multi, SMEPseudo2Instr { + let Uses = uses; + } def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo; @@ -3385,9 +3390,11 @@ class sme2_mla_ll_array_vg4_multi op,MatrixOperand matrix_ty, multiclass sme2_mla_ll_array_vg4_multi op, MatrixOperand matrix_ty, - RegisterOperand vector_ty, - ValueType vt, SDPatternOperator intrinsic> { - def NAME : sme2_mla_ll_array_vg4_multi, SMEPseudo2Instr; + RegisterOperand vector_ty, ValueType vt, + SDPatternOperator intrinsic, list uses=[]> { + def NAME : sme2_mla_ll_array_vg4_multi, SMEPseudo2Instr { + let Uses = uses; + } def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo; diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll index 41422e2cf9f560..80009d097e9a5b 100644 --- a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll +++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll @@ -213,3 +213,77 @@ define void @test_fmlall_single_vg4x4(i32 %slice, %zn0, %zm) ret void } + +; FMLAL (multi) + +define void @test_fmlal_multi_vg2x2(i32 %slice, %zn0, %zn1, %zm0, %zm1) { +; CHECK-LABEL: test_fmlal_multi_vg2x2: +; CHECK: // %bb.0: +; CHECK: mov w8, w0 +; CHECK: fmlal za.h[w8, 0:1, vgx2], { z0.b, z1.b }, { z2.b, z3.b } +; CHECK: fmlal za.h[w8, 6:7, vgx2], { z0.b, z1.b }, { z2.b, z3.b } +; CHECK: ret + call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 %slice, + %zn0, %zn1, + %zm0, %zm1) + %add = add i32 %slice, 6 + call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 %add, + %zn0, %zn1, + %zm0, %zm1) + ret void +} + +define void @test_fmlal_multi_vg2x4(i32 %slice, %zn0, %zn1, %zn2, %zn3, +; CHECK-LABEL: test_fmlal_multi_vg2x4: +; CHECK: // %bb.0: +; CHECK: mov w8, w0 +; CHECK: fmlal za.h[w8, 0:1, vgx4], { z0.b - z3.b }, { z4.b - z7.b } +; CHECK: fmlal za.h[w8, 6:7, vgx4], { z0.b - z3.b }, { z4.b - z7.b } +; CHECK: ret + %zm0, %zm1, %zm2, %zm3) { + call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 %slice, + %zn0, %zn1, %zn2, %zn3, + %zm0, %zm1, %zm2, %zm3) + %add = add i32 %slice, 6 + call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 %add, + %zn0, %zn1, %zn2, %zn3, + %zm0, %zm1, %zm2, %zm3) + ret void +} + +; FMLALL (multi) + +define void @test_fmlal_multi_vg4x2(i32 %slice, %zn0, %zn1, %zm0, %zm1) { +; CHECK-LABEL: test_fmlal_multi_vg4x2: +; CHECK: // %bb.0: +; CHECK: mov w8, w0 +; CHECK: fmlall za.s[w8, 0:3, vgx2], { z0.b, z1.b }, { z2.b, z3.b } +; CHECK: fmlall za.s[w8, 4:7, vgx2], { z0.b, z1.b }, { z2.b, z3.b } +; CHECK: ret + call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 %slice, + %zn0, %zn1, + %zm0, %zm1) + %add = add i32 %slice, 4 + call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 %add, + %zn0, %zn1, + %zm0, %zm1) + ret void +} + +define void @test_fmlal_multi_vg4x4(i32 %slice, %zn0, %zn1, %zn2, %zn3, +; CHECK-LABEL: test_fmlal_multi_vg4x4: +; CHECK: // %bb.0: +; CHECK: mov w8, w0 +; CHECK: fmlall za.s[w8, 0:3, vgx4], { z0.b - z3.b }, { z4.b - z7.b } +; CHECK: fmlall za.s[w8, 4:7, vgx4], { z0.b - z3.b }, { z4.b - z7.b } +; CHECK: ret + %zm0, %zm1, %zm2, %zm3) { + call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 %slice, + %zn0, %zn1, %zn2, %zn3, + %zm0, %zm1, %zm2, %zm3) + %add = add i32 %slice, 4 + call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 %add, + %zn0, %zn1, %zn2, %zn3, + %zm0, %zm1, %zm2, %zm3) + ret void +}