Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Implement intrinsics for F1CVTL/F2CVTL and BF1CVTL/BF2CVTL #116959

Merged
merged 7 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/TargetBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ namespace clang {
bool isTupleSet() const { return Flags & IsTupleSet; }
bool isReadZA() const { return Flags & IsReadZA; }
bool isWriteZA() const { return Flags & IsWriteZA; }
bool setsFPMR() const { return Flags & SetsFPMR; }
bool isReductionQV() const { return Flags & IsReductionQV; }
uint64_t getBits() const { return Flags; }
bool isFlagSet(uint64_t Flag) const { return Flags & Flag; }
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Basic/arm_sme.td
Original file line number Diff line number Diff line change
Expand Up @@ -824,4 +824,11 @@ let SMETargetGuard = "sme-lutv2" in {
def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>;
}

let SMETargetGuard = "sme2,fp8" in {
// Convert from half-precision/BFloat16 to deinterleaved FP8 multi-vector
def SVF1CVTL : Inst<"svcvtl1_f16[_mf8]_x2_fpm", "2~n", "h", MergeNone, "aarch64_sme_fp8_f1cvtl_x2", [IsStreaming, IsOverloadNone, SetsFPMR], []>;
def SVF1CVTL_BF : Inst<"svcvtl1_bf16[_mf8]_x2_fpm", "2~n", "b", MergeNone, "aarch64_sme_fp8_bf1cvtl_x2", [IsStreaming, IsOverloadNone, SetsFPMR], []>;
def SVF2CVTL : Inst<"svcvtl2_f16[_mf8]_x2_fpm", "2~n", "h", MergeNone, "aarch64_sme_fp8_f2cvtl_x2", [IsStreaming, IsOverloadNone, SetsFPMR], []>;
def SVF2CVTL_BF : Inst<"svcvtl2_bf16[_mf8]_x2_fpm", "2~n", "b", MergeNone, "aarch64_sme_fp8_bf2cvtl_x2", [IsStreaming, IsOverloadNone, SetsFPMR], []>;
}
} // let SVETargetGuard = InvalidMode
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/arm_sve_sme_incl.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ include "arm_immcheck_incl.td"
// M: svfloat32_t
// N: svfloat64_t
// $: svbfloat16_t
// ~: svmfloat8_t

// J: Prefetch type (sv_prfop)

Expand Down Expand Up @@ -235,6 +236,7 @@ def IsInOutZA : FlagType<0x200000000000>;
def IsInZT0 : FlagType<0x400000000000>;
def IsOutZT0 : FlagType<0x800000000000>;
def IsInOutZT0 : FlagType<0x1000000000000>;
def SetsFPMR : FlagType<0x2000000000000>;

defvar InvalidMode = "";

Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11182,6 +11182,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
BuiltinID == SME::BI__builtin_sme_svstr_za)
return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);

// Emit set FPMR for intrinsics that require it
momchil-velikov marked this conversation as resolved.
Show resolved Hide resolved
if (TypeFlags.setsFPMR())
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr),
Ops.pop_back_val());
// Handle builtins which require their multi-vector operands to be swapped
swapCommutativeSMEOperands(BuiltinID, Ops);

Expand Down
81 changes: 81 additions & 0 deletions clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py

// REQUIRES: aarch64-registered-target

// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
// RUN: %clang_cc1 -DSME_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
// RUN: %clang_cc1 -DSME_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s

#include <arm_sme.h>

#ifdef SME_OVERLOADED_FORMS
#define SME_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
#else
#define SME_ACLE_FUNC(A1,A2,A3) A1##A2##A3
#endif

// CHECK-LABEL: @test_cvt1l_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f1cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt1l_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f1cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt1l_f16_x2(svmfloat8_t zn, uint64_t fpmr) __arm_streaming {
return SME_ACLE_FUNC(svcvtl1_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2l_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f2cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt2l_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f2cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt2l_f16_x2(svmfloat8_t zn, uint64_t fpmr) __arm_streaming {
return SME_ACLE_FUNC(svcvtl2_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt1l_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf1cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z18test_cvt1l_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf1cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt1l_bf16_x2(svmfloat8_t zn, uint64_t fpmr) __arm_streaming {
return SME_ACLE_FUNC(svcvtl1_bf16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2l_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf2cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z18test_cvt2l_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf2cvtl.x2(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt2l_bf16_x2(svmfloat8_t zn, uint64_t fpmr) __arm_streaming {
return SME_ACLE_FUNC(svcvtl2_bf16,_mf8,_x2_fpm)(zn, fpmr);
}
6 changes: 6 additions & 0 deletions clang/utils/TableGen/SveEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,12 @@ void SVEType::applyModifier(char Mod) {
Float = false;
BFloat = false;
break;
case '~':
Float = false;
BFloat = false;
MFloat = true;
ElementBitwidth = 8;
break;
case '.':
llvm_unreachable(". is never a type in itself");
break;
Expand Down
17 changes: 17 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -3813,6 +3813,23 @@ let TargetPrefix = "aarch64" in {
LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>],
[IntrNoMem]>;

class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_nxv8f16_ty, llvm_nxv8f16_ty],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;

class SME2_FP8_CVT_X2_Single_BF16_Intrinsic
: DefaultAttrsIntrinsic<[llvm_nxv8bf16_ty, llvm_nxv8bf16_ty],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
//
// CVT from half-precision/BFloat16 to delinterleaved FP8 multi-vectors
//
def int_aarch64_sme_fp8_f1cvtl_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sme_fp8_f2cvtl_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;

def int_aarch64_sme_fp8_bf1cvtl_x2 : SME2_FP8_CVT_X2_Single_BF16_Intrinsic;
def int_aarch64_sme_fp8_bf2cvtl_x2 : SME2_FP8_CVT_X2_Single_BF16_Intrinsic;
}

// SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
void SelectPExtPair(SDNode *N, unsigned Opc);
void SelectWhilePair(SDNode *N, unsigned Opc);
void SelectCVTIntrinsic(SDNode *N, unsigned NumVecs, unsigned Opcode);
void SelectCVTIntrinsicFP8(SDNode *N, unsigned NumVecs, unsigned Opcode);
void SelectClamp(SDNode *N, unsigned NumVecs, unsigned Opcode);
void SelectUnaryMultiIntrinsic(SDNode *N, unsigned NumOutVecs,
bool IsTupleInput, unsigned Opc);
Expand Down Expand Up @@ -1866,6 +1867,27 @@ void AArch64DAGToDAGISel::SelectCVTIntrinsic(SDNode *N, unsigned NumVecs,
CurDAG->RemoveDeadNode(N);
}

void AArch64DAGToDAGISel::SelectCVTIntrinsicFP8(SDNode *N, unsigned NumVecs,
unsigned Opcode) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SmallVector<SDValue, 4> Ops(N->op_begin() + 2, N->op_end());
Ops.push_back(/*Chain*/ N->getOperand(0));

SDNode *Instruction =
CurDAG->getMachineNode(Opcode, DL, {MVT::Untyped, MVT::Other}, Ops);
SDValue SuperReg = SDValue(Instruction, 0);

for (unsigned i = 0; i < NumVecs; ++i)
ReplaceUses(SDValue(N, i), CurDAG->getTargetExtractSubreg(
AArch64::zsub0 + i, DL, VT, SuperReg));

// Copy chain
unsigned ChainIdx = NumVecs;
ReplaceUses(SDValue(N, ChainIdx), SDValue(Instruction, 1));
CurDAG->RemoveDeadNode(N);
}

void AArch64DAGToDAGISel::SelectDestructiveMultiIntrinsic(SDNode *N,
unsigned NumVecs,
bool IsZmMulti,
Expand Down Expand Up @@ -5547,6 +5569,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
SelectMultiVectorLuti(Node, 4, AArch64::LUTI4_4ZZT2Z);
return;
}
case Intrinsic::aarch64_sme_fp8_bf1cvtl_x2:
SelectCVTIntrinsicFP8(Node, 2, AArch64::BF1CVTL_2ZZ_BtoH);
return;
case Intrinsic::aarch64_sme_fp8_f1cvtl_x2:
SelectCVTIntrinsicFP8(Node, 2, AArch64::F1CVTL_2ZZ_BtoH);
return;
case Intrinsic::aarch64_sme_fp8_bf2cvtl_x2:
SelectCVTIntrinsicFP8(Node, 2, AArch64::BF2CVTL_2ZZ_BtoH);
return;
case Intrinsic::aarch64_sme_fp8_f2cvtl_x2:
SelectCVTIntrinsicFP8(Node, 2, AArch64::F2CVTL_2ZZ_BtoH);
return;
}
} break;
case ISD::INTRINSIC_WO_CHAIN: {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -2412,7 +2412,7 @@ multiclass sme2p1_fp_cvt_vector_vg2_single<string mnemonic, bit l> {

// SME2 multi-vec FP8 up convert two registers
multiclass sme2p1_fp8_cvt_vector_vg2_single<string mnemonic, bits<2> opc, bit L> {
def _NAME : sme2_cvt_unpk_vector_vg2<opc, 0b110, L, ZZ_h_mul_r, ZPR8, mnemonic>{
def NAME : sme2_cvt_unpk_vector_vg2<opc, 0b110, L, ZZ_h_mul_r, ZPR8, mnemonic>{
let Uses = [FPMR, FPCR];
}
}
Expand Down
48 changes: 48 additions & 0 deletions llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s

; F1CVTL / F2CVTL

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvtl(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f1cvtl:
; CHECK: // %bb.0:
; CHECK-NEXT: f1cvtl { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f1cvtl.x2(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

define { <vscale x 8 x half>, <vscale x 8 x half> } @f2cvtl(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f2cvtl:
; CHECK: // %bb.0:
; CHECK-NEXT: f2cvtl { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f2cvtl.x2(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

; BF1CVTL / BF2CVTL

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf1cvtl(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf1cvtl:
; CHECK: // %bb.0:
; CHECK-NEXT: bf1cvtl { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf1cvtl.x2(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf2cvtl( <vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf2cvtl:
; CHECK: // %bb.0:
; CHECK-NEXT: bf2cvtl { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf2cvtl.x2(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}


declare { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f1cvtl.x2(<vscale x 16 x i8>)
SpencerAbson marked this conversation as resolved.
Show resolved Hide resolved
declare { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sme.fp8.f2cvtl.x2(<vscale x 16 x i8>)
declare { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf1cvtl.x2(<vscale x 16 x i8>)
declare { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sme.fp8.bf2cvtl.x2(<vscale x 16 x i8>)
Loading