Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][LLVM] LLVMTypeConverter: Tighten materialization checks #116532

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 17, 2024

This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a MemRefType/UnrankedMemRefType from the unpacked elements of a MemRef descriptor or from a bare pointer.

The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect.

This commit also drops a check around bare pointer materializations:

// This is a bare pointer. We allow bare pointers only for function entry
// blocks.

This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize regardless of the input format.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir-func
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a MemRefType/UnrankedMemRefType from the unpacked elements of a MemRef descriptor or from a bare pointer.

The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect.

This commit also drops a check around bare pointer materializations:

// This is a bare pointer. We allow bare pointers only for function entry
// blocks.

This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize whatever format is requested.


Full diff: https://github.com/llvm/llvm-project/pull/116532.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+17-15)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Helper function that checks if the given value range is a bare pointer.
+  auto isBarePointer = [](ValueRange values) {
+    return values.size() == 1 &&
+           isa<LLVM::LLVMPointerType>(values.front().getType());
+  };
+
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder,
                                  UnrankedMemRefType resultType,
                                  ValueRange inputs, Location loc) {
-    if (inputs.size() == 1) {
-      // Bare pointers are not supported for unranked memrefs because a
-      // memref descriptor cannot be built just from a bare pointer.
+    // Note: Bare pointers are not supported for unranked memrefs because a
+    // memref descriptor cannot be built just from a bare pointer.
+    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
       return Value();
-    }
     Value desc =
         UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     // An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs, Location loc) {
     Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return Value();
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
-        return Value();
+    if (isBarePointer(inputs)) {
       desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
-    } else {
+    } else if (TypeRange(inputs) ==
+               getMemRefDescriptorFields(resultType,
+                                         /*unpackAggregates=*/true)) {
       desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    } else {
+      // The inputs are neither a bare pointer nor an unpacked memref
+      // descriptor. This materialization function cannot be used.
+      return Value();
     }
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a MemRefType/UnrankedMemRefType from the unpacked elements of a MemRef descriptor or from a bare pointer.

The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect.

This commit also drops a check around bare pointer materializations:

// This is a bare pointer. We allow bare pointers only for function entry
// blocks.

This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize whatever format is requested.


Full diff: https://github.com/llvm/llvm-project/pull/116532.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+17-15)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Helper function that checks if the given value range is a bare pointer.
+  auto isBarePointer = [](ValueRange values) {
+    return values.size() == 1 &&
+           isa<LLVM::LLVMPointerType>(values.front().getType());
+  };
+
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder,
                                  UnrankedMemRefType resultType,
                                  ValueRange inputs, Location loc) {
-    if (inputs.size() == 1) {
-      // Bare pointers are not supported for unranked memrefs because a
-      // memref descriptor cannot be built just from a bare pointer.
+    // Note: Bare pointers are not supported for unranked memrefs because a
+    // memref descriptor cannot be built just from a bare pointer.
+    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
       return Value();
-    }
     Value desc =
         UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     // An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs, Location loc) {
     Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return Value();
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
-        return Value();
+    if (isBarePointer(inputs)) {
       desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
-    } else {
+    } else if (TypeRange(inputs) ==
+               getMemRefDescriptorFields(resultType,
+                                         /*unpackAggregates=*/true)) {
       desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    } else {
+      // The inputs are neither a bare pointer nor an unpacked memref
+      // descriptor. This materialization function cannot be used.
+      return Value();
     }
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the

@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_mat_extra_checks branch from df3d0f2 to 5857c76 Compare November 17, 2024 08:16
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:sparse Sparse compiler in MLIR mlir:scf mlir:func labels Nov 17, 2024
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/1n_pattern November 17, 2024 08:16
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me from a defensive programming perspective, but shouldn't a user add their own argument materialization in this case that can handle their type conversion and the dialect conversion will then call that first?
Is there an easy way we can test this by any chance?

@matthias-springer
Copy link
Member Author

shouldn't a user add their own argument materialization in this case that can handle their type conversion and the dialect conversion will then call that first?

Yes that's right. But if users forget to do that, they see a more cryptic error message because we are generating invalid IR. (Inserting a value of incorrect type into an LLVM struct.) If the materialization functions bails, it will trigger the unrealized_conversion_cast materialization, that can be debugged more easily.

One reason why I'd like to add these extra checks is because they helped me adapting the MemRef type conversion rules in #116524. I think these checks will also be useful for others.

Is there an easy way we can test this by any chance?

There was no good place to add such tests until now. I added a new TestPatterns.cpp in test/lib/Dialect/LLVM.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you :))

@matthias-springer matthias-springer force-pushed the users/matthias-springer/1n_pattern branch 2 times, most recently from cf3160e to 63e90d8 Compare November 18, 2024 11:56
@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_mat_extra_checks branch from 67c2e5f to 7025a8c Compare November 18, 2024 13:38
@matthias-springer matthias-springer force-pushed the users/matthias-springer/1n_pattern branch 3 times, most recently from 8a7a042 to 47e321a Compare November 23, 2024 05:20
@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_mat_extra_checks branch from 7025a8c to fe68c3c Compare November 23, 2024 07:45
@matthias-springer matthias-springer changed the base branch from users/matthias-springer/1n_pattern to users/matthias-springer/delete_decompose_call_graph November 23, 2024 07:45
@matthias-springer matthias-springer changed the base branch from users/matthias-springer/delete_decompose_call_graph to main November 24, 2024 03:10
@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_mat_extra_checks branch from fe68c3c to 1511e50 Compare November 24, 2024 03:10
@matthias-springer matthias-springer merged commit a0ef12c into main Nov 24, 2024
6 of 8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/memref_mat_extra_checks branch November 24, 2024 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:func mlir:llvm mlir:scf mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants