Skip to content

Commit

Permalink
Einsum decomposition (#1537)
Browse files Browse the repository at this point in the history
This PR implements ONNXEinsumOp verify(), inferShapes(), and decomposes it into a set of supported ops (Constant, Squeeze, Unsqueeze, Mul, ReduceSum, Where, Transpose).

Signed-off-by: Soren Lassen <sorenlassen@gmail.com>
  • Loading branch information
sorenlassen authored Jul 15, 2022
1 parent 39d027a commit a1f2a25
Show file tree
Hide file tree
Showing 23 changed files with 1,626 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--- Automatically generated, do not edit. -->
<!--- python documentOps.py --arch cpu --input /workdir/onnx-mlir/test/backend/inference_backend.py --path /workdir/onnx-mlir/utils --notes --unsupported -->
<!--- python documentOps.py --arch cpu --input test/backend/inference_backend.py --path utils --notes --unsupported -->

# Supported ONNX Operation for Target *cpu*.

Expand Down Expand Up @@ -49,7 +49,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **Div** |14 |No support for short integers. | |
| **Dropout** |13 |Does not support masked and training. | |
| **DynamicQuantizeLinear** |11 | | |
| **EinSum** |unsupported | | |
| **Einsum** |12 |Limited to the types supported by ReduceSum and MatMul (which we decompose to in most cases) which exclude integers with width < 32. | |
| **Elu** |6 | | |
| **Equal** |13 | | |
| **Erf** |13 | | |
Expand Down
15 changes: 10 additions & 5 deletions src/Conversion/ONNXToKrnl/RNN/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
MemRefType::get({batchSize, 3 * hiddenSize}, elementType);

// Common matrix multiplications.
Value XtWT = create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT);
Value XtWT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT));
Value one = create.math.constant(elementType, 1);

// Lower and upper bounds derived from Ht tensor.
Expand All @@ -437,7 +438,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
// Ht = (1 - zt) (.) ht + zt (.) Ht-1"
// In this case, we can do all matrix multiplications first, then fuse all
// element-wise computations into a single nested loop.
Value HtRT = create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT);
Value HtRT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT));

// Do element-wise computations. Fuse them into a single nested loop.
ValueRange loops = create.krnl.defineLoops(htRank);
Expand Down Expand Up @@ -509,8 +511,10 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
// In this case, besides computing matrix multiplications, we need to
// compute rt and (rt (.) Ht-1) first, then fuse the remaining element-wise
// computations into a single nested loop.
Value HtRz = create.onnx.matmul(matrixType, Ht, weightPack.Rz);
Value HtRr = create.onnx.matmul(matrixType, Ht, weightPack.Rr);
Value HtRz =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Rz));
Value HtRr =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Rr));
Value rt, rtHt;
if (hasAllConstantDimensions(matrixType)) {
rt = insertAllocAndDealloc(matrixType, loc, rewriter, false);
Expand Down Expand Up @@ -553,7 +557,8 @@ void calculateState<GruState, GruActivationPack, GruWeightPack, GruBiasPack>(
});

// Emit (rt (.) Ht-1)*(Rh^T)
Value rtHtRh = create.onnx.matmul(matrixType, rtHt, weightPack.Rh);
Value rtHtRh = create.onnx.toMemref(
create.onnx.matmul(matrixType, rtHt, weightPack.Rh));

// Do element-wise computations. Fuse them into a single nested loop.
ValueRange loops2 = create.krnl.defineLoops(htRank);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,10 @@ void calculateState<LstmState, LstmActivationPack, LstmWeightPack,
// Xt * (Wi^T ++ Wo^T ++ Wf^T ++ Wc^T)
// Ht * (Ri^T ++ Ro^T ++ Rf^T ++ Rc^T)
// where '++' is matrix concatenation.
Value XtWT = create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT);
Value HtRT = create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT);
Value XtWT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT));
Value HtRT = create.onnx.toMemref(
create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT));

// Do element-wise computations. Fuse them into a single nested loop.
// Lower and upper bounds derived from Ht tensor.
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ void calculateState<RnnState, RnnActivationPack, RnnWeightPack, RnnBiasPack>(
unsigned htRank = matrixType.getRank();

// Do matrix multiplications.
Value XtWi = create.onnx.matmul(matrixType, Xt, weightPack.Wi);
Value HtRi = create.onnx.matmul(matrixType, Ht, weightPack.Ri);
Value XtWi =
create.onnx.toMemref(create.onnx.matmul(matrixType, Xt, weightPack.Wi));
Value HtRi =
create.onnx.toMemref(create.onnx.matmul(matrixType, Ht, weightPack.Ri));

// Do element-wise computations. Fuse them into a single nested loop.
// Lower and upper bounds derived from Ht tensor.
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_onnx_mlir_library(OMONNXOps
ONNXDialect.cpp
ONNXOps.cpp
ONNXOpsHelper.cpp
ONNXEinsumOpHelper.cpp
ONNXTypes.cpp
Rewrite.cpp

Expand Down
26 changes: 25 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Value OnnxBuilder::matmul(Type Y, Value A, Value B, bool useGemm) const {
/*transB=*/
IntegerAttr::get(b.getIntegerType(64, /*isSigned=*/true),
APInt(64, 0, /*isSigned=*/true)));
return toMemref(b.create<ONNXMatMulOp>(loc, toTensor(Y), aValue, bValue));
return b.create<ONNXMatMulOp>(loc, toTensor(Y), aValue, bValue);
}

Value OnnxBuilder::min(ValueRange inputs) const {
Expand All @@ -109,11 +109,24 @@ Value OnnxBuilder::mul(Value A, Value B) const {
return b.create<ONNXMulOp>(loc, toTensor(A), toTensor(B));
}

Value OnnxBuilder::reduceSum(Type outputType, Value data, Value axes,
bool keepdims, bool noop_with_empty_axes) const {
int64_t i_keepdims = keepdims; // 0 if false, 1 if true
int64_t i_noop_with_empty_axes = noop_with_empty_axes; // ditto
return b.create<ONNXReduceSumOp>(loc, toTensor(outputType), toTensor(data),
toTensor(axes), i_keepdims, i_noop_with_empty_axes);
}

Value OnnxBuilder::reshape(Type outputType, Value input, Value shape) const {
return b.create<ONNXReshapeOp>(
loc, toTensor(outputType), toTensor(input), toTensor(shape));
}

Value OnnxBuilder::squeeze(Type outputType, Value data, Value axes) const {
return b.create<ONNXSqueezeOp>(
loc, toTensor(outputType), toTensor(data), toTensor(axes));
}

Value OnnxBuilder::sub(Value A, Value B) const {
assert((A.getType().cast<ShapedType>().getElementType() ==
B.getType().cast<ShapedType>().getElementType()) &&
Expand Down Expand Up @@ -162,4 +175,15 @@ Value OnnxBuilder::toMemref(Value input) const {
.getResult(0);
}

Value OnnxBuilder::unsqueeze(Type outputType, Value data, Value axes) const {
return b.create<ONNXUnsqueezeOp>(
loc, toTensor(outputType), toTensor(data), toTensor(axes));
}

Value OnnxBuilder::where(
Type outputType, Value condition, Value X, Value Y) const {
return b.create<ONNXWhereOp>(
loc, toTensor(outputType), toTensor(condition), toTensor(X), toTensor(Y));
}

} // namespace onnx_mlir
19 changes: 19 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,19 @@ struct OnnxBuilder : onnx_mlir::DialectBuilder {
// ONNXMulOp
mlir::Value mul(mlir::Value A, mlir::Value B) const;

// ONNXReduceSumOp
mlir::Value reduceSum(mlir::Type outputType, mlir::Value data,
mlir::Value axes, bool keepdims = true,
bool noop_with_empty_axes = false) const;

// ONNXReshapeOp
mlir::Value reshape(
mlir::Type outputType, mlir::Value input, mlir::Value shape) const;

// ONNXSqueezeOp
mlir::Value squeeze(
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;

// ONNXSubOp
mlir::Value sub(mlir::Value A, mlir::Value B) const;

Expand All @@ -66,8 +75,18 @@ struct OnnxBuilder : onnx_mlir::DialectBuilder {
mlir::Type toTensor(mlir::Type input) const;
// Convert a Value to MemrefType if it is of TensorType.
mlir::Value toMemref(mlir::Value input) const;

// ONNXTransposeOp
mlir::Value transpose(
mlir::Type outputType, mlir::Value input, mlir::ArrayAttr perm) const;

// ONNXUnsqueezeOp
mlir::Value unsqueeze(
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;

// ONNXWhereOp
mlir::Value where(mlir::Type outputType, mlir::Value condition, mlir::Value X,
mlir::Value Y) const;
};

// Recursive class specialized for OnnxBuilder refereed to as onnx.
Expand Down
Loading

0 comments on commit a1f2a25

Please sign in to comment.