Skip to content

Commit

Permalink
[MLIR][TORCH] Add decomposition for aten.var.correction op
Browse files Browse the repository at this point in the history
This commit adds the decomposition for `aten.var.correction` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com
  • Loading branch information
vivekkhandelwal1 committed Jul 28, 2022
1 parent 4403cfe commit 107f7af
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 128 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4047,6 +4047,32 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
}];
}

def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalIntType:$correction,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenVarCorrectionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
228 changes: 154 additions & 74 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,7 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
} // namespace

namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroOp op,
Expand Down Expand Up @@ -705,7 +704,8 @@ class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
// unsqueezed_sizes += [1, s]
// expanded_sizes += [m, s]
// reshaped_sizes += [m * s]
// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
// return
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
//
namespace {
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
Expand Down Expand Up @@ -754,7 +754,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
assert(leadingRank >= 0 && "leadingRank should greater than 0");
for (size_t i = 0; i < leadingRank; ++i) {
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{repeats[i]});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{repeats[i]});
reshapedSizes.push_back(repeats[i]);
}

Expand All @@ -772,18 +773,20 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
loc, rewriter.getI64IntegerAttr(selfShape[i]));
}

insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one, dimSize});
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{scale, dimSize});
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
ArrayRef<Value>{one, dimSize});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{scale, dimSize});

Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
reshapedSizes.push_back(scaledSize);
}

Type dtype = self.getType().cast<ValueTensorType>().getDtype();
Type unsqueezedType =
ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType =
ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype);
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIntSizes), dtype);

auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims =
Expand All @@ -792,8 +795,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
Value reshapedDims =
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
auto reshaped =
rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(), unsqueezedDims);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(),
unsqueezedDims);
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
reshaped, expandedDims);

Expand Down Expand Up @@ -1184,23 +1187,23 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
Value input = op.self();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();

Value inputTimesBeta = rewriter.create<AtenMulScalarOp>(
loc, inputType, input, op.beta());
Value inputTimesBeta =
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.beta());

// out = log1p(exp(input * beta)) / beta
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
Value out = rewriter.create<AtenDivScalarOp>(
loc, inputType, log1p, op.beta());
Value out =
rewriter.create<AtenDivScalarOp>(loc, inputType, log1p, op.beta());

// Select where x * beta > threshold
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
rewriter.getI1Type());
Value condition = rewriter.create<AtenGtScalarOp>(
loc, boolResType, inputTimesBeta, op.threshold());

rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(
op, op.getType(), condition, input, out);
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), condition,
input, out);
return success();
}
};
Expand Down Expand Up @@ -2138,6 +2141,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
};
} // namespace

template <typename OpTy>
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
bool unbiased, int64_t correction) {
Location loc = op.getLoc();
Value self = op.self();
Value dimList = op.dim();
Value keepDim = op.keepdim();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "support floating-point type input only");
}

// Upcasting the input tensor to `F64` dtype for higher precision during the
// computation of the result.
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
inputTensorTy = self.getType().cast<BaseTensorType>();
}

unsigned inputRank = getTensorRank(self);
SmallVector<Value> dimListElements;
bool isNoneOrEmpty = true;
if (!dimList.getType().template isa<Torch::NoneType>()) {
if (!getListConstructElements(dimList, dimListElements))
return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct");
if (dimListElements.size() != 0 || inputRank == 0)
isNoneOrEmpty = false;
}
if (isNoneOrEmpty) {
for (unsigned i = 0; i < inputRank; i++)
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimListElements);
}
Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
dimListElements[i],
/*keepDim=*/true);

Value constantNone = rewriter.create<ConstantNoneOp>(loc);
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
/*dtype=*/constantNone);
Value subMean =
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);

if (!unbiased) {
Value result = rewriter.create<AtenMeanDimOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
result = convertTensorToDtype(rewriter, loc, result,
outputTensorType.getDtype());
rewriter.replaceOp(op, result);
return success();
}
// Divide the square sum by productDimSize - correction.
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);

// `productDimSize` is product of sizes of dimensions to be reduced.
Value constantOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value productDimSize = constantOne;
for (Value dim : dimListElements) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(correction));
// The `correction` value should be less than or equal to `productDimSize +
// 1`.
Value productDimSizePlusOne =
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
Value cond =
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"correction value should be less than or equal to productDimSize + 1");
Value productDimSizeSubCorrection =
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
productDimSizeSubCorrection);
result =
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
rewriter.replaceOp(op, result);
return success();
}

// Decompose aten.var(x, dims) into:
// sub = aten.sub(x, aten.mean(x, dims))
// square = aten.square(sub)
Expand All @@ -2151,70 +2255,44 @@ class DecomposeAtenVarDimOp : public OpRewritePattern<AtenVarDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dimList = op.dim();
Value keepDim = op.keepdim();
Type outputType = op.getType();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"support floating type input only");
}

auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
if (!dimListConstruct) {
return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct");
}

bool unbiased;
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var");
}
int64_t correction = unbiased ? 1 : 0;
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success();
}
};
} // namespace

SmallVector<Value> dimListElements = dimListConstruct.elements();
Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
dimListElements[i],
/*keepDim=*/true);

Value constantNone = rewriter.create<ConstantNoneOp>(loc);
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
/*dtype=*/constantNone);
Value subMean =
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
if (unbiased) {
// Bessel’s correction is used. Divide the square sum by
// productDimSize-1.
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone);

// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (Value dim : dimListConstruct.elements()) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value productDimSizeSubOne =
rewriter.create<AtenSubIntOp>(loc, productDimSize, constantOne);
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, squareSum,
productDimSizeSubOne);
// Decompose aten.var(x, dims) into:
// sub = aten.sub(x, aten.mean(x, dims))
// square = aten.square(sub)
// out = aten.sum(square, dims) / (productDimSize - correction)
namespace {
class DecomposeAtenVarCorrectionOp
: public OpRewritePattern<AtenVarCorrectionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
PatternRewriter &rewriter) const override {
int64_t correction;
if (!op.correction().getType().isa<Torch::NoneType>()) {
if (!matchPattern(op.correction(), m_TorchConstantInt(&correction)))
return rewriter.notifyMatchFailure(
op, "Only support constant int correction for aten.var");
} else {
rewriter.replaceOpWithNewOp<AtenMeanDimOp>(
op, outputType, square, dimList, keepDim, /*dtype=*/constantNone);
// The default value in case of `correction` being None is 1.
correction = 1;
}
bool unbiased = correction == 0 ? false : true;
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success();
}
};
Expand Down Expand Up @@ -2426,6 +2504,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenSelectScatterOp>();
patterns.add<DecomposeAtenVarDimOp>(context);
target.addIllegalOp<AtenVarDimOp>();
patterns.add<DecomposeAtenVarCorrectionOp>(context);
target.addIllegalOp<AtenVarCorrectionOp>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ void TypeAnalysis::visitOperation(Operation *op,
Type dtype = operands[0]->getValue().dtype;
visitReductionAlongAllDimsOp(max, dtype, operands);
return;
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp>(op)) {
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp, AtenVarCorrectionOp>(op)) {
auto input = operands[0]->getValue();
visitReductionAlongAllDimsOp(op, input.dtype, operands);
return;
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5570,6 +5570,36 @@ module {
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
Expand Down
Loading

0 comments on commit 107f7af

Please sign in to comment.