Skip to content

Commit

Permalink
Use custom generalization in mapping + test case
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Aug 6, 2024
1 parent 351098c commit 5403818
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/TPP/PassBundles/TppMapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,

// TODO: Remove when layout propagation and tile-and-fuse have better
// support for named ops.
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
pm.addNestedPass<func::FuncOp>(createGeneralizeNamedOps());
pm.addPass(createCanonicalizerPass());

// Postprocess packing.
Expand Down
22 changes: 22 additions & 0 deletions test/Passes/tpp-mapping.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.maximumf

// -----

func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
%arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> {
%e = tensor.empty() : tensor<64x64xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32>
%1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

// CHECK-LABEL: tile_and_fuse_named(
// CHECK-COUNT-3: tensor.pack
// Fused matmul and relu
// CHECK: scf.forall
// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>)
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.addf
// CHECK: tensor.unpack

0 comments on commit 5403818

Please sign in to comment.