Skip to content

Commit

Permalink
Wrap ZHigh dialect into onnx_mlir::zhigh namespace (llvm#1221)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>
  • Loading branch information
Ettore Tiotto authored Mar 8, 2022
1 parent 0a40daa commit 09986d9
Show file tree
Hide file tree
Showing 30 changed files with 400 additions and 292 deletions.
13 changes: 8 additions & 5 deletions src/Accelerators/Accelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,27 @@
//
// =============================================================================
//
// Accelerator base class
// Accelerator base class.
//
//===----------------------------------------------------------------------===//

#include "src/Accelerators/Accelerator.hpp"
#include <iostream>
#include <vector>

namespace mlir {
namespace onnx_mlir {

std::vector<Accelerator *> *Accelerator::acceleratorTargets;

Accelerator::Accelerator() {
if (acceleratorTargets == NULL) {
if (acceleratorTargets == NULL)
acceleratorTargets = new std::vector<Accelerator *>();
}
}

Accelerator::~Accelerator() {}

std::vector<Accelerator *> *Accelerator::getAcceleratorList() {
return acceleratorTargets;
}

} // namespace mlir
} // namespace onnx_mlir
17 changes: 11 additions & 6 deletions src/Accelerators/Accelerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@
// Accelerator base class
//
//===----------------------------------------------------------------------===//

#pragma once

#include "include/onnx-mlir/Compiler/OMCompilerTypes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include <vector>

namespace mlir {
namespace onnx_mlir {

class Accelerator {
public:
Accelerator();
virtual ~Accelerator();
static std::vector<Accelerator *> *getAcceleratorList();
virtual bool isActive() = 0;
virtual void prepareAccelerator(mlir::OwningOpRef<ModuleOp> &module,
virtual bool isActive() const = 0;
virtual void prepareAccelerator(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::MLIRContext &context, mlir::PassManager &pm,
onnx_mlir::EmissionTargetType emissionTarget) = 0;
onnx_mlir::EmissionTargetType emissionTarget) const = 0;

private:
static std::vector<Accelerator *> *acceleratorTargets;
};
} // namespace mlir

} // namespace onnx_mlir
51 changes: 30 additions & 21 deletions src/Accelerators/NNPA/Compiler/DLCompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,32 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"

#include "Compiler/DLCompilerUtils.hpp"
#include "Dialect/ZHigh/ZHighOps.hpp"
#include "Dialect/ZLow/ZLowOps.hpp"
#include "Pass/DLCPasses.hpp"
#include "Support/OMDLCOptions.hpp"
#include "src/Accelerators/NNPA/Compiler/DLCompilerUtils.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/DLCPasses.hpp"
#include "src/Accelerators/NNPA/Support/OMDLCOptions.hpp"
#include "src/Compiler/CompilerUtils.hpp"

#define DEBUG_TYPE "DLCompiler"

using namespace std;
using namespace mlir;
using namespace onnx_mlir;

extern llvm::cl::OptionCategory OnnxMlirOptions;

llvm::cl::opt<DLCEmissionTargetType> dlcEmissionTarget(
llvm::cl::desc("[Optional] Choose Z-related target to emit "
"(once selected it will cancel the other targets):"),
llvm::cl::values(
clEnumVal(EmitZHighIR, "Lower model to ZHigh IR (ZHigh dialect)"),
clEnumVal(EmitZLowIR, "Lower model to ZLow IR (ZLow dialect)"),
clEnumVal(EmitZNONE, "Do not emit Z-related target (default)")),
llvm::cl::init(EmitZNONE), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::values(clEnumVal(DLCEmissionTargetType::EmitZHighIR,
"Lower model to ZHigh IR (ZHigh dialect)"),
clEnumVal(DLCEmissionTargetType::EmitZLowIR,
"Lower model to ZLow IR (ZLow dialect)"),
clEnumVal(DLCEmissionTargetType::EmitZNONE,
"Do not emit Z-related target (default)")),
llvm::cl::init(DLCEmissionTargetType::EmitZNONE),
llvm::cl::cat(OnnxMlirOptions));

llvm::cl::list<std::string> execNodesOnCpu{"execNodesOnCpu",
llvm::cl::desc("Comma-separated list of node names in an onnx graph. The "
Expand All @@ -52,15 +58,17 @@ llvm::cl::list<std::string> execNodesOnCpu{"execNodesOnCpu",
llvm::cl::CommaSeparated, llvm::cl::ZeroOrMore,
llvm::cl::cat(OnnxMlirOptions)};

namespace onnx_mlir {

void addONNXToZHighPasses(
mlir::PassManager &pm, ArrayRef<std::string> execNodesOnCpu) {
pm.addPass(mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createShapeInferencePass());
pm.addNestedPass<FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
// Add instrumentation for Onnx Ops in the same way as onnx-mlir.
if (instrumentZHighOps == "" || instrumentZHighOps == "NONE")
pm.addNestedPass<FuncOp>(onnx_mlir::createInstrumentONNXPass());
pm.addPass(mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
// generated.
Expand All @@ -86,7 +94,7 @@ void addONNXToZHighPasses(
void addZHighToZLowPasses(mlir::PassManager &pm, int optLevel) {
// Add instrumentation for ZHigh Ops
pm.addNestedPass<FuncOp>(mlir::createInstrumentZHighPass());
pm.addPass(mlir::createZHighToZLowPass(optLevel));
pm.addPass(onnx_mlir::zhigh::createZHighToZLowPass(optLevel));
pm.addNestedPass<FuncOp>(onnx_mlir::createLowerKrnlShapePass());
pm.addNestedPass<FuncOp>(onnx_mlir::createDisconnectKrnlDimFromAllocPass());
pm.addPass(mlir::memref::createNormalizeMemRefsPass());
Expand Down Expand Up @@ -122,21 +130,20 @@ void addPassesDLC(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
// InputIRLevelType inputIRLevel = determineInputIRLevel(module);

// LLVM_DEBUG(llvm::dbgs() << "Adding DLC passes" << std::endl;);
if (emissionTarget >= onnx_mlir::EmitONNXIR) {
if (emissionTarget >= EmitONNXIR)
addONNXToMLIRPasses(pm);
}

if (emissionTarget >= onnx_mlir::EmitMLIR) {
// Lower zAIU-compatible ONNX ops to ZHigh dialect where possible.
addONNXToZHighPasses(pm, execNodesOnCpu);

if (dlcEmissionTarget >= EmitZHighIR)
emissionTarget = onnx_mlir::EmitMLIR;
if (dlcEmissionTarget >= DLCEmissionTargetType::EmitZHighIR)
emissionTarget = EmitMLIR;
else {
pm.addPass(mlir::createCanonicalizerPass());
// Add instrumentation for remaining Onnx Ops
if (instrumentZHighOps != "" && instrumentZHighOps != "NONE")
pm.addNestedPass<FuncOp>(onnx_mlir::createInstrumentONNXPass());
pm.addNestedPass<FuncOp>(createInstrumentONNXPass());
// Lower all ONNX and ZHigh ops.
std::string optStr = getCompilerOption(OptionKind::CompilerOptLevel);
OptLevel optLevel = OptLevel::O0;
Expand All @@ -151,16 +158,18 @@ void addPassesDLC(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
addZHighToZLowPasses(pm, optLevel); // Constant folding for std.alloc.
pm.addNestedPass<FuncOp>(mlir::createFoldStdAllocPass());

if (dlcEmissionTarget >= EmitZLowIR)
emissionTarget = onnx_mlir::EmitMLIR;
if (dlcEmissionTarget >= DLCEmissionTargetType::EmitZLowIR)
emissionTarget = EmitMLIR;
else {
// Partially lower Krnl ops to Affine dialect.
addKrnlToAffinePasses(pm);
}
}
}

if (emissionTarget >= onnx_mlir::EmitLLVMIR)
if (emissionTarget >= EmitLLVMIR)
// Lower the remaining Krnl and all ZLow ops to LLVM dialect.
addAllToLLVMPasses(pm);
}

} // namespace onnx_mlir
12 changes: 8 additions & 4 deletions src/Accelerators/NNPA/Compiler/DLCompilerUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===-------------------------- DLCompilerUtils.hpp
//------------------------===//
//===-------------------------- DLCompilerUtils.hpp -----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/Pass/PassManager.h"
#include "src/Compiler/CompilerUtils.hpp"
#include "src/Support/OMOptions.hpp"

enum DLCEmissionTargetType {
namespace onnx_mlir {

enum class DLCEmissionTargetType {
EmitZNONE,
EmitZLowIR,
EmitZHighIR,
Expand All @@ -39,5 +41,7 @@ void addPassesDLC(mlir::OwningOpRef<mlir::ModuleOp> &module,
int compileModuleDLC(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::MLIRContext &context, std::string outputBaseName,
onnx_mlir::EmissionTargetType emissionTarget,
DLCEmissionTargetType dlcEmissionTarget = EmitZNONE,
DLCEmissionTargetType dlcEmissionTarget = DLCEmissionTargetType::EmitZNONE,
mlir::ArrayRef<std::string> execNodesOnCpu = mlir::ArrayRef<std::string>());

} // namespace onnx_mlir
21 changes: 13 additions & 8 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//====------ ONNXToZHigh.cpp - ONNX dialect to ZHigh lowering -------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -26,6 +26,8 @@ using namespace mlir;
// LSTM/GRU specific functions
//

namespace onnx_mlir {

ArrayAttr getLSTMGRUBiasSplitShape(
Location loc, PatternRewriter &rewriter, ArrayRef<int64_t> shapeR) {
int64_t hiddenSize = shapeR[2];
Expand Down Expand Up @@ -73,7 +75,7 @@ Value getLSTMGRUZDNNWeightFromONNXWeight(
Value o_gate = splitOp.getResults()[1];
Value f_gate = splitOp.getResults()[2];
Value c_gate = splitOp.getResults()[3];
stickForOp = rewriter.create<ZHighStickForLSTMOp>(
stickForOp = rewriter.create<zhigh::ZHighStickForLSTMOp>(
loc, f_gate, i_gate, c_gate, o_gate);
} else { // GRU
SmallVector<Type, 3> splitTypes(splitNum, splitType);
Expand All @@ -83,7 +85,7 @@ Value getLSTMGRUZDNNWeightFromONNXWeight(
Value r_gate = splitOp.getResults()[1];
Value h_gate = splitOp.getResults()[2];
stickForOp =
rewriter.create<ZHighStickForGRUOp>(loc, z_gate, r_gate, h_gate);
rewriter.create<zhigh::ZHighStickForGRUOp>(loc, z_gate, r_gate, h_gate);
}
return stickForOp;
}
Expand Down Expand Up @@ -153,8 +155,9 @@ Value getLSTMGRUGetYc(
Value noneValue;
if (isNoneType(resYc))
return noneValue;
ZHighUnstickOp unstickOp =
rewriter.create<ZHighUnstickOp>(loc, val.getType(), val);

auto unstickOp =
rewriter.create<zhigh::ZHighUnstickOp>(loc, val.getType(), val);
return rewriter.create<ONNXSqueezeV11Op>(
loc, resYc.getType(), unstickOp.getResult(), rewriter.getI64ArrayAttr(0));
}
Expand Down Expand Up @@ -270,7 +273,7 @@ void ONNXToZHighLoweringPass::runOnOperation() {

// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<ONNXDialect, ZHighDialect, KrnlOpsDialect,
target.addLegalDialect<ONNXDialect, zhigh::ZHighDialect, KrnlOpsDialect,
StandardOpsDialect, arith::ArithmeticDialect>();

// Combined ONNX ops to ZHigh lowering.
Expand Down Expand Up @@ -330,11 +333,13 @@ void ONNXToZHighLoweringPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<Pass> mlir::createONNXToZHighPass() {
std::unique_ptr<Pass> createONNXToZHighPass() {
return std::make_unique<ONNXToZHighLoweringPass>();
}

std::unique_ptr<Pass> mlir::createONNXToZHighPass(
std::unique_ptr<Pass> createONNXToZHighPass(
mlir::ArrayRef<std::string> execNodesOnCpu) {
return std::make_unique<ONNXToZHighLoweringPass>(execNodesOnCpu);
}

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
//
//===----------------------------------------------------------------------===//

#include "Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "Dialect/ZHigh/ZHighOps.hpp"
#include "Pass/DLCPasses.hpp"

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Pass/DLCPasses.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

using namespace mlir;

namespace onnx_mlir {

/// Calculate sqrt(var + epsilon) for batchnorm op A.
/// A = scale / sqrt(var + epsilon)
Value getSqrtResultBatchNormA(
Expand All @@ -58,7 +59,6 @@ Value getSqrtResultBatchNormA(
// Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh.
//===----------------------------------------------------------------------===//

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "Conversion/ONNXToZHigh/RewriteONNXForZHigh.inc"

Expand All @@ -72,16 +72,13 @@ struct RewriteONNXForZHighPass
}

RewriteONNXForZHighPass() = default;
RewriteONNXForZHighPass(const RewriteONNXForZHighPass &pass) {}
RewriteONNXForZHighPass(mlir::ArrayRef<std::string> execNodesOnCpu) {
this->execNodesOnCpu = execNodesOnCpu;
}
RewriteONNXForZHighPass(mlir::ArrayRef<std::string> execNodesOnCpu)
: execNodesOnCpu(execNodesOnCpu) {}
void runOnOperation() final;

public:
mlir::ArrayRef<std::string> execNodesOnCpu = mlir::ArrayRef<std::string>();
};
} // end anonymous namespace.

void RewriteONNXForZHighPass::runOnOperation() {
ModuleOp module = getOperation();
Expand All @@ -92,7 +89,8 @@ void RewriteONNXForZHighPass::runOnOperation() {

// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<ONNXDialect, ZHighDialect, StandardOpsDialect>();
target
.addLegalDialect<ONNXDialect, zhigh::ZHighDialect, StandardOpsDialect>();

// Single ONNX to ZHigh operation lowering.
RewritePatternSet patterns(&getContext());
Expand All @@ -111,11 +109,13 @@ void RewriteONNXForZHighPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<Pass> mlir::createRewriteONNXForZHighPass() {
std::unique_ptr<Pass> createRewriteONNXForZHighPass() {
return std::make_unique<RewriteONNXForZHighPass>();
}

std::unique_ptr<Pass> mlir::createRewriteONNXForZHighPass(
std::unique_ptr<Pass> createRewriteONNXForZHighPass(
mlir::ArrayRef<std::string> execNodesOnCpu) {
return std::make_unique<RewriteONNXForZHighPass>(execNodesOnCpu);
}

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#ifndef OP_BASE
include "src/Dialect/ONNX/ONNX.td"
include "Dialect/ZHigh/ZHighOps.td"
include "Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.td"
include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
#endif // OP_BASE

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
Loading

0 comments on commit 09986d9

Please sign in to comment.