Skip to content

Commit

Permalink
Symbolic shape expressions in TCP dialect (cruise-automation#78)
Browse files Browse the repository at this point in the history
Support for TorchDynamo captured symbolic shape expressions was added to
Torch-MLIR recently (llvm/torch-mlir#3372). This
PR continues the work to lower these to TCP.

- [x] TCP op definitions, custom printer/parser/verifier
- [x] Dialect and conversion lit tests
- [x] Custom op python lit tests
- [x] Cleanup pass to remove bind ops
  • Loading branch information
sjain-stanford authored Jul 12, 2024
1 parent 0bfd510 commit e6eaf43
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 3 deletions.
13 changes: 11 additions & 2 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ td_library(
],
includes = ["include"],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
Expand Down Expand Up @@ -58,11 +59,17 @@ gentbl_cc_library(
"include/mlir-tcp/Dialect/IR/TcpDialect.cpp.inc",
),
(
["-gen-attrdef-decls"],
[
"-gen-attrdef-decls",
"-attrdefs-dialect=tcp",
],
"include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc",
),
(
["-gen-attrdef-defs"],
[
"-gen-attrdef-defs",
"-attrdefs-dialect=tcp",
],
"include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc",
),
(
Expand Down Expand Up @@ -142,6 +149,7 @@ gentbl_cc_library(
cc_library(
name = "TcpDialectPasses",
srcs = [
"lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp",
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
"lib/Dialect/Transforms/FusionPatterns.cpp",
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
Expand All @@ -151,6 +159,7 @@ cc_library(
"lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp",
],
hdrs = [
"include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h",
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h",
Expand Down
66 changes: 66 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define TCP_OPS

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"

include "mlir-tcp/Dialect/IR/TcpBase.td"
include "mlir-tcp/Dialect/IR/TcpEnums.td"
Expand Down Expand Up @@ -640,4 +641,69 @@ def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, Sa
let assemblyFormat = "$in `starts` `(` $starts `)` `sizes` `(` $sizes `)` `strides` `(` $strides `)` attr-dict `:` type($in) `->` type($out)";
}

//===----------------------------------------------------------------------===//
// Symbolic shape modeling ops for TorchDynamo frontend.
//===----------------------------------------------------------------------===//

def Tcp_SymbolicIntOp : Tcp_Op<"symbolic_int", [Pure]> {

let summary = "Symbolic int representing a dynamic dimension";

let description = [{
The `tcp.symbolic_int` operation captures a dynamic dimension on the
global function arguments. It associates the shape symbols (i.e. "s0",
"s1") with the global SSA values (i.e. `%0`, `%1`) that is then
referenced to bind shapes on op results.

Additionally, the operation annotates `min_val` and `max_val` attributes
denoting the range constraints for the dynamic dimension. This may be
useful for modeling runtime shape guards, or compile-time optimizations
based on the shape bounds (min, opt, max) on results of ops / regions.

Example:
```
%0 = tcp.symbolic_int "s0" {min_val = 5, max_val = 10} : i64
%1 = tcp.symbolic_int "s1" {min_val = 2, max_val = 20} : i64
```
}];

let arguments = (ins
StrAttr:$symbol_name,
I64Attr:$min_val,
I64Attr:$max_val
);
let results = (outs
AnySignlessInteger:$result
);
let assemblyFormat = [{
$symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result)
}];
}

def Tcp_BindSymbolicShapeOp : Tcp_Op<"bind_symbolic_shape", []> {
let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols";
let description = [{
The `tcp.bind_symbolic_shape` operation binds shape expressions
useful to compute the dynamic dimensions of a tensor. It takes a
variadic of SSA symbols that map 1:1 to the local symbols declared
in the affine map. The affine map contains a list of affine shape
expressions for each dim where the terminals are from the declared
symbols.

Example:
```
tcp.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
tcp.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : tensor<?x?x3xf32>
```
}];
let arguments = (ins
Tcp_Tensor:$operand,
Variadic<AnySignlessInteger>:$shape_symbols,
Builtin_AffineMapAttr:$shape_expressions
);
let results = (outs);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // TCP_OPS
22 changes: 22 additions & 0 deletions include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir::tcp {

std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createDropSymbolicShapeOpsPass();

} // namespace mlir::tcp
7 changes: 7 additions & 0 deletions include/mlir-tcp/Dialect/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ def DecomposeTensorOps : Pass<"decompose-tensor-ops", "func::FuncOp"> {
let constructor = "mlir::tcp::createDecomposeTensorOpsPass()";
}

// \brief This pass removes any unused symbolic shape ops.
// We discard remaining bind shape ops during backend lowering.
def DropSymbolicShapeOps : Pass<"drop-symbolic-shape-ops", "func::FuncOp"> {
let summary = "Removes all remaining symbolic shape ops.";
let constructor = "mlir::tcp::createDropSymbolicShapeOpsPass()";
}

#endif // TCP_PASSES
39 changes: 39 additions & 0 deletions lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,51 @@ class ConvertAtenZerosOnesLikeOp : public OpConversionPattern<AtenOpT> {
}
};

class ConvertSymbolicIntOp : public OpConversionPattern<Torch::SymbolicIntOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(Torch::SymbolicIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = getTypeConverter()->convertType(op.getType());

rewriter.replaceOpWithNewOp<tcp::SymbolicIntOp>(
op, resultType, adaptor.getSymbolNameAttr(), adaptor.getMinValAttr(),
adaptor.getMaxValAttr());
return success();
}
};

class ConvertBindSymbolicShapeOp
: public OpConversionPattern<Torch::BindSymbolicShapeOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(Torch::BindSymbolicShapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<tcp::BindSymbolicShapeOp>(
op, adaptor.getOperand(), adaptor.getShapeSymbols(),
adaptor.getShapeExpressionsAttr());
return success();
}
};

} // namespace

void torch_to_tcp::populateMiscPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) {

torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertSymbolicIntOp,
Torch::SymbolicIntOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertBindSymbolicShapeOp,
Torch::BindSymbolicShapeOp>(
typeConverter, patterns, target, convertTorchOpsSet);

#define INSERT_ATEN_MISC_OP_PATTERN(AtenOp) \
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
typeConverter, patterns, target, convertTorchOpsSet)
Expand Down
63 changes: 63 additions & 0 deletions lib/Dialect/IR/TcpOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,67 @@ LogicalResult CastOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// BindSymbolicShapeOp
//===----------------------------------------------------------------------===//

//
// tcp.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] ->
// (s0, s1 * 2 + s2, 3)> : tensor<?x?x3xf32>
//

ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
SmallVector<OpAsmParser::UnresolvedOperand> shapeSymbols;
AffineMapAttr shapeExpressions;
Type operandType;

if (parser.parseOperand(operand) || parser.parseComma() ||
parser.parseLSquare() || parser.parseOperandList(shapeSymbols) ||
parser.parseRSquare() || parser.parseComma() ||
parser.parseAttribute(shapeExpressions, "shape_expressions",
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(operandType)) {
return failure();
}

if (parser.resolveOperand(operand, operandType, result.operands) ||
parser.resolveOperands(shapeSymbols,
parser.getBuilder().getType<IntegerType>(64),
result.operands)) {
return failure();
}

return success();
}

// Use a custom printer here to avoid the AffineMap from getting hoisted
// when printed. This makes it so the AffineMap is printed inline with the op.
void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
p << " " << getOperand() << ", [";
llvm::interleaveComma(getShapeSymbols(), p);
p << "], "
<< "affine_map<" << getShapeExpressions().getValue() << ">";
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/{"shape_expressions"});
p << " : " << getOperand().getType();
}

LogicalResult BindSymbolicShapeOp::verify() {
if (getShapeSymbols().empty())
return emitOpError() << "requires non-empty shapeSymbols";

for (auto symbol : getShapeSymbols()) {
Operation *definingOp = symbol.getDefiningOp();
if (!isa<SymbolicIntOp>(definingOp)) {
return emitOpError()
<< "shape symbol must be produced by a SymbolicIntOp";
}
}

return success();
}

} // namespace mlir::tcp
59 changes: 59 additions & 0 deletions lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h"

#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir-tcp/Dialect/Transforms/Passes.h"

#include "./PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace mlir::tcp {

namespace {

class RemoveBindSymbolicShapeOps
: public OpRewritePattern<tcp::BindSymbolicShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};

class DropSymbolicShapeOpsPass
: public DropSymbolicShapeOpsBase<DropSymbolicShapeOpsPass> {
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);

patterns.add<RemoveBindSymbolicShapeOps>(context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createDropSymbolicShapeOpsPass() {
return std::make_unique<DropSymbolicShapeOpsPass>();
}

} // namespace mlir::tcp
1 change: 1 addition & 0 deletions lib/Dialect/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Dialect/Transforms/Passes.h"
#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h"
#include "mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h"
#include "mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h"
#include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h"
Expand Down
4 changes: 4 additions & 0 deletions lib/Pipeline/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir-tcp/Conversion/TcpToTensor/TcpToTensor.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h"
#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h"
#include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h"
#include "mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h"

Expand Down Expand Up @@ -65,6 +66,9 @@ static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) {
}

static void createTcpToLlvmPipeline(OpPassManager &pm) {
// Drop TCP symbolic shape ops for dynamic dims
pm.addNestedPass<func::FuncOp>(tcp::createDropSymbolicShapeOpsPass());

// TCP transformations.
pm.addNestedPass<func::FuncOp>(tcp::createDecomposeTensorOpsPass());

Expand Down
Loading

0 comments on commit e6eaf43

Please sign in to comment.