Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy the VMVX ukernels matmul interface #10211

Merged
merged 2 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

// TODO: move these flags to a header file shared with runtime/.
#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1

namespace mlir {
namespace iree_compiler {

Expand Down Expand Up @@ -873,20 +876,6 @@ struct LinalgMatmulConversion
rhs = contract.rhs();
out = op.outputs().front();
}

Value getOneValue(PatternRewriter &rewriter) {
Location loc = op.getLoc();
Type elementType = out.getType().cast<MemRefType>().getElementType();
if (auto floatType = elementType.dyn_cast<FloatType>()) {
return rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(floatType, 1.0));
} else if (elementType.isa<IntegerType>()) {
return rewriter.create<arith::ConstantIntOp>(loc, 1, elementType);
}

assert(false && "unknown element type");
return nullptr;
}
};

LogicalResult matchAndRewrite(linalg::ContractionOpInterface op,
Expand All @@ -912,8 +901,7 @@ struct LinalgMatmulConversion
}

// Switch on contraction type.
if (info.contract.isRowMajorMatmul() ||
info.contract.isColumnMajorMatmul()) {
if (info.contract.isRowMajorMatmul()) {
if (succeeded(handleConformingMatmul2D(info, rewriter))) {
return success();
}
Expand All @@ -929,27 +917,14 @@ struct LinalgMatmulConversion
auto &lhsDesc = info.lhsAnal.getDesc(rewriter);
auto &rhsDesc = info.rhsAnal.getDesc(rewriter);
auto &outDesc = info.outAnal.getDesc(rewriter);
int flags = IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
// Determine m, n, k based on dims.
int flags = 0;
Value m, n, k;
if (info.contract.isRowMajorMatmul()) {
m = lhsDesc.sizes[0];
k = rhsDesc.sizes[0];
n = rhsDesc.sizes[1];
} else if (info.contract.isColumnMajorMatmul()) {
m = lhsDesc.sizes[0];
k = rhsDesc.sizes[1];
n = rhsDesc.sizes[0];
// TODO: Flag constants somewhere.
flags |= 1;
} else {
if (!info.contract.isRowMajorMatmul()) {
return failure();
}

// Alpha/beta: We always start the lowering with alpha/beta set to 1.
// Simplification patterns within VMVX will simplify this if possible.
Value alpha = info.getOneValue(rewriter);
Value beta = alpha;
Value m = lhsDesc.sizes[0];
Value k = rhsDesc.sizes[0];
Value n = rhsDesc.sizes[1];

auto lhsBuffer = lhsDesc.castToLinear(loc, rewriter);
auto rhsBuffer = rhsDesc.castToLinear(loc, rewriter);
Expand All @@ -965,8 +940,6 @@ struct LinalgMatmulConversion
outBuffer, outDesc.offset, outDesc.strides[0],
// m,n,k
m, n, k,
// alpha, beta
alpha, beta,
// flags
lhsDesc.getElementTypeAttr(), rhsDesc.getElementTypeAttr(),
outDesc.getElementTypeAttr(), rewriter.getI32IntegerAttr(flags));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@ func.func @fill2d(%arg0 : memref<384x128xf32>, %arg1 : f32) {
}

// CHECK-LABEL: @matmul_row_major
// CHECK-DAG: %[[SCALE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[BB0:.*]], %[[OFFSET0:.*]], %[[SIZES0:.*]]:2, %[[STRIDES0:.*]]:2 = vmvx.get_buffer_descriptor %arg0
// CHECK-DAG: %[[BB1:.*]], %[[OFFSET1:.*]], %[[SIZES1:.*]]:2, %[[STRIDES1:.*]]:2 = vmvx.get_buffer_descriptor %arg1
// CHECK-DAG: %[[BB2:.*]], %[[OFFSET2:.*]], %[[SIZES2:.*]]:2, %[[STRIDES2:.*]]:2 = vmvx.get_buffer_descriptor %arg2
// CHECK: vmvx.matmul lhs(%[[BB1]] offset %[[OFFSET1]] row_stride %[[STRIDES1]]#0 : !util.buffer)
// CHECK-SAME: rhs(%[[BB2]] offset %[[OFFSET2]] row_stride %[[STRIDES2]]#0 : !util.buffer)
// CHECK-SAME: out(%[[BB0]] offset %[[OFFSET0]] row_stride %[[STRIDES0]]#0 : !util.buffer)
// CHECK-SAME: mnk(%[[SIZES1]]#0, %[[SIZES2]]#1, %[[SIZES2]]#0) scale(%[[SCALE]] : f32, %[[SCALE]] : f32)
// CHECK-SAME: flags(0)
// CHECK-SAME: mnk(%[[SIZES1]]#0, %[[SIZES2]]#1, %[[SIZES2]]#0)
// CHECK-SAME: flags(1)
func.func @matmul_row_major(%arg0 : memref<64x64xf32>, %arg1 : memref<64x384xf32>, %arg2 : memref<384x64xf32>) {
linalg.matmul
ins(%arg1, %arg2 : memref<64x384xf32>, memref<384x64xf32>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,18 @@ func.func @matmul_f32f32f32(
// OUT
%arg6 : !util.buffer, %arg7 : index, %arg8 : index,
// SIZE
%arg9 : index, %arg10 : index, %arg11 : index,
// SCALE
%arg12 : f32, %arg13 : f32) {
%arg9 : index, %arg10 : index, %arg11 : index) {

// CHECK-DAG: %[[ZERO:.*]] = vm.const.i32.zero
// CHECK-DAG: %[[FLAGS:.*]] = vm.const.i32 1
// CHECK: vm.call @vmvx.matmul.f32f32f32(
// CHECK-SAME: %arg0, %arg1, %arg2,
// CHECK-SAME: %arg3, %arg4, %arg5,
// CHECK-SAME: %arg6, %arg7, %arg8,
// CHECK-SAME: %arg9, %arg10, %arg11, %arg12, %arg13, %[[ZERO]]) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, f32, f32, i32) -> ()
// CHECK-SAME: %arg9, %arg10, %arg11, %[[FLAGS]]) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, i32) -> ()
vmvx.matmul lhs(%arg0 offset %arg1 row_stride %arg2 : !util.buffer)
rhs(%arg3 offset %arg4 row_stride %arg5 : !util.buffer)
out(%arg6 offset %arg7 row_stride %arg8 : !util.buffer)
mnk(%arg9, %arg10, %arg11)
scale(%arg12 : f32, %arg13 : f32)
flags(0) : (f32, f32, f32)
flags(1) : (f32, f32, f32)
func.return
}
5 changes: 0 additions & 5 deletions compiler/src/iree/compiler/Dialect/VMVX/IR/VMVXOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,6 @@ def VMVX_MatmulOp : VMVX_Op<"matmul"> {
VMVX_Index:$n,
VMVX_Index:$k,

// Scale factors.
VMVX_ElementType:$alpha,
VMVX_ElementType:$beta,

// Type and flag attributes.
VMVX_ElementTypeAttr:$lhs_type,
VMVX_ElementTypeAttr:$rhs_type,
Expand All @@ -214,7 +210,6 @@ def VMVX_MatmulOp : VMVX_Op<"matmul"> {
`rhs` `` `(` $rhs_buffer `offset` $rhs_offset `row_stride` $rhs_row_stride `:` type($rhs_buffer)`)`
`out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)`
`mnk` `` `(` $m `,` $n `,` $k `)`
`scale` `` `(` $alpha `:` type($alpha) `,` $beta `:` type($beta) `)`
`flags` `` `(` $flags `)`
`:` `(` $lhs_type `,` $rhs_type `,` $out_type `)`
attr-dict
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Dialect/VMVX/vmvx.imports.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,6 @@ vm.import @matmul.f32f32f32(
%m : i64,
%n : i64,
%k : i64,
%alpha : f32,
%beta : f32,
%flags : i32
)

Expand Down
6 changes: 3 additions & 3 deletions runtime/src/iree/modules/vmvx/exports.inl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ EXPORT_FN("exp.2d.f32", iree_ukernel_x32u_expf_2d, ukernel_x32u_2d, rIIIrIIIII,
EXPORT_FN("fill.2d.x32", iree_vmvx_fill2d_x32, fill2d_x32, irIIII, v)
EXPORT_FN("floor.2d.f32", iree_ukernel_x32u_floorf_2d, ukernel_x32u_2d, rIIIrIIIII, v)
EXPORT_FN("log.2d.f32", iree_ukernel_x32u_logf_2d, ukernel_x32u_2d, rIIIrIIIII, v)
EXPORT_FN("matmul.f32f32f32", iree_vmvx_matmul_f32f32f32, matmul_f32, rIIrIIrIIIIIffi, v)
EXPORT_FN("matmul.f32f32f32", iree_vmvx_matmul_f32f32f32, matmul_f32, rIIrIIrIIIIIi, v)
// NOTE: must still be in alphabetical order with all other exports.
#if defined(IREE_HAVE_MMT4D_BUILTINS)
#if 0 // TODO: implement mmt4d ukernel
EXPORT_FN("mmt4d.f32f32f32", iree_vmvx_mmt4d_f32f32f32, mmt4d_f32, rIIrIIrIIIIIffi, v)
#endif // IREE_HAVE_MMT4D_BUILTINS
#endif
EXPORT_FN("mul.2d.f32", iree_ukernel_x32b_mulf_2d, ukernel_x32b_2d, rIIIrIIIrIIIII, v)
EXPORT_FN("mul.2d.i32", iree_ukernel_x32b_muli_2d, ukernel_x32b_2d, rIIIrIIIrIIIII, v)
EXPORT_FN("neg.2d.f32", iree_ukernel_x32u_negf_2d, ukernel_x32u_2d, rIIIrIIIII, v)
Expand Down
49 changes: 22 additions & 27 deletions runtime/src/iree/modules/vmvx/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
// Include the ukernel support library so that we can use its implementations
// as fixed-function components of the runtime.
#include "iree/builtins/ukernel/elementwise.h"
#include "iree/builtins/ukernel/mmt4d.h"

// Temporary switch between a blas-style matmul kernel and an mmt4d style.
// We omit the latter for the time being since it is experimental and keeps
// the binary size down.
#define IREE_VMVX_USE_BLAS_MATMUL 1
#if 0 // TODO: implement mmt4d ukernel
#include "iree/builtins/ukernel/mmt4d.h"
#endif

#define IREE_VMVX_MODULE_VERSION_0_0 0x00000000u
#define IREE_VMVX_MODULE_VERSION_LATEST IREE_VMVX_MODULE_VERSION_0_0

// TODO: move these flags to a header file shared with compiler/.
#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1

//===----------------------------------------------------------------------===//
// Module type definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -490,8 +491,6 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_fill2d_x32, fill2d_x32, v) {
// Exported matmul function definitions
//===----------------------------------------------------------------------===//

#if IREE_VMVX_USE_BLAS_MATMUL

IREE_VMVX_ABI_FIXED_STRUCT(matmul_f32, rIIrIIrIIIIIffi, {
iree_vm_ref_t lhs_ref;
int64_t lhs_offset;
Expand All @@ -505,8 +504,6 @@ IREE_VMVX_ABI_FIXED_STRUCT(matmul_f32, rIIrIIrIIIIIffi, {
int64_t m;
int64_t n;
int64_t k;
float alpha;
float beta;
int32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(matmul_f32, v);
Expand Down Expand Up @@ -538,36 +535,34 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_matmul_f32f32f32, matmul_f32, v) {
iree_host_size_t N = (iree_host_size_t)args->n;
iree_host_size_t K = (iree_host_size_t)args->k;

// TODO: define flags more robustly.
if (args->flags == 0) {
// Row major.
for (iree_host_size_t i = 0; i < M; ++i) {
// TODO: define flags more robustly
unsigned accumulate_flag = args->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
unsigned unhandled_flags = args->flags ^ accumulate_flag;
if (unhandled_flags) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unsupported matmul flags: 0x%x", unhandled_flags);
}
for (iree_host_size_t i = 0; i < M; ++i) {
for (iree_host_size_t j = 0; j < N; ++j) {
float* out_ptr = out + i * out_stride0 + j;
float acc = accumulate_flag ? *out_ptr : 0.f;
for (iree_host_size_t k = 0; k < K; ++k) {
float apart = args->alpha * lhs[i * lhs_stride0 + k];
for (iree_host_size_t j = 0; j < N; ++j) {
out[i * out_stride0 + j] +=
args->beta * apart * rhs[k * rhs_stride0 + j];
}
acc += lhs[i * lhs_stride0 + k] * rhs[k * rhs_stride0 + j];
}
*out_ptr = acc;
}
} else {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unsupported matmul flags: %x",
(unsigned)args->flags);
}

IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

#endif // IREE_VMVX_USE_BLAS_MATMUL

//===----------------------------------------------------------------------===//
// MMT4D
//===----------------------------------------------------------------------===//

#if !IREE_VMVX_USE_BLAS_MATMUL
#if 0 // TODO: implement mmt4d ukernel

// NOTE: for demo purposes this reuses the matmul signature.
IREE_VMVX_ABI_FIXED_STRUCT(mmt4d_f32, rIIrIIrIIIIIffi, {
Expand Down Expand Up @@ -624,7 +619,7 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d_f32, v) {
"unsupported mmt4d parameters");
}

#endif // !IREE_VMVX_USE_BLAS_MATMUL
#endif // TODO: implement mmt4d ukernel

//===----------------------------------------------------------------------===//
// VM module interface implementation
Expand Down