diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp index fda34b369338..52118af6d266 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp @@ -21,6 +21,9 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +// TODO: move these flags to a header file shared with runtime/. +#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1 + namespace mlir { namespace iree_compiler { @@ -873,20 +876,6 @@ struct LinalgMatmulConversion rhs = contract.rhs(); out = op.outputs().front(); } - - Value getOneValue(PatternRewriter &rewriter) { - Location loc = op.getLoc(); - Type elementType = out.getType().cast().getElementType(); - if (auto floatType = elementType.dyn_cast()) { - return rewriter.create( - loc, FloatAttr::get(floatType, 1.0)); - } else if (elementType.isa()) { - return rewriter.create(loc, 1, elementType); - } - - assert(false && "unknown element type"); - return nullptr; - } }; LogicalResult matchAndRewrite(linalg::ContractionOpInterface op, @@ -912,8 +901,7 @@ struct LinalgMatmulConversion } // Switch on contraction type. - if (info.contract.isRowMajorMatmul() || - info.contract.isColumnMajorMatmul()) { + if (info.contract.isRowMajorMatmul()) { if (succeeded(handleConformingMatmul2D(info, rewriter))) { return success(); } @@ -929,27 +917,14 @@ struct LinalgMatmulConversion auto &lhsDesc = info.lhsAnal.getDesc(rewriter); auto &rhsDesc = info.rhsAnal.getDesc(rewriter); auto &outDesc = info.outAnal.getDesc(rewriter); + int flags = IREE_VMVX_MATMUL_FLAG_ACCUMULATE; // Determine m, n, k based on dims. - int flags = 0; - Value m, n, k; - if (info.contract.isRowMajorMatmul()) { - m = lhsDesc.sizes[0]; - k = rhsDesc.sizes[0]; - n = rhsDesc.sizes[1]; - } else if (info.contract.isColumnMajorMatmul()) { - m = lhsDesc.sizes[0]; - k = rhsDesc.sizes[1]; - n = rhsDesc.sizes[0]; - // TODO: Flag constants somewhere. - flags |= 1; - } else { + if (!info.contract.isRowMajorMatmul()) { return failure(); } - - // Alpha/beta: We always start the lowering with alpha/beta set to 1. - // Simplification patterns within VMVX will simplify this if possible. - Value alpha = info.getOneValue(rewriter); - Value beta = alpha; + Value m = lhsDesc.sizes[0]; + Value k = rhsDesc.sizes[0]; + Value n = rhsDesc.sizes[1]; auto lhsBuffer = lhsDesc.castToLinear(loc, rewriter); auto rhsBuffer = rhsDesc.castToLinear(loc, rewriter); @@ -965,8 +940,6 @@ struct LinalgMatmulConversion outBuffer, outDesc.offset, outDesc.strides[0], // m,n,k m, n, k, - // alpha, beta - alpha, beta, // flags lhsDesc.getElementTypeAttr(), rhsDesc.getElementTypeAttr(), outDesc.getElementTypeAttr(), rewriter.getI32IntegerAttr(flags)); diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir index cf687ff2434f..a706e06ad735 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir @@ -54,15 +54,14 @@ func.func @fill2d(%arg0 : memref<384x128xf32>, %arg1 : f32) { } // CHECK-LABEL: @matmul_row_major -// CHECK-DAG: %[[SCALE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: %[[BB0:.*]], %[[OFFSET0:.*]], %[[SIZES0:.*]]:2, %[[STRIDES0:.*]]:2 = vmvx.get_buffer_descriptor %arg0 // CHECK-DAG: %[[BB1:.*]], %[[OFFSET1:.*]], %[[SIZES1:.*]]:2, %[[STRIDES1:.*]]:2 = vmvx.get_buffer_descriptor %arg1 // CHECK-DAG: %[[BB2:.*]], %[[OFFSET2:.*]], %[[SIZES2:.*]]:2, %[[STRIDES2:.*]]:2 = vmvx.get_buffer_descriptor %arg2 // CHECK: vmvx.matmul lhs(%[[BB1]] offset %[[OFFSET1]] row_stride %[[STRIDES1]]#0 : !util.buffer) // CHECK-SAME: rhs(%[[BB2]] offset %[[OFFSET2]] row_stride %[[STRIDES2]]#0 : !util.buffer) // CHECK-SAME: out(%[[BB0]] offset %[[OFFSET0]] row_stride %[[STRIDES0]]#0 : !util.buffer) -// CHECK-SAME: mnk(%[[SIZES1]]#0, %[[SIZES2]]#1, %[[SIZES2]]#0) scale(%[[SCALE]] : f32, %[[SCALE]] : f32) -// CHECK-SAME: flags(0) +// CHECK-SAME: mnk(%[[SIZES1]]#0, %[[SIZES2]]#1, %[[SIZES2]]#0) +// CHECK-SAME: flags(1) func.func @matmul_row_major(%arg0 : memref<64x64xf32>, %arg1 : memref<64x384xf32>, %arg2 : memref<384x64xf32>) { linalg.matmul ins(%arg1, %arg2 : memref<64x384xf32>, memref<384x64xf32>) diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/matmul.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/matmul.mlir index a8734c1df7a4..d1c0a4753153 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/matmul.mlir +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/matmul.mlir @@ -10,21 +10,18 @@ func.func @matmul_f32f32f32( // OUT %arg6 : !util.buffer, %arg7 : index, %arg8 : index, // SIZE - %arg9 : index, %arg10 : index, %arg11 : index, - // SCALE - %arg12 : f32, %arg13 : f32) { + %arg9 : index, %arg10 : index, %arg11 : index) { - // CHECK-DAG: %[[ZERO:.*]] = vm.const.i32.zero + // CHECK-DAG: %[[FLAGS:.*]] = vm.const.i32 1 // CHECK: vm.call @vmvx.matmul.f32f32f32( // CHECK-SAME: %arg0, %arg1, %arg2, // CHECK-SAME: %arg3, %arg4, %arg5, // CHECK-SAME: %arg6, %arg7, %arg8, - // CHECK-SAME: %arg9, %arg10, %arg11, %arg12, %arg13, %[[ZERO]]) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, f32, f32, i32) -> () + // CHECK-SAME: %arg9, %arg10, %arg11, %[[FLAGS]]) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, i32) -> () vmvx.matmul lhs(%arg0 offset %arg1 row_stride %arg2 : !util.buffer) rhs(%arg3 offset %arg4 row_stride %arg5 : !util.buffer) out(%arg6 offset %arg7 row_stride %arg8 : !util.buffer) mnk(%arg9, %arg10, %arg11) - scale(%arg12 : f32, %arg13 : f32) - flags(0) : (f32, f32, f32) + flags(1) : (f32, f32, f32) func.return } diff --git a/compiler/src/iree/compiler/Dialect/VMVX/IR/VMVXOps.td b/compiler/src/iree/compiler/Dialect/VMVX/IR/VMVXOps.td index 438612606586..4868d80cf242 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/IR/VMVXOps.td +++ b/compiler/src/iree/compiler/Dialect/VMVX/IR/VMVXOps.td @@ -198,10 +198,6 @@ def VMVX_MatmulOp : VMVX_Op<"matmul"> { VMVX_Index:$n, VMVX_Index:$k, - // Scale factors. - VMVX_ElementType:$alpha, - VMVX_ElementType:$beta, - // Type and flag attributes. VMVX_ElementTypeAttr:$lhs_type, VMVX_ElementTypeAttr:$rhs_type, @@ -214,7 +210,6 @@ def VMVX_MatmulOp : VMVX_Op<"matmul"> { `rhs` `` `(` $rhs_buffer `offset` $rhs_offset `row_stride` $rhs_row_stride `:` type($rhs_buffer)`)` `out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)` `mnk` `` `(` $m `,` $n `,` $k `)` - `scale` `` `(` $alpha `:` type($alpha) `,` $beta `:` type($beta) `)` `flags` `` `(` $flags `)` `:` `(` $lhs_type `,` $rhs_type `,` $out_type `)` attr-dict diff --git a/compiler/src/iree/compiler/Dialect/VMVX/vmvx.imports.mlir b/compiler/src/iree/compiler/Dialect/VMVX/vmvx.imports.mlir index 92d226c44252..ca388f6dcfd7 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/vmvx.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/VMVX/vmvx.imports.mlir @@ -426,8 +426,6 @@ vm.import @matmul.f32f32f32( %m : i64, %n : i64, %k : i64, - %alpha : f32, - %beta : f32, %flags : i32 ) diff --git a/runtime/src/iree/modules/vmvx/exports.inl b/runtime/src/iree/modules/vmvx/exports.inl index a3963c9ced4e..270afd0dd239 100644 --- a/runtime/src/iree/modules/vmvx/exports.inl +++ b/runtime/src/iree/modules/vmvx/exports.inl @@ -40,11 +40,11 @@ EXPORT_FN("exp.2d.f32", iree_ukernel_x32u_expf_2d, ukernel_x32u_2d, rIIIrIIIII, EXPORT_FN("fill.2d.x32", iree_vmvx_fill2d_x32, fill2d_x32, irIIII, v) EXPORT_FN("floor.2d.f32", iree_ukernel_x32u_floorf_2d, ukernel_x32u_2d, rIIIrIIIII, v) EXPORT_FN("log.2d.f32", iree_ukernel_x32u_logf_2d, ukernel_x32u_2d, rIIIrIIIII, v) -EXPORT_FN("matmul.f32f32f32", iree_vmvx_matmul_f32f32f32, matmul_f32, rIIrIIrIIIIIffi, v) +EXPORT_FN("matmul.f32f32f32", iree_vmvx_matmul_f32f32f32, matmul_f32, rIIrIIrIIIIIi, v) // NOTE: must still be in alphabetical order with all other exports. -#if defined(IREE_HAVE_MMT4D_BUILTINS) +#if 0 // TODO: implement mmt4d ukernel EXPORT_FN("mmt4d.f32f32f32", iree_vmvx_mmt4d_f32f32f32, mmt4d_f32, rIIrIIrIIIIIffi, v) -#endif // IREE_HAVE_MMT4D_BUILTINS +#endif EXPORT_FN("mul.2d.f32", iree_ukernel_x32b_mulf_2d, ukernel_x32b_2d, rIIIrIIIrIIIII, v) EXPORT_FN("mul.2d.i32", iree_ukernel_x32b_muli_2d, ukernel_x32b_2d, rIIIrIIIrIIIII, v) EXPORT_FN("neg.2d.f32", iree_ukernel_x32u_negf_2d, ukernel_x32u_2d, rIIIrIIIII, v) diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c index 6eb7c05be460..0d0df0a3637a 100644 --- a/runtime/src/iree/modules/vmvx/module.c +++ b/runtime/src/iree/modules/vmvx/module.c @@ -18,16 +18,17 @@ // Include the ukernel support library so that we can use its implementations // as fixed-function components of the runtime. #include "iree/builtins/ukernel/elementwise.h" -#include "iree/builtins/ukernel/mmt4d.h" -// Temporary switch between a blas-style matmul kernel and an mmt4d style. -// We omit the latter for the time being since it is experimental and keeps -// the binary size down. -#define IREE_VMVX_USE_BLAS_MATMUL 1 +#if 0 // TODO: implement mmt4d ukernel +#include "iree/builtins/ukernel/mmt4d.h" +#endif #define IREE_VMVX_MODULE_VERSION_0_0 0x00000000u #define IREE_VMVX_MODULE_VERSION_LATEST IREE_VMVX_MODULE_VERSION_0_0 +// TODO: move these flags to a header file shared with compiler/. +#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1 + //===----------------------------------------------------------------------===// // Module type definitions //===----------------------------------------------------------------------===// @@ -490,8 +491,6 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_fill2d_x32, fill2d_x32, v) { // Exported matmul function definitions //===----------------------------------------------------------------------===// -#if IREE_VMVX_USE_BLAS_MATMUL - IREE_VMVX_ABI_FIXED_STRUCT(matmul_f32, rIIrIIrIIIIIffi, { iree_vm_ref_t lhs_ref; int64_t lhs_offset; @@ -505,8 +504,6 @@ IREE_VMVX_ABI_FIXED_STRUCT(matmul_f32, rIIrIIrIIIIIffi, { int64_t m; int64_t n; int64_t k; - float alpha; - float beta; int32_t flags; }); IREE_VMVX_ABI_DEFINE_SHIM(matmul_f32, v); @@ -538,36 +535,34 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_matmul_f32f32f32, matmul_f32, v) { iree_host_size_t N = (iree_host_size_t)args->n; iree_host_size_t K = (iree_host_size_t)args->k; - // TODO: define flags more robustly. - if (args->flags == 0) { - // Row major. - for (iree_host_size_t i = 0; i < M; ++i) { + // TODO: define flags more robustly + unsigned accumulate_flag = args->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE; + unsigned unhandled_flags = args->flags ^ accumulate_flag; + if (unhandled_flags) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unsupported matmul flags: 0x%x", unhandled_flags); + } + for (iree_host_size_t i = 0; i < M; ++i) { + for (iree_host_size_t j = 0; j < N; ++j) { + float* out_ptr = out + i * out_stride0 + j; + float acc = accumulate_flag ? *out_ptr : 0.f; for (iree_host_size_t k = 0; k < K; ++k) { - float apart = args->alpha * lhs[i * lhs_stride0 + k]; - for (iree_host_size_t j = 0; j < N; ++j) { - out[i * out_stride0 + j] += - args->beta * apart * rhs[k * rhs_stride0 + j]; - } + acc += lhs[i * lhs_stride0 + k] * rhs[k * rhs_stride0 + j]; } + *out_ptr = acc; } - } else { - IREE_TRACE_ZONE_END(z0); - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "unsupported matmul flags: %x", - (unsigned)args->flags); } IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } -#endif // IREE_VMVX_USE_BLAS_MATMUL - //===----------------------------------------------------------------------===// // MMT4D //===----------------------------------------------------------------------===// -#if !IREE_VMVX_USE_BLAS_MATMUL +#if 0 // TODO: implement mmt4d ukernel // NOTE: for demo purposes this reuses the matmul signature. IREE_VMVX_ABI_FIXED_STRUCT(mmt4d_f32, rIIrIIrIIIIIffi, { @@ -624,7 +619,7 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d_f32, v) { "unsupported mmt4d parameters"); } -#endif // !IREE_VMVX_USE_BLAS_MATMUL +#endif // TODO: implement mmt4d ukernel //===----------------------------------------------------------------------===// // VM module interface implementation