Skip to content

Commit

Permalink
Add shape inference pass option to analyze all functions (llvm#725)
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
Co-authored-by: chentong319 <chentong@us.ibm.com>
  • Loading branch information
3 people authored Jun 21, 2021
1 parent 362b3d3 commit 52639ea
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 44 deletions.
29 changes: 16 additions & 13 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,14 @@ class FrontendGenImpl {
return RankedTensorType::get(tensor_dims, elementType);
}

Type ConvertOnnxType(const std::string &onnx_name) {
auto it = value_info_map.find(onnx_name);
if (it != value_info_map.end()) {
return ImportTensorType(it->second);
} else {
return builder_.getNoneType();
llvm::Optional<Type> ConvertOnnxType(const std::string &onnx_name) {
if (options_.useOnnxModelTypes) {
auto it = value_info_map.find(onnx_name);
if (it != value_info_map.end()) {
return llvm::Optional<Type>(ImportTensorType(it->second));
}
}
return llvm::Optional<Type>();
}

/*!
Expand Down Expand Up @@ -531,8 +532,8 @@ class FrontendGenImpl {
// Optional outputs using empty string.
if (node.output()[i].empty()) {
outputTypes.emplace_back(builder_.getNoneType());
} else if (options_.useOnnxModelTypes) {
outputTypes.emplace_back(ConvertOnnxType(node.output(i)));
} else if (auto onnxModelType = ConvertOnnxType(node.output(i))) {
outputTypes.emplace_back(onnxModelType.getValue());
} else {
auto j = i;
// Variadic output is a single ODS result.
Expand Down Expand Up @@ -589,13 +590,15 @@ class FrontendGenImpl {
}
}
}
if (!options_.useOnnxModelTypes)
if (auto opWithTypeInference =
dyn_cast<ResultTypeInferenceOpInterface>(genericOp)) {
auto outTypes = opWithTypeInference.resultTypeInference();
for (int i = 0; i < node.output().size(); i++)
if (auto opWithTypeInference =
dyn_cast<ResultTypeInferenceOpInterface>(genericOp)) {
auto outTypes = opWithTypeInference.resultTypeInference();
for (int i = 0; i < node.output().size(); i++) {
auto result = genericOp->getOpResult(i);
if (!options_.useOnnxModelTypes || result.getType().isa<NoneType>())
genericOp->getOpResult(i).setType(outTypes[i]);
}
}

for (const auto &output : llvm::enumerate(node.output()))
frontend_symbols_.AddMapping(
Expand Down
3 changes: 2 additions & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class Pass;
/// Pass for rewriting inside frontend dialect.
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();

std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createShapeInferencePass(
bool analyzeAllFunctions = false);

std::unique_ptr<Pass> createConstPropONNXToONNXPass();

Expand Down
42 changes: 27 additions & 15 deletions src/Transform/ONNX/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
// shapes through function specialization.
//
//===----------------------------------------------------------------------===//

#include <regex>

#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
Expand All @@ -37,6 +39,7 @@ static SmallVector<mlir::FuncOp, 4> lookUpFuncsMatching(
});
return matchedFuncs;
}

/*!
* FunctionPass that performs shape inference by iterating over a list of
* candidate operations and propagating the shape information until the list
Expand All @@ -57,23 +60,31 @@ static SmallVector<mlir::FuncOp, 4> lookUpFuncsMatching(
*/
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass,
OperationPass<mlir::ModuleOp>> {
private:
bool analyzeAllFunctions;

public:
ShapeInferencePass(bool analyzeAllFunctions_)
: analyzeAllFunctions(analyzeAllFunctions_) {}

void runOnOperation() override {
auto module = getOperation();
auto matchedFuncs =
lookUpFuncsMatching(module, std::regex("[a-zA-Z0-9_]*main_graph"));
if (!matchedFuncs.empty()) {
for (auto func : matchedFuncs) {
if (failed(runShapeInferenceOn(func)))
signalPassFailure();
if (!analyzeAllFunctions) {
auto matchedFuncs =
lookUpFuncsMatching(module, std::regex("[a-zA-Z0-9_]*main_graph"));
if (!matchedFuncs.empty()) {
for (auto func : matchedFuncs) {
if (failed(runShapeInferenceOn(func)))
signalPassFailure();
}
return;
}
} else {
auto result = module.walk([&](FuncOp funcOp) -> WalkResult {
return runShapeInferenceOn(funcOp);
});
if (result.wasInterrupted())
signalPassFailure();
}
auto result = module.walk([&](FuncOp funcOp) -> WalkResult {
return runShapeInferenceOn(funcOp);
});
if (result.wasInterrupted())
signalPassFailure();
}

static LogicalResult runShapeInferenceOnRegion(mlir::Region &r) {
Expand All @@ -93,7 +104,7 @@ class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass,
op.emitError("shape inference failed");
return failure();
}
} else {
} else if (!llvm::dyn_cast<CallOpInterface>(op)) {
op.emitError("unable to infer shape of operation without shape "
"inference interface");
return failure();
Expand Down Expand Up @@ -165,6 +176,7 @@ class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass,
/*!
* Create a Shape Inference pass.
*/
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass(
bool analyzeAllFunctions) {
return std::make_unique<ShapeInferencePass>(analyzeAllFunctions);
}
43 changes: 28 additions & 15 deletions test/onnx2mlir/CustomFnTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"

#include "src/Builder/FrontendDialectTransformer.hpp"

#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"

using namespace std;
using namespace ONNX_NAMESPACE;

Expand Down Expand Up @@ -56,10 +62,20 @@ void check(ModelProto &model) {
options.useOnnxModelTypes = true;
onnx_mlir::ImportFrontendModel(model, context, module, options);

// TODO: use result?
mlir::LogicalResult res = module->verify();
module->dump();
std::cerr << std::endl;
mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
pm.addPass(mlir::createShapeInferencePass(true));
mlir::applyPassManagerCLOptions(pm);
if (mlir::failed(pm.run(*module))) {
module->dump();
std::cerr << "Error applying shape inference!\n";
return;
}

if (mlir::failed(module->verify())) {
module->dump();
std::cerr << "Error verifying module!\n";
return;
}
}

void testCustomFunTranslation() {
Expand All @@ -77,27 +93,24 @@ void testCustomFunTranslation() {

auto *x = graph->add_input();
x->set_name("x");
x->mutable_type()->mutable_tensor_type()->set_elem_type(elt_type);
auto *x_type = x->mutable_type()->mutable_tensor_type();
x_type->set_elem_type(elt_type);
auto *x_shape = x_type->mutable_shape();
x_shape->add_dim()->set_dim_value(10);

auto *y = graph->add_output();
y->set_name("y");
y->mutable_type()->mutable_tensor_type()->set_elem_type(elt_type);
auto *y_type = y->mutable_type()->mutable_tensor_type();
y_type->set_elem_type(elt_type);
auto *y_shape = y_type->mutable_shape();
y_shape->add_dim()->set_dim_value(10);

auto *node = graph->add_node();
node->add_input("x");
node->add_output("y");
node->set_op_type("SquareFn");
node->set_name("node1");

auto *t = graph->add_value_info();
t->set_name("t");
t->mutable_type()->mutable_tensor_type()->set_elem_type(elt_type);

node = graph->add_node();
node->add_input("x");
node->add_output("t");
node->set_op_type("SquareFn");

check(model_proto);
}

Expand Down

0 comments on commit 52639ea

Please sign in to comment.