From 329aca9ed4fb17dd003281525d109062385a0294 Mon Sep 17 00:00:00 2001 From: Guojin Date: Thu, 14 Nov 2024 14:13:13 -0500 Subject: [PATCH] [CIR][CIRGen][Builtin] Support __builtin_elementwise_abs and extend AbsOp to take vector input (#1099) Extend AbsOp to take vector of int input. With it, we can support __builtin_elementwise_abs. We should in the next PR extend FpUnaryOps to support vector type input so we won't have blocker to implement all elementwise builtins completely. Now just temporarily have missingFeature `fpUnaryOPsSupportVectorType`. Currently, int type UnaryOp support vector type. FYI: [clang's documentation about elementwise builtins](https://clang.llvm.org/docs/LanguageExtensions.html#vector-builtins) --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 5 ++-- clang/include/clang/CIR/Dialect/IR/CIRTypes.h | 1 + .../include/clang/CIR/Dialect/IR/CIRTypes.td | 17 ++++++++++++ clang/include/clang/CIR/MissingFeatures.h | 3 +++ clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp | 19 ++++++++++--- clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 14 +++++++++- clang/test/CIR/CodeGen/builtins-elementwise.c | 27 +++++++++++++++++++ 7 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 clang/test/CIR/CodeGen/builtins-elementwise.c diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 8e43713b8fe4..d97cbfe47a76 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -4331,8 +4331,8 @@ def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">; def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">; def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> { - let arguments = (ins PrimitiveSInt:$src, UnitAttr:$poison); - let results = (outs PrimitiveSInt:$result); + let arguments = (ins CIR_AnySignedIntOrVecOfSignedInt:$src, UnitAttr:$poison); + let results = (outs CIR_AnySignedIntOrVecOfSignedInt:$result); let summary = [{ libc builtin equivalent abs, labs, llabs @@ -4345,6 +4345,7 @@ def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> { ```mlir %0 = cir.const #cir.int<-42> : s32i %1 = cir.abs %0 poison : s32i + %2 = cir.abs %3 : !cir.vector ``` }]; let assemblyFormat = "$src ( `poison` $poison^ )? `:` type($src) attr-dict"; diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.h b/clang/include/clang/CIR/Dialect/IR/CIRTypes.h index 4e9902792eca..9f6eab7c7ba9 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.h +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.h @@ -184,6 +184,7 @@ class StructType bool isAnyFloatingPointType(mlir::Type t); bool isFPOrFPVectorTy(mlir::Type); +bool isIntOrIntVectorTy(mlir::Type); } // namespace cir mlir::ParseResult parseAddrSpaceAttribute(mlir::AsmParser &p, diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 4317aaf3bb01..f73d80402047 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -60,6 +60,9 @@ def CIR_IntType : CIR_Type<"Int", "int", bool isPrimitive() const { return isValidPrimitiveIntBitwidth(getWidth()); } + bool isSignedPrimitive() const { + return isPrimitive() && isSigned(); + } /// Returns a minimum bitwidth of cir::IntType static unsigned minBitwidth() { return 1; } @@ -538,8 +541,22 @@ def IntegerVector : Type< ]>, "!cir.vector of !cir.int"> { } +// Vector of signed integral type +def SignedIntegerVector : Type< + And<[ + CPred<"::mlir::isa<::cir::VectorType>($_self)">, + CPred<"::mlir::isa<::cir::IntType>(" + "::mlir::cast<::cir::VectorType>($_self).getEltType())">, + CPred<"::mlir::cast<::cir::IntType>(" + "::mlir::cast<::cir::VectorType>($_self).getEltType())" + ".isSignedPrimitive()"> + ]>, "!cir.vector of !cir.int"> { +} + // Constraints def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>; +def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf< + [PrimitiveSInt, SignedIntegerVector]>; // Pointer to Arrays def ArrayPtr : Type< diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 211e0d879595..7d59e10809eb 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -328,6 +328,9 @@ struct MissingFeatures { //-- Other missing features + // We need to extend fpUnaryOPs to support vector types. + static bool fpUnaryOPsSupportVectorType() { return false; } + // We need to track the parent record types that represent a field // declaration. This is necessary to determine the layout of a class. static bool fieldDeclAbstraction() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index 1749b0329603..c5fdafe18bb5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -1255,9 +1255,22 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, case Builtin::BI__builtin_nondeterministic_value: llvm_unreachable("BI__builtin_nondeterministic_value NYI"); - case Builtin::BI__builtin_elementwise_abs: - llvm_unreachable("BI__builtin_elementwise_abs NYI"); - + case Builtin::BI__builtin_elementwise_abs: { + mlir::Type cirTy = ConvertType(E->getArg(0)->getType()); + bool isIntTy = cir::isIntOrIntVectorTy(cirTy); + if (!isIntTy) { + if (cir::isAnyFloatingPointType(cirTy)) { + return emitUnaryFPBuiltin(*this, *E); + } + assert(!MissingFeatures::fpUnaryOPsSupportVectorType()); + llvm_unreachable("unsupported type for BI__builtin_elementwise_abs"); + } + mlir::Value arg = emitScalarExpr(E->getArg(0)); + auto call = getBuilder().create(getLoc(E->getExprLoc()), + arg.getType(), arg, false); + mlir::Value result = call->getResult(0); + return RValue::get(result); + } case Builtin::BI__builtin_elementwise_acos: llvm_unreachable("BI__builtin_elementwise_acos NYI"); case Builtin::BI__builtin_elementwise_asin: diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 2e262478a733..bfa8ef62f54e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -828,7 +828,7 @@ bool cir::isAnyFloatingPointType(mlir::Type t) { } //===----------------------------------------------------------------------===// -// Floating-point and Float-point Vecotr type helpers +// Floating-point and Float-point Vector type helpers //===----------------------------------------------------------------------===// bool cir::isFPOrFPVectorTy(mlir::Type t) { @@ -840,6 +840,18 @@ bool cir::isFPOrFPVectorTy(mlir::Type t) { return isAnyFloatingPointType(t); } +//===----------------------------------------------------------------------===// +// CIR Integer and Integer Vector type helpers +//===----------------------------------------------------------------------===// + +bool cir::isIntOrIntVectorTy(mlir::Type t) { + + if (isa(t)) { + return isa(mlir::dyn_cast(t).getEltType()); + } + return isa(t); +} + //===----------------------------------------------------------------------===// // ComplexType Definitions //===----------------------------------------------------------------------===// diff --git a/clang/test/CIR/CodeGen/builtins-elementwise.c b/clang/test/CIR/CodeGen/builtins-elementwise.c new file mode 100644 index 000000000000..857122db0d08 --- /dev/null +++ b/clang/test/CIR/CodeGen/builtins-elementwise.c @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -emit-cir %s -o %t.cir +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s +// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -fclangir \ +// RUN: -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s + +typedef int vint4 __attribute__((ext_vector_type(4))); + +void test_builtin_elementwise_abs(vint4 vi4, int i, float f, double d) { + // CIR-LABEL: test_builtin_elementwise_abs + // LLVM-LABEL: test_builtin_elementwise_abs + // CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.float + // LLVM: {{%.*}} = call float @llvm.fabs.f32(float {{%.*}}) + f = __builtin_elementwise_abs(f); + + // CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.double + // LLVM: {{%.*}} = call double @llvm.fabs.f64(double {{%.*}}) + d = __builtin_elementwise_abs(d); + + // CIR: {{%.*}} = cir.abs {{%.*}} : !cir.vector + // LLVM: {{%.*}} = call <4 x i32> @llvm.abs.v4i32(<4 x i32> {{%.*}}, i1 false) + vi4 = __builtin_elementwise_abs(vi4); + + // CIR: {{%.*}} = cir.abs {{%.*}} : !s32 + // LLVM: {{%.*}} = call i32 @llvm.abs.i32(i32 {{%.*}}, i1 false) + i = __builtin_elementwise_abs(i); +}