Skip to content

Commit

Permalink
[MHLO] Support for dynamic shape in basic op conversion by introducin…
Browse files Browse the repository at this point in the history
…g CHLO dialect

Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
  • Loading branch information
wujiawei.jw committed Jul 31, 2022
1 parent 2c3b360 commit b8b7d22
Show file tree
Hide file tree
Showing 8 changed files with 751 additions and 846 deletions.
469 changes: 187 additions & 282 deletions lib/Conversion/TorchToMhlo/BasicOp.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo

DEPENDS
MhloDialect
ChloDialect
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Expand All @@ -19,6 +20,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRIR
MLIRPass
MhloDialect
ChloDialect
TorchMLIRTorchDialect
)

Expand Down
86 changes: 53 additions & 33 deletions lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
//
//===----------------------------------------------------------------------===//

#include "./MhloLegalizeUtils.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "./MhloLegalizeUtils.h"

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -114,9 +114,9 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
}

template <>
llvm::Optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
Operation *op, ArrayRef<double> vec,
ArrayRef<int64_t> shape) {
llvm::Optional<Value>
getConstTensor<double>(PatternRewriter &rewriter, Operation *op,
ArrayRef<double> vec, ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
Expand Down Expand Up @@ -146,7 +146,6 @@ template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape);


template <typename T>
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
const int64_t &intValue) {
Expand All @@ -163,20 +162,15 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
}

template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter,
Operation *op,
T val,
Type dtype,
llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(
dshape, dtype);
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}


LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
Expand All @@ -195,9 +189,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,

if (dtype.isa<mlir::FloatType>()) {
if (doBroadcast) {
mhloTensor = getSplatConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue),
dtype, dshape);
mhloTensor = getSplatConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
} else {
mhloTensor = mhlo::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
Expand All @@ -216,7 +209,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
if (doBroadcast) {
mhloTensor = getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
mhloTensor =
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
Expand All @@ -228,7 +222,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
if (doBroadcast) {
mhloTensor = getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
mhloTensor =
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
Expand All @@ -240,7 +235,6 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
return success();
}


LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
Expand All @@ -265,20 +259,33 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
return success();
}

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>();

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
Value input, TensorType outType) {
if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
}
return input;
}

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType) {
// Two tensors are “broadcastable” if the following rules hold:
// - Each tensor has at least one dimension.
// - When iterating over the dimension sizes, starting at the trailing dimension,
// the dimension sizes must either be equal, one of them is 1, or one of them
// does not exist.
Operation* op = input.getDefiningOp();
// - When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>();

if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType());
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
input =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
}

ArrayRef<int64_t> inShape = in_type.getShape();
Expand All @@ -298,7 +305,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
bcastDims.push_back(outPos);
do_bcast = true;
} else {
op->emitError("The size of tensor a (") << inDim << ")"
op->emitError("The size of tensor a (")
<< inDim << ")"
<< "must match the size of tensor b (" << outDim << ")"
<< "at non-singleton dimension " << inPos;
}
Expand All @@ -308,11 +316,23 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
return input;
}
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, rewriter.getI64Type()),
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()),
bcastDims);
auto bcast_op =
rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, input, bcast_attr);
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
input, bcast_attr);
return bcast_op.getResult();
}

Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
return rewriter
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
rewriter.getI64TensorAttr({}))
.getResult();
}
} // namespace mhlo
} // namespace mlir
} // namespace mlir
8 changes: 7 additions & 1 deletion lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);

Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
} // namespace mhlo
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
Expand All @@ -32,6 +33,7 @@ namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
Expand All @@ -40,7 +42,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
arith::ArithmeticDialect, Torch::TorchDialect>();

TypeConverter typeConverter;
Expand Down Expand Up @@ -68,4 +70,4 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>();
}
}
30 changes: 18 additions & 12 deletions lib/Dialect/TorchConversion/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
set(LinkedLibs MLIRIR
MLIRPass
MLIRFuncTransforms
TorchMLIRTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses
TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchToStd
TorchMLIRTorchToSCF
MLIRMemRefTransforms)

if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND LinkedLibs ChloPasses)
endif()

add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
Expand All @@ -17,15 +33,5 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRFuncTransforms
TorchMLIRTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses
TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchToStd
TorchMLIRTorchToSCF
MLIRMemRefTransforms
)
${LinkedLibs}
)
13 changes: 12 additions & 1 deletion lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
Expand Down Expand Up @@ -145,10 +146,20 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}

// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}

// Finish the type conversion from `torch` types to the types of the
// MHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
}
#endif
#endif
Loading

0 comments on commit b8b7d22

Please sign in to comment.