Skip to content

Commit

Permalink
[CLANG]Add Neon vectors for fpm8_t
Browse files Browse the repository at this point in the history
This patch adds these new vector sizes for neon:
fpm8x16_t and fpm8x8_t

According to the ARM ACLE PR#323[1].

[1] ARM-software/acle#323
  • Loading branch information
CarolineConcatto committed Jul 22, 2024
1 parent 1fc6bf3 commit fd4d8da
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 17 deletions.
14 changes: 14 additions & 0 deletions clang/include/clang/Basic/arm_fpm8.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===--- arm_fpm8.td - ARM FPM8 compiler interface ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the TableGen definitions from which the ARM BF16 header
// file will be generated.
//
//===----------------------------------------------------------------------===//

include "arm_neon_incl.td"
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/arm_neon_incl.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def OP_UNAVAILABLE : Operation {
// h: half-float
// d: double
// b: bfloat16
// m: fpm8
//
// Typespec modifiers
// ------------------
Expand All @@ -240,6 +241,7 @@ def OP_UNAVAILABLE : Operation {
// B: change to BFloat16
// P: change to polynomial category.
// p: change polynomial to equivalent integer category. Otherwise nop.
// M: change to Fpm8.
//
// >: double element width (vector size unchanged).
// <: half element width (vector size unchanged).
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Basic/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ void AArch64TargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__ARM_FEATURE_BF16_SCALAR_ARITHMETIC", "1");
}

if (HasFpm8) {
Builder.defineMacro("__ARM_FEATURE_FP8", "1");
}
if ((FPU & SveMode) && HasBFloat16) {
Builder.defineMacro("__ARM_FEATURE_SVE_BF16", "1");
}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Basic/Targets/ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ bool ARMTargetInfo::hasBFloat16Type() const {
return HasBFloat16 || (FPU && !SoftFloat);
}

bool ARMTargetInfo::hasFpm8Type() const { return true; }

bool ARMTargetInfo::isValidCPUName(StringRef Name) const {
return Name == "generic" ||
llvm::ARM::parseCPUArch(Name) != llvm::ARM::ArchKind::INVALID;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Basic/Targets/ARM.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class LLVM_LIBRARY_VISIBILITY ARMTargetInfo : public TargetInfo {

bool hasBFloat16Type() const override;

bool hasFpm8Type() const override;

bool isValidCPUName(StringRef Name) const override;
void fillValidCPUList(SmallVectorImpl<StringRef> &Values) const override;

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Headers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ if(ARM IN_LIST LLVM_TARGETS_TO_BUILD OR AArch64 IN_LIST LLVM_TARGETS_TO_BUILD)
clang_generate_header(-gen-arm-sme-header arm_sme.td arm_sme.h)
# Generate arm_bf16.h
clang_generate_header(-gen-arm-bf16 arm_bf16.td arm_bf16.h)
# Generate arm_fpm8.h
clang_generate_header(-gen-arm-fpm8 arm_fpm8.td arm_fpm8.h)
# Generate arm_mve.h
clang_generate_header(-gen-arm-mve-header arm_mve.td arm_mve.h)
# Generate arm_cde.h
Expand All @@ -414,6 +416,7 @@ if(ARM IN_LIST LLVM_TARGETS_TO_BUILD OR AArch64 IN_LIST LLVM_TARGETS_TO_BUILD)
"${CMAKE_CURRENT_BINARY_DIR}/arm_sme.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_bf16.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_vector_types.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_fpm8.h"
)
endif()
if(RISCV IN_LIST LLVM_TARGETS_TO_BUILD)
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10215,6 +10215,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
const VectorType *RHSVecType = RHSType->getAs<VectorType>();
assert(LHSVecType || RHSVecType);

// Any operation with Fpm8 type is only possible with C intrinsics
if ((LHSVecType && LHSVecType->getElementType()->isFpm8Type()) ||
(RHSVecType && RHSVecType->getElementType()->isFpm8Type()))
return InvalidOperands(Loc, LHS, RHS);

// AltiVec-style "vector bool op vector bool" combinations are allowed
// for some operators but not others.
if (!AllowBothBool && LHSVecType &&
Expand Down
85 changes: 71 additions & 14 deletions clang/test/CodeGen/arm-fpm8.c
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 4
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -o - %s | FileCheck %s
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -target-feature +neon -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-C
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -target-feature +neon -o - -x c++ %s | FileCheck %s --check-prefixes=CHECK,CHECK-CXX

// REQUIRES: aarch64-registered-target

// CHECK-LABEL: define dso_local i8 @func1n(
// CHECK-SAME: i8 noundef [[FPM8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: entry:
// CHECK-NEXT: [[FPM8_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-NEXT: store i8 [[FPM8]], ptr [[FPM8_ADDR]], align 1
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[FPM8_ADDR]], align 1
// CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-NEXT: ret i8 [[TMP1]]
// CHECK-C-LABEL: define dso_local i8 @func1n(
// CHECK-C-SAME: i8 noundef [[FPM8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[FPM8_ADDR:%.*]] = alloca i8, align 1
// CHECK-C-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-C-NEXT: store i8 [[FPM8]], ptr [[FPM8_ADDR]], align 1
// CHECK-C-NEXT: [[TMP0:%.*]] = load i8, ptr [[FPM8_ADDR]], align 1
// CHECK-C-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-C-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-C-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-C-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-C-NEXT: ret i8 [[TMP1]]
//
// CHECK-CXX-LABEL: define dso_local noundef i8 @_Z6func1nw(
// CHECK-CXX-SAME: i8 noundef [[FPM8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[FPM8_ADDR:%.*]] = alloca i8, align 1
// CHECK-CXX-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-CXX-NEXT: store i8 [[FPM8]], ptr [[FPM8_ADDR]], align 1
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load i8, ptr [[FPM8_ADDR]], align 1
// CHECK-CXX-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-CXX-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-CXX-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-CXX-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-CXX-NEXT: ret i8 [[TMP1]]
//
__fpm8 func1n(__fpm8 fpm8) {
__fpm8 f1n[10];
Expand All @@ -23,4 +37,47 @@ __fpm8 func1n(__fpm8 fpm8) {
}


#include <arm_neon.h>

// CHECK-C-LABEL: define dso_local <16 x i8> @test_ret_fpm8x16_t(
// CHECK-C-SAME: <16 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
// CHECK-C-NEXT: store <16 x i8> [[V]], ptr [[V_ADDR]], align 16
// CHECK-C-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
// CHECK-C-NEXT: ret <16 x i8> [[TMP0]]
//
// CHECK-CXX-LABEL: define dso_local noundef <16 x i8> @_Z18test_ret_fpm8x16_t13__Fpm8_tx16_t(
// CHECK-CXX-SAME: <16 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
// CHECK-CXX-NEXT: store <16 x i8> [[V]], ptr [[V_ADDR]], align 16
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
// CHECK-CXX-NEXT: ret <16 x i8> [[TMP0]]
//
fpm8x16_t test_ret_fpm8x16_t(fpm8x16_t v) {
return v;
}

// CHECK-C-LABEL: define dso_local <8 x i8> @test_ret_fpm8x8_t(
// CHECK-C-SAME: <8 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
// CHECK-C-NEXT: store <8 x i8> [[V]], ptr [[V_ADDR]], align 8
// CHECK-C-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
// CHECK-C-NEXT: ret <8 x i8> [[TMP0]]
//
// CHECK-CXX-LABEL: define dso_local noundef <8 x i8> @_Z17test_ret_fpm8x8_t12__Fpm8_tx8_t(
// CHECK-CXX-SAME: <8 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
// CHECK-CXX-NEXT: store <8 x i8> [[V]], ptr [[V_ADDR]], align 8
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
// CHECK-CXX-NEXT: ret <8 x i8> [[TMP0]]
//
fpm8x8_t test_ret_fpm8x8_t(fpm8x8_t v) {
return v;
}

//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
// CHECK: {{.*}}
20 changes: 19 additions & 1 deletion clang/test/Sema/arm-fpm8.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %clang_cc1 -fsyntax-only -verify=scalar -triple aarch64-arm-none-eabi -target-feature -fp8 %s
// RUN: %clang_cc1 -fsyntax-only -verify=scalar,neon -triple aarch64-arm-none-eabi \
// RUN: -target-feature -fp8 -target-feature +neon %s

// REQUIRES: aarch64-registered-target
__fpm8 test_static_cast_from_char(char in) {
Expand Down Expand Up @@ -33,3 +34,20 @@ void test(bool b) {
fpm8 + (b ? u8 : fpm8); // scalar-error {{incompatible operand types ('char' and '__fpm8')}}
}

#include <arm_neon.h>

void test_vector(fpm8x8_t a, fpm8x8_t b, uint8x8_t c) {
a + b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a - b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a * b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a / b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}

a + c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a - c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a * c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a / c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
c + b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c - b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c * b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c / b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
}
58 changes: 56 additions & 2 deletions clang/utils/TableGen/NeonEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ enum EltType {
Float16,
Float32,
Float64,
BFloat16
BFloat16,
Fpm8
};

} // end namespace NeonTypeFlags
Expand Down Expand Up @@ -148,6 +149,7 @@ class Type {
UInt,
Poly,
BFloat16,
Fpm8,
};
TypeKind Kind;
bool Immediate, Constant, Pointer;
Expand Down Expand Up @@ -201,6 +203,7 @@ class Type {
bool isLong() const { return isInteger() && ElementBitwidth == 64; }
bool isVoid() const { return Kind == Void; }
bool isBFloat16() const { return Kind == BFloat16; }
bool isFpm8() const { return Kind == Fpm8; }
unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
unsigned getSizeInBits() const { return Bitwidth; }
unsigned getElementSizeInBits() const { return ElementBitwidth; }
Expand Down Expand Up @@ -595,6 +598,8 @@ class NeonEmitter {
// Emit arm_bf16.h.inc
void runBF16(raw_ostream &o);

void runFpm8(raw_ostream &o);

void runVectorTypes(raw_ostream &o);

// Emit all the __builtin prototypes used in arm_neon.h, arm_fp16.h and
Expand Down Expand Up @@ -622,6 +627,8 @@ std::string Type::str() const {
S += "float";
else if (isBFloat16())
S += "bfloat";
else if (isFpm8())
S += "fpm";
else
S += "int";

Expand Down Expand Up @@ -664,6 +671,8 @@ std::string Type::builtin_str() const {
else if (isBFloat16()) {
assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
S += "y";
} else if (isFpm8()) {
S += "c";
} else
switch (ElementBitwidth) {
case 16: S += "h"; break;
Expand Down Expand Up @@ -718,6 +727,11 @@ unsigned Type::getNeonEnum() const {
Base = (unsigned)NeonTypeFlags::Float16 + (Addend - 1);
}

if (isFpm8()) {
assert(Addend == 1 && "Fpm8 is only 8 bit");
Base = (unsigned)NeonTypeFlags::Fpm8;
}

if (isBFloat16()) {
assert(Addend == 1 && "BFloat16 is only 16 bit");
Base = (unsigned)NeonTypeFlags::BFloat16;
Expand All @@ -744,6 +758,8 @@ Type Type::fromTypedefName(StringRef Name) {
T.Kind = Poly;
} else if (Name.consume_front("bfloat")) {
T.Kind = BFloat16;
} else if (Name.consume_front("fpm")) {
T.Kind = Fpm8;
} else {
assert(Name.starts_with("int"));
Name = Name.drop_front(3);
Expand Down Expand Up @@ -840,6 +856,10 @@ void Type::applyTypespec(bool &Quad) {
if (isPoly())
NumVectors = 0;
break;
case 'm':
Kind = Fpm8;
ElementBitwidth = 8;
break;
case 'b':
Kind = BFloat16;
ElementBitwidth = 16;
Expand Down Expand Up @@ -874,6 +894,10 @@ void Type::applyModifiers(StringRef Mods) {
Kind = BFloat16;
ElementBitwidth = 16;
break;
case 'M':
Kind = Fpm8;
ElementBitwidth = 8;
break;
case 'F':
Kind = Float;
break;
Expand Down Expand Up @@ -958,6 +982,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
if (T.isBFloat16())
return "bf16";

if (T.isFpm8())
return "fpm8";

if (T.isPoly())
typeCode = 'p';
else if (T.isInteger())
Expand Down Expand Up @@ -995,7 +1022,7 @@ std::string Intrinsic::getBuiltinTypeStr() {

Type RetT = getReturnType();
if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
!RetT.isFloating() && !RetT.isBFloat16())
!RetT.isFloating() && !RetT.isBFloat16() && !RetT.isFpm8())
RetT.makeInteger(RetT.getElementSizeInBits(), false);

// Since the return value must be one type, return a vector type of the
Expand Down Expand Up @@ -2378,6 +2405,8 @@ void NeonEmitter::run(raw_ostream &OS) {

OS << "#include <arm_bf16.h>\n";

OS << "#include <arm_fpm8.h>\n";

OS << "#include <arm_vector_types.h>\n";

// For now, signedness of polynomial types depends on target
Expand Down Expand Up @@ -2560,6 +2589,27 @@ void NeonEmitter::runFP16(raw_ostream &OS) {
OS << "#endif /* __ARM_FP16_H */\n";
}

void NeonEmitter::runFpm8(raw_ostream &OS) {
OS << "/*===---- arm_fpm8 - ARM vector type "
"------===\n"
" *\n"
" *\n"
" * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
"Exceptions.\n"
" * See https://llvm.org/LICENSE.txt for license information.\n"
" * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
" *\n"
" *===-----------------------------------------------------------------"
"------===\n"
" */\n\n";
OS << "#ifndef __ARM_FPM8_H\n";
OS << "#define __ARM_FPM8_H\n\n";
OS << "typedef __fpm8 fpm8_t;\n";

emitNeonTypeDefs("mQm", OS);
OS << "#endif // __ARM_FPM8_H\n";
}

void NeonEmitter::runVectorTypes(raw_ostream &OS) {
OS << "/*===---- arm_vector_types - ARM vector type "
"------===\n"
Expand Down Expand Up @@ -2682,6 +2732,10 @@ void clang::EmitBF16(RecordKeeper &Records, raw_ostream &OS) {
NeonEmitter(Records).runBF16(OS);
}

void clang::EmitFpm8(RecordKeeper &Records, raw_ostream &OS) {
NeonEmitter(Records).runFpm8(OS);
}

void clang::EmitNeonSema(RecordKeeper &Records, raw_ostream &OS) {
NeonEmitter(Records).runHeader(OS);
}
Expand Down
Loading

0 comments on commit fd4d8da

Please sign in to comment.