Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][IR] Add textual shorthand for specifying constant vector splats. #74620

Merged
merged 1 commit into from
Dec 8, 2023

Conversation

paulwalker-arm
Copy link
Collaborator

Add LL parsing for <N x ty> splat(ty <imm>) that lowers onto ConstantInt::get() for integer types and ConstantFP::get() for floating-point types.

The intent is to extend ConstantInt/FP classes to support vector types rather than redirecting to other constant classes as the get() methods do today.

This patch gives IR writers the convenience of using the shorthand today, thus allowing existing tests to be ported.

@llvmbot llvmbot added the llvm:ir label Dec 6, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-llvm-ir

Author: Paul Walker (paulwalker-arm)

Changes

Add LL parsing for &lt;N x ty&gt; splat(ty &lt;imm&gt;) that lowers onto ConstantInt::get() for integer types and ConstantFP::get() for floating-point types.

The intent is to extend ConstantInt/FP classes to support vector types rather than redirecting to other constant classes as the get() methods do today.

This patch gives IR writers the convenience of using the shorthand today, thus allowing existing tests to be ported.


Full diff: https://github.com/llvm/llvm-project/pull/74620.diff

5 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+5)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+39-11)
  • (added) llvm/test/Assembler/constant-splat.ll (+57)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index f5e8065ca1dc6..f7977488a54fd 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4299,6 +4299,11 @@ constants and smaller complex constants.
     "``< i32 42, i32 11, i32 74, i32 100 >``". Vector constants
     must have :ref:`vector type <t_vector>`, and the number and types of
     elements must match those specified by the type.
+
+    When creating a vector whose elements have the same constant value, the
+    prefered syntax is ``splat (<Ty> Val)``. For example: "``splat (i32 11)``".
+    These vector constants must have ::ref:`vector type <t_vector>` with an
+    element type that matches the ``splat`` operand.
 **Zero initialization**
     The string '``zeroinitializer``' can be used to zero initialize a
     value to zero of *any* type, including scalar and
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0aa0093e8efbd..147cf56c821aa 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -336,6 +336,7 @@ enum Kind {
   kw_extractelement,
   kw_insertelement,
   kw_shufflevector,
+  kw_splat,
   kw_extractvalue,
   kw_insertvalue,
   kw_blockaddress,
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index bf01b39e6f971..919c69fe2783e 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -698,6 +698,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(uinc_wrap);
   KEYWORD(udec_wrap);
 
+  KEYWORD(splat);
   KEYWORD(vscale);
   KEYWORD(x);
   KEYWORD(blockaddress);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 5c7d551b2b31a..76c6694ae2b2a 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3981,6 +3981,31 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
 
+  case lltok::kw_splat: {
+    Lex.Lex();
+    if (parseToken(lltok::lparen, "expected '(' after vector splat"))
+      return true;
+    Constant *C;
+    if (parseGlobalTypeAndValue(C))
+      return true;
+    if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
+      return true;
+
+    if (auto *CI = dyn_cast<ConstantInt>(C)) {
+      ID.APSIntVal = CI->getValue();
+      ID.Kind = ValID::t_APSInt;
+      return false;
+    }
+
+    if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+      ID.APFloatVal = CFP->getValue();
+      ID.Kind = ValID::t_APFloat;
+      return false;
+    }
+
+    return tokError("invalid splat operand");
+  }
+
   case lltok::kw_getelementptr:
   case lltok::kw_shufflevector:
   case lltok::kw_insertelement:
@@ -5740,14 +5765,15 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       V = NoCFIValue::get(cast<GlobalValue>(V));
     return V == nullptr;
   case ValID::t_APSInt:
-    if (!Ty->isIntegerTy())
+    if (!Ty->isIntOrIntVectorTy())
       return error(ID.Loc, "integer constant must have integer type");
-    ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getPrimitiveSizeInBits());
-    V = ConstantInt::get(Context, ID.APSIntVal);
+    ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getScalarSizeInBits());
+    V = ConstantInt::get(Ty, ID.APSIntVal);
     return false;
-  case ValID::t_APFloat:
-    if (!Ty->isFloatingPointTy() ||
-        !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
+  case ValID::t_APFloat: {
+    Type *ScalarTy = Ty->getScalarType();
+    if (!ScalarTy->isFloatingPointTy() ||
+        !ConstantFP::isValueValidForType(ScalarTy, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
     // The lexer has no type info, so builds all half, bfloat, float, and double
@@ -5756,13 +5782,13 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (ScalarTy->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isBFloatTy())
+      else if (ScalarTy->isBFloatTy())
         ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isFloatTy())
+      else if (ScalarTy->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       if (IsSNAN) {
@@ -5774,13 +5800,15 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
                                          ID.APFloatVal.isNegative(), &Payload);
       }
     }
-    V = ConstantFP::get(Context, ID.APFloatVal);
 
-    if (V->getType() != Ty)
+    if (Type::getFloatingPointTy(Context, ID.APFloatVal.getSemantics()) !=
+        ScalarTy)
       return error(ID.Loc, "floating point constant does not have type '" +
                                getTypeString(Ty) + "'");
 
+    V = ConstantFP::get(Ty, ID.APFloatVal);
     return false;
+  }
   case ValID::t_Null:
     if (!Ty->isPointerTy())
       return error(ID.Loc, "null must be a pointer type");
diff --git a/llvm/test/Assembler/constant-splat.ll b/llvm/test/Assembler/constant-splat.ll
new file mode 100644
index 0000000000000..25d83f4a1488c
--- /dev/null
+++ b/llvm/test/Assembler/constant-splat.ll
@@ -0,0 +1,57 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+; NOTE: Tests the expansion of the "splat" shorthand method to create vector
+; constants.  Future work will change how "splat" is expanded, ultimately
+; leading to a point where "splat" is emitted as the disassembly.
+
+; CHECK: @constant.splat.i1 = constant <1 x i1> <i1 true>
+@constant.splat.i1 = constant <1 x i1> splat (i1 true)
+
+; CHECK: @constant.splat.i32 = constant <5 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7>
+@constant.splat.i32 = constant <5 x i32> splat (i32 7)
+
+; CHECK: @constant.splat.i128 = constant <2 x i128> <i128 85070591730234615870450834276742070272, i128 85070591730234615870450834276742070272>
+@constant.splat.i128 = constant <2 x i128> splat (i128 85070591730234615870450834276742070272)
+
+; CHECK: @constant.splat.f16 = constant <4 x half> <half 0xHBC00, half 0xHBC00, half 0xHBC00, half 0xHBC00>
+@constant.splat.f16 = constant <4 x half> splat (half 0xHBC00)
+
+; CHECK: @constant.splat.f32 = constant <5 x float> <float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00>
+@constant.splat.f32 = constant <5 x float> splat (float -2.000000e+00)
+
+; CHECK: @constant.splat.f64 = constant <3 x double> <double -3.000000e+00, double -3.000000e+00, double -3.000000e+00>
+@constant.splat.f64 = constant <3 x double> splat (double -3.000000e+00)
+
+; CHECK: @constant.splat.128 = constant <2 x fp128> <fp128 0xL00000000000000018000000000000000, fp128 0xL00000000000000018000000000000000>
+@constant.splat.128 = constant <2 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+
+; CHECK: @constant.splat.bf16 = constant <4 x bfloat> <bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0>
+@constant.splat.bf16 = constant <4 x bfloat> splat (bfloat 0xRC0A0)
+
+; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> <x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800>
+@constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+
+; CHECK: @constant.splat.ppc_fp128 = constant <1 x ppc_fp128> <ppc_fp128 0xM80000000000000000000000000000000>
+@constant.splat.ppc_fp128 = constant <1 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+
+define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
+; CHECK: %add = add <4 x i32> %a, <i32 137, i32 137, i32 137, i32 137>
+  %add = add <4 x i32> %a, splat (i32 137)
+  ret void
+}
+
+define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
+; CHECK: ret <4 x i32> <i32 56, i32 56, i32 56, i32 56>
+  ret <4 x i32> splat (i32 56)
+}
+
+define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
+; CHECK: %add = fadd <vscale x 2 x double> %a, shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> poison, double 5.700000e+00, i64 0), <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer)
+  %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  ret void
+}
+
+define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
+; CHECK: ret <vscale x 4 x i32> shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 78, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+  ret <vscale x 4 x i32> splat (i32 78)
+}

@@ -4299,6 +4299,11 @@ constants and smaller complex constants.
"``< i32 42, i32 11, i32 74, i32 100 >``". Vector constants
must have :ref:`vector type <t_vector>`, and the number and types of
elements must match those specified by the type.

When creating a vector whose elements have the same constant value, the
prefered syntax is ``splat (<Ty> Val)``. For example: "``splat (i32 11)``".
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I being too adventurous here by suggesting this be the preferred syntax?

@nikic
Copy link
Contributor

nikic commented Dec 6, 2023

My main question here is whether this should be limited to immediates, or work with everything that ConstantVector::getSplat() supports. That is, is there any reason why this syntax supports splat (i32 123) but not for example splat (ptr @g) or splat (i64 ptrtoint (ptr @g to i64))?

Ultimately we will need some way to represent all splats we can currently represent -- I assume that the splat (ptr @g) case for scalable vectors is going to be covered by getelementptr i8, ptr @g, <vscale x N x i64> zeroinitializer. (Are there any other tricky cases?)

@paulwalker-arm
Copy link
Collaborator Author

It depends how far down the rabbit hole we want to go. From a textual point of view the change looks easy but given the goal is to move away from ConstantExpr I'd have to ask what the value is to supporting splat (ptr @g) for scalable vectors? We don't support scalable vector globals so the only use is within a function where non-constant IR should be sufficient? I'm generally not that familiar with the benefits and the existing scalable vector support comes from it being a logical extension of existing code rather than having a massive amount of thought behind it.

For today perhaps it's best I extend the textual support to fallback to ConstantVector::getSplat() rather than erroring and defer the decision whether to deprecate non-immediate scalable vector constants. My guess is it'll take some time to migrate the immediate cases anyway.

@nikic
Copy link
Contributor

nikic commented Dec 7, 2023

It depends how far down the rabbit hole we want to go. From a textual point of view the change looks easy but given the goal is to move away from ConstantExpr I'd have to ask what the value is to supporting splat (ptr @g) for scalable vectors? We don't support scalable vector globals so the only use is within a function where non-constant IR should be sufficient? I'm generally not that familiar with the benefits and the existing scalable vector support comes from it being a logical extension of existing code rather than having a massive amount of thought behind it.

My general thought here was that this is necessary for "closure", so that the ConstantVector::getSplat() can safely be used with any Constant * and we don't get edge case assertion failures where scalable vectors are combined with constant expressions. But after looking at some of the uses of ConstantVector::getSplat(), I'm coming around to the idea that we probably don't need this.

llvm/docs/LangRef.rst Outdated Show resolved Hide resolved
llvm/lib/AsmParser/LLParser.cpp Outdated Show resolved Hide resolved
llvm/lib/AsmParser/LLParser.cpp Outdated Show resolved Hide resolved
llvm/lib/AsmParser/LLParser.cpp Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Dec 7, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Add LL parsing for `<N x ty> splat(ty <imm>)` that lowers onto
ConstantInt::get() for integer types and ConstantFP::get() for
floating-point types.

The intent is to extend ConstantInt/FP classes to support vector
types rather than redirecting to other constant classes as the
get() methods do today.

This patch gives IR writers the convenience of using the shorthand
today, thus allowing existing tests to be ported.
@paulwalker-arm paulwalker-arm merged commit 9c6693f into llvm:main Dec 8, 2023
5 checks passed
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Apr 4, 2024
…. NFC

A handy shorthand for specifying the shufflevector(insertelement(poison, foo, 0), poison, zeroinitializer) splat pattern was introduced in llvm#74620.

Some of the RISC-V tests were converted over to use this new form in dbb65dd, this patch handles the rest.

This converts some constant expressions to the new form, but it also converts instruction sequences that weren't previously constant expressions to constant expressions as well. In some cases this causes some minor code motion in the test output, but it's not significant to the test itself. E.g. see llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-{gather,scatter}.ll
lukel97 added a commit that referenced this pull request Apr 9, 2024
…#87616)

A handy shorthand for specifying the shufflevector(insertelement(poison,
foo, 0), poison, zeroinitializer) splat pattern was introduced in
#74620.

Some of the RISC-V tests were converted over to use this new form in
dbb65dd, this patch handles the rest
which didn't have any codegen diffs.

This not only converts some constant expressions to the new form, but
also instruction sequences that weren't previously constant expressions
to constant expressions as well. In some cases this affects codegen, but
these have been omitted here and will be handled in a separate PR.
@paulwalker-arm paulwalker-arm deleted the llvm-ir-splat-macro branch October 4, 2024 17:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants