From 889b67c9d30e3024a1317431d66c22599f6c2011 Mon Sep 17 00:00:00 2001 From: donald chen Date: Sat, 26 Oct 2024 08:07:51 +0800 Subject: [PATCH] [mlir] [memref] add more checks to the memref.reinterpret_cast (#112669) Operation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] : memref to memref<10xf32> A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns with the input parameter. --- .../GPU/Transforms/DecomposeMemRefs.cpp | 13 ++++++++- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 27 ++++++++++--------- .../Dialect/MemRef/Transforms/ExpandOps.cpp | 22 ++++++++++++--- .../Transforms/ExpandStridedMetadata.cpp | 17 +++++------- .../expand-then-convert-to-llvm.mlir | 10 ++----- mlir/test/Dialect/GPU/decompose-memrefs.mlir | 12 ++++----- mlir/test/Dialect/MemRef/expand-ops.mlir | 13 +++++---- .../MemRef/expand-strided-metadata.mlir | 21 +++++---------- mlir/test/Dialect/MemRef/invalid.mlir | 9 +++++++ 9 files changed, 81 insertions(+), 63 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index 2b2d10a7733ece..004d73a77e5359 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -29,6 +29,17 @@ namespace mlir { using namespace mlir; +static MemRefType inferCastResultType(Value source, OpFoldResult offset) { + auto sourceType = cast(source.getType()); + SmallVector staticOffsets; + SmallVector dynamicOffsets; + dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); + auto stridedLayout = + StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {}); + return MemRefType::get({}, sourceType.getElementType(), stridedLayout, + sourceType.getMemorySpace()); +} + static void setInsertionPointToStart(OpBuilder &builder, Value val) { if (auto *parentOp = val.getDefiningOp()) { builder.setInsertionPointAfter(parentOp); @@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, SmallVector offsetsTemp = getAsOpFoldResult(offsets); auto &&[base, offset, ignore] = getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); - auto retType = cast(base.getType()); + MemRefType retType = inferCastResultType(base, offset); return rewriter.create(loc, retType, base, offset, std::nullopt, std::nullopt); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d579a27359dfa0..2219505c9b802f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() { // Match sizes in result memref type and in static_sizes attribute. for (auto [idx, resultSize, expectedSize] : llvm::enumerate(resultType.getShape(), getStaticSizes())) { - if (!ShapedType::isDynamic(resultSize) && - !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) + if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) return emitError("expected result type with size = ") - << expectedSize << " instead of " << resultSize - << " in dim = " << idx; + << (ShapedType::isDynamic(expectedSize) + ? std::string("dynamic") + : std::to_string(expectedSize)) + << " instead of " << resultSize << " in dim = " << idx; } // Match offset and strides in static_offset and static_strides attributes. If @@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() { // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); - if (!ShapedType::isDynamic(resultOffset) && - !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) + if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") - << expectedOffset << " instead of " << resultOffset; + << (ShapedType::isDynamic(expectedOffset) + ? std::string("dynamic") + : std::to_string(expectedOffset)) + << " instead of " << resultOffset; // Match strides in result memref type and in static_strides attribute. for (auto [idx, resultStride, expectedStride] : llvm::enumerate(resultStrides, getStaticStrides())) { - if (!ShapedType::isDynamic(resultStride) && - !ShapedType::isDynamic(expectedStride) && - resultStride != expectedStride) + if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride) return emitError("expected result type with stride = ") - << expectedStride << " instead of " << resultStride - << " in dim = " << idx; + << (ShapedType::isDynamic(expectedStride) + ? std::string("dynamic") + : std::to_string(expectedStride)) + << " instead of " << resultStride << " in dim = " << idx; } return success(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index faba12f5bf82f8..83683c7e617bf8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { strides.resize(rank); Location loc = op.getLoc(); - Value stride = rewriter.create(loc, 1); + Value stride = nullptr; + int64_t staticStride = 1; for (int i = rank - 1; i >= 0; --i) { Value size; // Load dynamic sizes from the shape input, use constants for static dims. @@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { size = rewriter.create(loc, sizeAttr); sizes[i] = sizeAttr; } - strides[i] = stride; - if (i > 0) - stride = rewriter.create(loc, stride, size); + if (stride) + strides[i] = stride; + else + strides[i] = rewriter.getIndexAttr(staticStride); + + if (i > 0) { + if (stride) { + stride = rewriter.create(loc, stride, size); + } else if (op.getType().isDynamicDim(i)) { + stride = rewriter.create( + loc, rewriter.create(loc, staticStride), + size); + } else { + staticStride *= op.getType().getDimSize(i); + } + } } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0), diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index a2049ba4a4924d..087d1fcc2b23ae 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, SmallVector groupStrides; ArrayRef srcShape = sourceType.getShape(); + + OpFoldResult lastValidStride = nullptr; for (int64_t currentDim : reassocGroup) { // Skip size-of-1 dimensions, since right now their strides may be // meaningless. @@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, continue; int64_t currentStride = strides[currentDim]; - groupStrides.push_back(ShapedType::isDynamic(currentStride) - ? origStrides[currentDim] - : builder.getIndexAttr(currentStride)); + lastValidStride = ShapedType::isDynamic(currentStride) + ? origStrides[currentDim] + : builder.getIndexAttr(currentStride); } - if (groupStrides.empty()) { + if (!lastValidStride) { // We're dealing with a 1x1x...x1 shape. The stride is meaningless, // but we still have to make the type system happy. MemRefType collapsedType = collapseShape.getResultType(); @@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, return {builder.getIndexAttr(finalStride)}; } - // For the general case, we just want the minimum stride - // since the collapsed dimensions are contiguous. - auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(), - builder.getContext()); - return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, - groupStrides)}; + return {lastValidStride}; } /// From `reshape_like(memref, subSizes, subStrides))` compute diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 55b1bc9c545a85..ec5ceae57ccb33 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index -// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64 // CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64 // CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index // CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 @@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64 // CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index // CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 -// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64 -// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index -// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32> // CHECK: return %[[RES]] : memref<1x?xf32> // CHECK: } diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir index 56fc9a66b7ace7..1a192219484517 100644 --- a/mlir/test/Dialect/GPU/decompose-memrefs.mlir +++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir @@ -7,8 +7,8 @@ // CHECK: gpu.launch // CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] -// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref -// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> func.func @decompose_store(%arg0 : f32, %arg1 : memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref) { // CHECK: gpu.launch // CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2] -// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref -// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref to memref -// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref> // CHECK: "test.test"(%[[RES]]) : (f32) -> () func.func @decompose_load(%arg0 : memref) { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir index f958a92b751a4a..65932b5814a668 100644 --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -52,14 +52,13 @@ func.func @memref_reshape(%input: memref<*xf32>, // CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, // CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { -// CHECK: [[C1:%.*]] = arith.constant 1 : index // CHECK: [[C8:%.*]] = arith.constant 8 : index -// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index - -// CHECK: [[C1_:%.*]] = arith.constant 1 : index -// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32> +// CHECK: [[C1:%.*]] = arith.constant 1 : index +// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32> // CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index -// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index + +// CHECK: [[C8_:%.*]] = arith.constant 8 : index +// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index // CHECK: [[C0:%.*]] = arith.constant 0 : index // CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> @@ -67,5 +66,5 @@ func.func @memref_reshape(%input: memref<*xf32>, // CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]] // CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8], -// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]] +// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1] // CHECK-SAME: : memref<*xf32> to memref diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 8aac802ba10ae9..647731db439c08 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf // = min(7, 1) // = 1 // -// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)> -// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> -// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)> +// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> // CHECK-LABEL: func @simplify_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // -// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] -// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2] +// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] // -// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1] +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] func.func @simplify_collapse(%arg : memref) -> memref { @@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride // We just return the first dynamic one for this group. // // -// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)> // CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride( // CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2] // // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> // -// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1] -// -// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2] +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2] func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>) -> memref<6x1xi32, strided<[?, ?], offset: ?>> { @@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride // Stride 2 = origStride5 // = 1 // -// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> -// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> // CHECK-LABEL: func @extract_strided_metadata_of_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // @@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // -// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0] // CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] // -// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]] +// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]] func.func @extract_strided_metadata_of_collapse(%arg : memref) -> (memref, index, index, index, index, diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 0f533cb95a0ca9..51c4781c9022b2 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref) { // ----- +func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref, %offset : index) { + // expected-error @+1 {{expected result type with offset = dynamic instead of 0}} + %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] + : memref to memref<10xf32> + return +} + +// ----- + func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref) { // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}} %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]