Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Improve check for target extension types disallowed in globals. #116639

Merged
merged 2 commits into from
Nov 19, 2024

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Nov 18, 2024

For target extension types that are not allowed to be used as globals,
also disallow them nested inside array and structure types.

For target extension types that are not allowed to be used as globals,
also disallow them nested inside array and structure types.
@llvmbot
Copy link
Member

llvmbot commented Nov 18, 2024

@llvm/pr-subscribers-llvm-ir

Author: Jay Foad (jayfoad)

Changes

For target extension types that are not allowed to be used as globals,
also disallow them nested inside array and structure types.


Full diff: https://github.com/llvm/llvm-project/pull/116639.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/DerivedTypes.h (+9-1)
  • (modified) llvm/include/llvm/IR/Type.h (+6)
  • (modified) llvm/lib/IR/Type.cpp (+44)
  • (modified) llvm/lib/IR/Verifier.cpp (+11-12)
  • (modified) llvm/test/Assembler/target-type-properties.ll (+13-3)
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 19d351456d658e..65f9810776024e 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -225,7 +225,9 @@ class StructType : public Type {
     SCDB_IsLiteral = 4,
     SCDB_IsSized = 8,
     SCDB_ContainsScalableVector = 16,
-    SCDB_NotContainsScalableVector = 32
+    SCDB_NotContainsScalableVector = 32,
+    SCDB_ContainsNonGlobalTargetExtType = 64,
+    SCDB_NotContainsNonGlobalTargetExtType = 128,
   };
 
   /// For a named struct that actually has a name, this is a pointer to the
@@ -294,6 +296,12 @@ class StructType : public Type {
   bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
   using Type::isScalableTy;
 
+  /// Return true if this type is or contains a target extension type that
+  /// disallows being used as a global.
+  bool
+  containsNonGlobalTargetExtType(SmallPtrSetImpl<const Type *> &Visited) const;
+  using Type::containsNonGlobalTargetExtType;
+
   /// Returns true if this struct contains homogeneous scalable vector types.
   /// Note that the definition of homogeneous scalable vector type is not
   /// recursive here. That means the following structure will return false
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 7fa940ab347af6..000fdee45bb861 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -209,6 +209,12 @@ class Type {
   bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
   bool isScalableTy() const;
 
+  /// Return true if this type is or contains a target extension type that
+  /// disallows being used as a global.
+  bool
+  containsNonGlobalTargetExtType(SmallPtrSetImpl<const Type *> &Visited) const;
+  bool containsNonGlobalTargetExtType() const;
+
   /// Return true if this is a FP type or a vector of FP.
   bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); }
 
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 88ede0d35fa3ee..6def8ac27c0fa4 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -72,6 +72,22 @@ bool Type::isScalableTy() const {
   return isScalableTy(Visited);
 }
 
+bool Type::containsNonGlobalTargetExtType(
+    SmallPtrSetImpl<const Type *> &Visited) const {
+  if (const auto *ATy = dyn_cast<ArrayType>(this))
+    return ATy->getElementType()->containsNonGlobalTargetExtType(Visited);
+  if (const auto *STy = dyn_cast<StructType>(this))
+    return STy->containsNonGlobalTargetExtType(Visited);
+  if (auto *TT = dyn_cast<TargetExtType>(this))
+    return !TT->hasProperty(TargetExtType::CanBeGlobal);
+  return false;
+}
+
+bool Type::containsNonGlobalTargetExtType() const {
+  SmallPtrSet<const Type *, 4> Visited;
+  return containsNonGlobalTargetExtType(Visited);
+}
+
 const fltSemantics &Type::getFltSemantics() const {
   switch (getTypeID()) {
   case HalfTyID: return APFloat::IEEEhalf();
@@ -425,6 +441,34 @@ bool StructType::isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const {
   return false;
 }
 
+bool StructType::containsNonGlobalTargetExtType(
+    SmallPtrSetImpl<const Type *> &Visited) const {
+  if ((getSubclassData() & SCDB_ContainsNonGlobalTargetExtType) != 0)
+    return true;
+
+  if ((getSubclassData() & SCDB_NotContainsNonGlobalTargetExtType) != 0)
+    return false;
+
+  if (!Visited.insert(this).second)
+    return false;
+
+  for (Type *Ty : elements()) {
+    if (Ty->containsNonGlobalTargetExtType(Visited)) {
+      const_cast<StructType *>(this)->setSubclassData(
+          getSubclassData() | SCDB_ContainsNonGlobalTargetExtType);
+      return true;
+    }
+  }
+
+  // For structures that are opaque, return false but do not set the
+  // SCDB_NotContainsNonGlobalTargetExtType flag since it may gain non-global
+  // target extension types when it becomes non-opaque.
+  if (!isOpaque())
+    const_cast<StructType *>(this)->setSubclassData(
+        getSubclassData() | SCDB_NotContainsNonGlobalTargetExtType);
+  return false;
+}
+
 bool StructType::containsHomogeneousScalableVectorTypes() const {
   if (getNumElements() <= 0 || !isa<ScalableVectorType>(elements().front()))
     return false;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5cfcd21e508595..5c0ccf734cccbc 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -829,8 +829,10 @@ void Verifier::visitGlobalValue(const GlobalValue &GV) {
 }
 
 void Verifier::visitGlobalVariable(const GlobalVariable &GV) {
+  Type *GVType = GV.getValueType();
+
   if (GV.hasInitializer()) {
-    Check(GV.getInitializer()->getType() == GV.getValueType(),
+    Check(GV.getInitializer()->getType() == GVType,
           "Global variable initializer type does not match global "
           "variable type!",
           &GV);
@@ -854,7 +856,7 @@ void Verifier::visitGlobalVariable(const GlobalVariable &GV) {
 
     // Don't worry about emitting an error for it not being an array,
     // visitGlobalValue will complain on appending non-array.
-    if (ArrayType *ATy = dyn_cast<ArrayType>(GV.getValueType())) {
+    if (ArrayType *ATy = dyn_cast<ArrayType>(GVType)) {
       StructType *STy = dyn_cast<StructType>(ATy->getElementType());
       PointerType *FuncPtrTy =
           PointerType::get(Context, DL.getProgramAddressSpace());
@@ -878,7 +880,6 @@ void Verifier::visitGlobalVariable(const GlobalVariable &GV) {
     Check(GV.materialized_use_empty(),
           "invalid uses of intrinsic global variable", &GV);
 
-    Type *GVType = GV.getValueType();
     if (ArrayType *ATy = dyn_cast<ArrayType>(GVType)) {
       PointerType *PTy = dyn_cast<PointerType>(ATy->getElementType());
       Check(PTy, "wrong type for intrinsic global variable", &GV);
@@ -912,15 +913,13 @@ void Verifier::visitGlobalVariable(const GlobalVariable &GV) {
 
   // Scalable vectors cannot be global variables, since we don't know
   // the runtime size.
-  Check(!GV.getValueType()->isScalableTy(),
-        "Globals cannot contain scalable types", &GV);
-
-  // Check if it's a target extension type that disallows being used as a
-  // global.
-  if (auto *TTy = dyn_cast<TargetExtType>(GV.getValueType()))
-    Check(TTy->hasProperty(TargetExtType::CanBeGlobal),
-          "Global @" + GV.getName() + " has illegal target extension type",
-          TTy);
+  Check(!GVType->isScalableTy(), "Globals cannot contain scalable types", &GV);
+
+  // Check if it is or contains a target extension type that disallows being
+  // used as a global.
+  Check(!GVType->containsNonGlobalTargetExtType(),
+        "Global @" + GV.getName() + " has illegal target extension type",
+        GVType);
 
   if (!GV.hasInitializer()) {
     visitGlobalValue(GV);
diff --git a/llvm/test/Assembler/target-type-properties.ll b/llvm/test/Assembler/target-type-properties.ll
index 49c9d812f1cf4a..60790dbc5c17bb 100644
--- a/llvm/test/Assembler/target-type-properties.ll
+++ b/llvm/test/Assembler/target-type-properties.ll
@@ -1,6 +1,8 @@
 ; RUN: split-file %s %t
 ; RUN: not llvm-as < %t/zeroinit-error.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-ZEROINIT %s
-; RUN: not llvm-as < %t/global-var.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-GLOBALVAR %s
+; RUN: not llvm-as < %t/global-var.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-GLOBAL-VAR %s
+; RUN: not llvm-as < %t/global-array.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-GLOBAL-ARRAY %s
+; RUN: not llvm-as < %t/global-struct.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-GLOBAL-STRUCT %s
 ; Check target extension type properties are verified in the assembler.
 
 ;--- zeroinit-error.ll
@@ -12,5 +14,13 @@ define void @foo() {
 }
 
 ;--- global-var.ll
-@global = external global target("unknown_target_type")
-; CHECK-GLOBALVAR: Global @global has illegal target extension type
+@global_var = external global target("unknown_target_type")
+; CHECK-GLOBAL-VAR: Global @global_var has illegal target extension type
+
+;--- global-array.ll
+@global_array = external global [4 x target("unknown_target_type")]
+; CHECK-GLOBAL-ARRAY: Global @global_array has illegal target extension type
+
+;--- global-struct.ll
+@global_struct = external global {target("unknown_target_type")}
+; CHECK-GLOBAL-STRUCT: Global @global_struct has illegal target extension type

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// target extension types when it becomes non-opaque.
if (!isOpaque())
const_cast<StructType *>(this)->setSubclassData(
getSubclassData() | SCDB_NotContainsNonGlobalTargetExtType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sad about the long name containing both "not" and "non"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sad that our types aren't immutable and we have to do these contortions (https://discourse.llvm.org/t/recursive-types/82707).

@@ -903,7 +947,7 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {

// DirectX resources
if (Name.starts_with("dx."))
return TargetTypeInfo(PointerType::get(C, 0));
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bogner @hekota I had to do this to keep check-clang-codegenhlsl passing. Does it look right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks reasonable to me.

@jayfoad jayfoad requested review from bogner and hekota November 18, 2024 17:15
@jayfoad jayfoad merged commit 497b1ae into llvm:main Nov 19, 2024
8 checks passed
@jayfoad jayfoad deleted the improve-canbeglobal branch November 19, 2024 09:29
swatheesh-mcw pushed a commit to swatheesh-mcw/llvm-project that referenced this pull request Nov 20, 2024
…lvm#116639)

For target extension types that are not allowed to be used as globals,
also disallow them nested inside array and structure types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants