Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum decomposition #1537

Merged
merged 41 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a737e82
add shape inference for Einsum
sorenlassen May 25, 2022
3d6a515
Added skeleton verifier for Einsum
sorenlassen May 26, 2022
d03365c
Implemented ONNXEinsumOp::verify()
sorenlassen Jul 10, 2022
e2110c4
Replaced ONNXEinsumOpShapeHelper with einsum::inferOutputShape()
sorenlassen Jul 10, 2022
6a63135
Added Einsum decomposition to DecomposeONNXToONNXPass
sorenlassen May 31, 2022
91a4730
ONNXEinsumOp verify(), inferShapes() unit test
sorenlassen Jul 10, 2022
c730c33
Test Einsum shape inference with dynamic dimensions
sorenlassen Jul 10, 2022
1f0b830
Got rid of templates in TestONNXEinsumOp
sorenlassen Jul 11, 2022
91856c6
Beginnings of Decomposer
sorenlassen Jul 11, 2022
15a7c37
Implemented Decomposer
sorenlassen Jul 11, 2022
fe2e8bd
using einsum::Shape/Subscripts to become more concise
sorenlassen Jul 11, 2022
9db4b27
Fixed matmul/mul/reduce bug
sorenlassen Jul 11, 2022
07d6ece
Cleaned up Decomposer::diagonal() implementation
sorenlassen Jul 11, 2022
dda5443
Lifted helper functions out of Decomposer class
sorenlassen Jul 11, 2022
185a364
Simplified set insertion
sorenlassen Jul 11, 2022
8df314e
Added first Einsum decomposition lit test
sorenlassen Jul 11, 2022
a251d95
Separated out Einsum shape inference lit tests
sorenlassen Jul 11, 2022
fc23bc1
Einsum diagonal decomposition lit test
sorenlassen Jul 11, 2022
f883aa7
Implemented matmul() properly
sorenlassen Jul 12, 2022
9870a24
clang formatted everything
sorenlassen Jul 12, 2022
2e49575
Lit test Einsum shape inference errors
sorenlassen Jul 12, 2022
29eb0fd
Lit test the diagonal mask constant
sorenlassen Jul 12, 2022
1b40653
Lit test Einsum matmul decomposition
sorenlassen Jul 12, 2022
a476d5e
Lit test Einsum decomposition failures
sorenlassen Jul 12, 2022
56331df
Lit test Einsum decomposition some more
sorenlassen Jul 12, 2022
c04a7af
Remove noise accidentally added to onnx_shape_inference.mlir
sorenlassen Jul 13, 2022
09d0824
Updated Einsum support in SupportedONNXOps-cpu.md
sorenlassen Jul 13, 2022
ce678f7
Added error messages to assertions
sorenlassen Jul 13, 2022
015ee31
assert messages, concrete types instead of auto
sorenlassen Jul 13, 2022
1907e17
Use MultiDialectBuilder<OnnxBuilder> in DecomposeEinsum
sorenlassen Jul 13, 2022
ba9f5d2
Make better use of getZeroAttr()
sorenlassen Jul 14, 2022
f5d696c
Made Decomposer::remove() private
sorenlassen Jul 14, 2022
29733fa
Simplified DecomposeEinsum helper methods
sorenlassen Jul 14, 2022
fb07d0d
Added OnnxBuilder::squeeze
sorenlassen Jul 14, 2022
e4ede01
Added OnnxBuilder::unsqueeze
sorenlassen Jul 14, 2022
e09268f
Added OnnxBuilder::where
sorenlassen Jul 14, 2022
9fd8089
Removed toMemref() from OnnxBuilder::matmul()
sorenlassen Jul 14, 2022
3d519c0
Added OnnxBuilder::reduceSum
sorenlassen Jul 14, 2022
b2797dc
explicitly calling toMemref when needed, and removing matmulToMemref
sorenlassen Jul 14, 2022
360ac52
Removed lit test for Einsum shape inference errors
sorenlassen Jul 14, 2022
f35c2f2
Merge branch 'main' into einsum-decomp
tungld Jul 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--- Automatically generated, do not edit. -->
<!--- python documentOps.py --arch cpu --input /workdir/onnx-mlir/test/backend/inference_backend.py --path /workdir/onnx-mlir/utils --notes --unsupported -->
<!--- python documentOps.py --arch cpu --input test/backend/inference_backend.py --path utils --notes --unsupported -->

# Supported ONNX Operation for Target *cpu*.

Expand Down Expand Up @@ -49,7 +49,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **Div** |14 |No support for short integers. | |
| **Dropout** |13 |Does not support masked and training. | |
| **DynamicQuantizeLinear** |11 | | |
| **EinSum** |unsupported | | |
| **Einsum** |12 |Limited to the types supported by ReduceSum and MatMul (which we decompose to in most cases) which exclude integers with width < 32. | |
| **Elu** |6 | | |
| **Equal** |13 | | |
| **Erf** |13 | | |
Expand Down
15 changes: 10 additions & 5 deletions src/Conversion/ONNXToKrnl/RNN/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
MemRefType::get({batchSize, 3 * hiddenSize}, elementType);

// Common matrix multiplications.
Value XtWT = create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT);
Value XtWT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT));
Value one = create.math.constant(elementType, 1);

// Lower and upper bounds derived from Ht tensor.
Expand All @@ -437,7 +438,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
// Ht = (1 - zt) (.) ht + zt (.) Ht-1"
// In this case, we can do all matrix multiplications first, then fuse all
// element-wise computations into a single nested loop.
Value HtRT = create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT);
Value HtRT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT));

// Do element-wise computations. Fuse them into a single nested loop.
ValueRange loops = create.krnl.defineLoops(htRank);
Expand Down Expand Up @@ -509,8 +511,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
// In this case, besides computing matrix multiplications, we need to
// compute rt and (rt (.) Ht-1) first, then fuse the remaining element-wise
// computations into a single nested loop.
Value HtRz = create.onnx.matmul(matrixType, Ht, weightPack.Rz);
Value HtRr = create.onnx.matmul(matrixType, Ht, weightPack.Rr);
Value HtRz =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Rz));
Value HtRr =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Rr));
Value rt, rtHt;
if (hasAllConstantDimensions(matrixType)) {
rt = insertAllocAndDealloc(matrixType, loc, rewriter, false);
Expand Down Expand Up @@ -553,7 +557,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
});

// Emit (rt (.) Ht-1)*(Rh^T)
Value rtHtRh = create.onnx.matmul(matrixType, rtHt, weightPack.Rh);
Value rtHtRh = create.onnx.toMemref(
create.onnx.matmul(matrixType, rtHt, weightPack.Rh));

// Do element-wise computations. Fuse them into a single nested loop.
ValueRange loops2 = create.krnl.defineLoops(htRank);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,10 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
// Xt * (Wi^T ++ Wo^T ++ Wf^T ++ Wc^T)
// Ht * (Ri^T ++ Ro^T ++ Rf^T ++ Rc^T)
// where '++' is matrix concatenation.
Value XtWT = create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT);
Value HtRT = create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT);
Value XtWT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT));
Value HtRT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT));

// Do element-wise computations. Fuse them into a single nested loop.
// Lower and upper bounds derived from Ht tensor.
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
unsigned htRank = matrixType.getRank();

// Do matrix multiplications.
Value XtWi = create.onnx.matmul(matrixType, Xt, weightPack.Wi);
Value HtRi = create.onnx.matmul(matrixType, Ht, weightPack.Ri);
Value XtWi =
create.onnx.toMemref(create.onnx.matmul(matrixType, Xt, weightPack.Wi));
Value HtRi =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Ri));

// Do element-wise computations. Fuse them into a single nested loop.
// Lower and upper bounds derived from Ht tensor.
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_onnx_mlir_library(OMONNXOps
ONNXDialect.cpp
ONNXOps.cpp
ONNXOpsHelper.cpp
ONNXEinsumOpHelper.cpp
ONNXTypes.cpp
Rewrite.cpp

Expand Down
26 changes: 25 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Value OnnxBuilder::matmul(Type Y, Value A, Value B, bool useGemm) const {
/*transB=*/
IntegerAttr::get(b.getIntegerType(64, /*isSigned=*/true),
APInt(64, 0, /*isSigned=*/true)));
return toMemref(b.create<ONNXMatMulOp>(loc, toTensor(Y), aValue, bValue));
return b.create<ONNXMatMulOp>(loc, toTensor(Y), aValue, bValue);
}

Value OnnxBuilder::min(ValueRange inputs) const {
Expand All @@ -109,11 +109,24 @@ Value OnnxBuilder::mul(Value A, Value B) const {
return b.create<ONNXMulOp>(loc, toTensor(A), toTensor(B));
}

Value OnnxBuilder::reduceSum(Type outputType, Value data, Value axes,
bool keepdims, bool noop_with_empty_axes) const {
int64_t i_keepdims = keepdims; // 0 if false, 1 if true
int64_t i_noop_with_empty_axes = noop_with_empty_axes; // ditto
return b.create<ONNXReduceSumOp>(loc, toTensor(outputType), toTensor(data),
toTensor(axes), i_keepdims, i_noop_with_empty_axes);
}

Value OnnxBuilder::reshape(Type outputType, Value input, Value shape) const {
return b.create<ONNXReshapeOp>(
loc, toTensor(outputType), toTensor(input), toTensor(shape));
}

Value OnnxBuilder::squeeze(Type outputType, Value data, Value axes) const {
return b.create<ONNXSqueezeOp>(
loc, toTensor(outputType), toTensor(data), toTensor(axes));
}

Value OnnxBuilder::sub(Value A, Value B) const {
assert((A.getType().cast<ShapedType>().getElementType() ==
B.getType().cast<ShapedType>().getElementType()) &&
Expand Down Expand Up @@ -162,4 +175,15 @@ Value OnnxBuilder::toMemref(Value input) const {
.getResult(0);
}

Value OnnxBuilder::unsqueeze(Type outputType, Value data, Value axes) const {
return b.create<ONNXUnsqueezeOp>(
loc, toTensor(outputType), toTensor(data), toTensor(axes));
}

Value OnnxBuilder::where(
Type outputType, Value condition, Value X, Value Y) const {
return b.create<ONNXWhereOp>(
loc, toTensor(outputType), toTensor(condition), toTensor(X), toTensor(Y));
}

} // namespace onnx_mlir
19 changes: 19 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,19 @@ struct OnnxBuilder : onnx_mlir::DialectBuilder {
// ONNXMulOp
mlir::Value mul(mlir::Value A, mlir::Value B) const;

// ONNXReduceSumOp
mlir::Value reduceSum(mlir::Type outputType, mlir::Value data,
mlir::Value axes, bool keepdims = true,
bool noop_with_empty_axes = false) const;

// ONNXReshapeOp
mlir::Value reshape(
mlir::Type outputType, mlir::Value input, mlir::Value shape) const;

// ONNXSqueezeOp
mlir::Value squeeze(
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;

// ONNXSubOp
mlir::Value sub(mlir::Value A, mlir::Value B) const;

Expand All @@ -66,8 +75,18 @@ struct OnnxBuilder : onnx_mlir::DialectBuilder {
mlir::Type toTensor(mlir::Type input) const;
// Convert a Value to MemrefType if it is of TensorType.
mlir::Value toMemref(mlir::Value input) const;

// ONNXTransposeOp
mlir::Value transpose(
mlir::Type outputType, mlir::Value input, mlir::ArrayAttr perm) const;

// ONNXUnsqueezeOp
mlir::Value unsqueeze(
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;

// ONNXWhereOp
mlir::Value where(mlir::Type outputType, mlir::Value condition, mlir::Value X,
mlir::Value Y) const;
};

// Recursive class specialized for OnnxBuilder refereed to as onnx.
Expand Down
Loading