Skip to content

Commit

Permalink
Outer loop parallelization test
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Jun 12, 2024
1 parent 73a8d4e commit 1a90608
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions test/Integration/loop-insertion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,39 @@
// SHUFFLE-CHECK: %[[arg3:.*]] = llvm.inttoptr %[[op3]]
// SHUFFLE-CHECK: func.call @xsmm_brgemm_invoke(%[[c1_i64]], %[[dispatch]], %[[arg1]], %[[offset]], %[[arg2]], %[[offset2]], %[[arg3]], %[[offset3]], %[[c32_i64]])

// RUN: mlir-gen --kernel=args --batch=256 --layers=1024,1024 --tiles=32,32,32 | tpp-run --M-tile-shape=2,4 --N-tile-shape=4,8 --loop-shuffle-order=0,3,2,1 -num-outer-parallel=2 -e=entry -entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s --check-prefix=PARALLEL-CHECK

// PARALLEL-CHECK: func.func @_entry(%[[ARG0:.*]]: memref<8x32x32x32xf32>, %[[ARG1:.*]]: memref<32x32x32x32xf32>, %[[ARG2:.*]]: memref<8x32x32x32xf32>) {
// PARALLEL-CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
// PARALLEL-CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// PARALLEL-CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// PARALLEL-CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
// PARALLEL-CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// PARALLEL-CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64
// PARALLEL-CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64
// PARALLEL-CHECK-DAG: %[[c1024_i64:.*]] = arith.constant 1024 : i64
// PARALLEL-CHECK-DAG: %[[c0_i64:.*]] = arith.constant 0 : i64
// PARALLEL-CHECK: %[[dispatch:.*]] = call @xsmm_brgemm_dispatch
// PARALLEL-CHECK: %[[expandshape:.*]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [2, 4, 32, 32, 32] : memref<8x32x32x32xf32> into memref<2x4x32x32x32xf32>
// PARALLEL-CHECK: %[[expandshape0:.*]] = memref.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 8, 32, 32, 32] : memref<32x32x32x32xf32> into memref<4x8x32x32x32xf32>
// PARALLEL-CHECK: %[[expandshape1:.*]] = memref.expand_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4], [5]] output_shape [2, 4, 4, 8, 32, 32] : memref<8x32x32x32xf32> into memref<2x4x4x8x32x32xf32>
// PARALLEL-CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[c0]], %[[c0]]) to (%[[c2]], %[[c8]]) step (%[[c1]], %[[c1]]) {
// PARALLEL-CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c4]] step %[[c1]] {
// PARALLEL-CHECK: %[[subview:.*]] = memref.subview %[[expandshape0]][%[[ARG5]], %[[ARG4]], 0, 0, 0] [1, 1, 32, 32, 32] [1, 1, 1, 1, 1] : memref<4x8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// PARALLEL-CHECK: scf.for %[[ARG6:.*]] = %[[c0]] to %[[c4]] step %[[c1]] {
// PARALLEL-CHECK: %[[subview2:.*]] = memref.subview %[[expandshape]][%[[ARG3]], %[[ARG6]], 0, 0, 0] [1, 1, 32, 32, 32] [1, 1, 1, 1, 1] : memref<2x4x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// PARALLEL-CHECK: %[[subview3:.*]] = memref.subview %[[expandshape1]][%[[ARG3]], %[[ARG6]], %[[ARG5]], %[[ARG4]], 0, 0] [1, 1, 1, 1, 32, 32] [1, 1, 1, 1, 1, 1] : memref<2x4x4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// PARALLEL-CHECK: {{.*}}, %[[offset:.*]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[subview2]] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// PARALLEL-CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[subview2]]
// PARALLEL-CHECK: %[[op1:.*]] = arith.index_cast %[[intptr]]
// PARALLEL-CHECK: %[[arg1:.*]] = llvm.inttoptr %[[op1]]
// PARALLEL-CHECK: %{{.*}}, %[[offset2:.*]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[subview]] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// PARALLEL-CHECK: %[[intptr2:.*]] = memref.extract_aligned_pointer_as_index %[[subview]]
// PARALLEL-CHECK: %[[op2:.*]] = arith.index_cast %[[intptr2]]
// PARALLEL-CHECK: %[[arg2:.*]] = llvm.inttoptr %[[op2]]
// PARALLEL-CHECK: %{{.*}}, %[[offset3:.*]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[subview3]] : memref<32x32xf32, strided<[32, 1], offset: ?>>
// PARALLEL-CHECK: %[[intptr3:.*]] = memref.extract_aligned_pointer_as_index %[[subview3]]
// PARALLEL-CHECK: %[[op3:.*]] = arith.index_cast %[[intptr3]]
// PARALLEL-CHECK: %[[arg3:.*]] = llvm.inttoptr %[[op3]]
// PARALLEL-CHECK: func.call @xsmm_brgemm_invoke(%[[c1_i64]], %[[dispatch]], %[[arg1]], %[[offset]], %[[arg2]], %[[offset2]], %[[arg3]], %[[offset3]], %[[c32_i64]])

0 comments on commit 1a90608

Please sign in to comment.