-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] LLVMTypeConverter
: Tighten materialization checks
#116532
[mlir][LLVM] LLVMTypeConverter
: Tighten materialization checks
#116532
Conversation
@llvm/pr-subscribers-mlir-func @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect. This commit also drops a check around bare pointer materializations:
This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize whatever format is requested. Full diff: https://github.com/llvm/llvm-project/pull/116532.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Helper function that checks if the given value range is a bare pointer.
+ auto isBarePointer = [](ValueRange values) {
+ return values.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(values.front().getType());
+ };
+
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
+ // Note: Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
return Value();
- }
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return Value();
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
- return Value();
+ if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
- } else {
+ } else if (TypeRange(inputs) ==
+ getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ } else {
+ // The inputs are neither a bare pointer nor an unpacked memref
+ // descriptor. This materialization function cannot be used.
+ return Value();
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
|
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesThis commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect. This commit also drops a check around bare pointer materializations:
This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize whatever format is requested. Full diff: https://github.com/llvm/llvm-project/pull/116532.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Helper function that checks if the given value range is a bare pointer.
+ auto isBarePointer = [](ValueRange values) {
+ return values.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(values.front().getType());
+ };
+
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
+ // Note: Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
return Value();
- }
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return Value();
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
- return Value();
+ if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
- } else {
+ } else if (TypeRange(inputs) ==
+ getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ } else {
+ // The inputs are neither a bare pointer nor an unpacked memref
+ // descriptor. This materialization function cannot be used.
+ return Value();
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
|
df3d0f2
to
5857c76
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me from a defensive programming perspective, but shouldn't a user add their own argument materialization in this case that can handle their type conversion and the dialect conversion will then call that first?
Is there an easy way we can test this by any chance?
Yes that's right. But if users forget to do that, they see a more cryptic error message because we are generating invalid IR. (Inserting a value of incorrect type into an LLVM struct.) If the materialization functions bails, it will trigger the One reason why I'd like to add these extra checks is because they helped me adapting the MemRef type conversion rules in #116524. I think these checks will also be useful for others.
There was no good place to add such tests until now. I added a new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you :))
cf3160e
to
63e90d8
Compare
67c2e5f
to
7025a8c
Compare
8a7a042
to
47e321a
Compare
7025a8c
to
fe68c3c
Compare
fe68c3c
to
1511e50
Compare
This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a
MemRefType
/UnrankedMemRefType
from the unpacked elements of a MemRef descriptor or from a bare pointer.The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect.
This commit also drops a check around bare pointer materializations:
This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize regardless of the input format.