Skip to content

Commit

Permalink
Merge pull request #58 from Xilinx/enadjara.legalize_onnx_conv2d_tran…
Browse files Browse the repository at this point in the history
…spose_to_xten_nn

feat: legalize onnx conv2d transpose to xtennn
  • Loading branch information
ehsan-toosi authored Jul 31, 2024
2 parents 98243be + 5d8725f commit 4f0c873
Show file tree
Hide file tree
Showing 3 changed files with 398 additions and 42 deletions.
24 changes: 24 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,28 @@ def XtenNN_SignOp: XTenNN_Op<"sign", [Pure, TosaExtension, ElementwiseUnary, Sam
let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> {
let summary = "Perform ConvTranspose operation";
let description = [{
This operation is equivalent to `onnx.ConvTranspose` and computes the ConvTranspose.
}];

let arguments = (ins
4DTensorOf<[AnyFloat]>:$input,
4DTensorOf<[AnyFloat]>:$weights,
1DTensorOf<[AnyFloat]>:$bias,
ArrayAttr<2>:$pad,
I64DenseArrayAttr<2>:$output_padding,
I64DenseArrayAttr<2>:$stride,
I64DenseArrayAttr<2>:$dilation,
I64Attr:$group
);

let results = (outs
4DTensorOf<[AnyFloat]>:$output
);

let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

#endif // XTENNN_OPS
131 changes: 89 additions & 42 deletions lib/Conversion/XTenNNToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ Type toTorchTensorTypeCast(PatternRewriter &rewriter, ShapedType ty) {
/*isSigned=*/true);
}

return Torch::ValueTensorType::get(ty.getContext(), ty.getShape(),
elementType);
if (!ty.hasRank()) {
return Torch::ValueTensorType::get(ty.getContext(), {}, elementType);
}

return Torch::ValueTensorType::get(
ty.getContext(), Torch::makeShapeTorchCompatible(ty.getShape()),
elementType);
}

Value toTorchTensorTypeCast(PatternRewriter &rewriter, Value input) {
Expand All @@ -71,26 +76,54 @@ Value toBuiltinTensorTypeCast(OpBuilder &builder, Value val, Type type) {
type, val);
}

struct Conv2dPadding {
struct Padding2d {
std::array<int64_t, 2> hPadding;
std::array<int64_t, 2> wPadding;

[[nodiscard]] bool isSymmetric() const {
return hPadding[0] == hPadding[1] && wPadding[0] == wPadding[1];
}
};

Conv2dPadding getPadding(GroupConv2dOp::Adaptor adaptor) {
auto pad = adaptor.getPad();
assert(pad.size() == 2 && "expected 2 elements by definition");
template <typename T>
static Padding2d get(T adaptor) {
auto pad = adaptor.getPad();
assert(pad.size() == 2 && "expected 2 elements by definition");

auto hPadding = cast<DenseI64ArrayAttr>(pad[0]);
auto wPadding = cast<DenseI64ArrayAttr>(pad[1]);
assert(hPadding.size() == 2 && "expected 2 elements by definition");
assert(wPadding.size() == 2 && "expected 2 elements by definition");
auto hPadding = cast<DenseI64ArrayAttr>(pad[0]);
auto wPadding = cast<DenseI64ArrayAttr>(pad[1]);
assert(hPadding.size() == 2 && "expected 2 elements by definition");
assert(wPadding.size() == 2 && "expected 2 elements by definition");

return {{hPadding[0], hPadding[1]}, {wPadding[0], wPadding[1]}};
}
return {{hPadding[0], hPadding[1]}, {wPadding[0], wPadding[1]}};
}

// Zero padding of the input value with hPadding and wPadding using
// AtenConstantPadNdOp.
Value createZeroAtenPadOp(ConversionPatternRewriter &rewriter, Location loc,
Value input) {

// Build new vtensor result type
auto ty = cast<Torch::ValueTensorType>(input.getType());
mlir::Type paddingResultTy;
std::optional<llvm::ArrayRef<int64_t>> optSizes = ty.getOptionalSizes();
if (optSizes) {
auto newSizes = ty.getSizes().vec();
newSizes[2] += hPadding[0] + hPadding[1];
newSizes[3] += wPadding[0] + wPadding[1];
paddingResultTy = Torch::ValueTensorType::get(
rewriter.getContext(), newSizes, ty.getOptionalDtype());
} else {
paddingResultTy = Torch::ValueTensorType::get(
rewriter.getContext(), ty.getOptionalSizes(), ty.getOptionalDtype());
}

auto zeroPadValue = rewriter.create<Torch::ConstantIntOp>(loc, 0);
auto pads = Torch::toTorchList(
loc, rewriter, {hPadding[0], hPadding[1], wPadding[0], wPadding[1]});
return rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddingResultTy, input, pads, zeroPadValue);
}
};

template <typename SrcOpT>
std::optional<ValueRange> oneToOneXTenNNToTorch(SrcOpT op,
Expand All @@ -113,33 +146,11 @@ std::optional<ValueRange> groupConv2dToTorch(GroupConv2dOp op, GroupConv2dOp::Ad

auto newInput = values[0];
mlir::Value conv2dPads;
Conv2dPadding structPadding = getPadding(adaptor);
auto structPadding = Padding2d::get(adaptor);
if (!structPadding.isSymmetric()) {
// Padding is not symmetric which is the only mode aten conv2d op supports.
// We circumvent this problem by adding a padding operation

// Build new vtensor result type
auto ty = cast<Torch::ValueTensorType>(newInput.getType());
mlir::Type paddingResultTy;
std::optional<llvm::ArrayRef<int64_t>> optSizes = ty.getOptionalSizes();
if (optSizes) {
auto newSizes = ty.getSizes().vec();
newSizes[2] += structPadding.hPadding[0] + structPadding.hPadding[1];
newSizes[3] += structPadding.wPadding[0] + structPadding.wPadding[1];
paddingResultTy = Torch::ValueTensorType::get(op->getContext(), newSizes,
ty.getOptionalDtype());
} else {
paddingResultTy = Torch::ValueTensorType::get(
op->getContext(), ty.getOptionalSizes(), ty.getOptionalDtype());
}

auto zeroPadValue = rewriter.create<Torch::ConstantIntOp>(loc, 0);
auto pads = Torch::toTorchList(
loc, rewriter,
{structPadding.hPadding[0], structPadding.hPadding[1],
structPadding.wPadding[0], structPadding.wPadding[1]});
newInput = rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddingResultTy, newInput, pads, zeroPadValue);
newInput = structPadding.createZeroAtenPadOp(rewriter, loc, newInput);

// We want zero pad for the Conv2d since we are going to apply it with a
// padding op
Expand All @@ -163,6 +174,43 @@ std::optional<ValueRange> groupConv2dToTorch(GroupConv2dOp op, GroupConv2dOp::Ad
->getResults();
}

std::optional<ValueRange>
convTranspose2dToTorch(ConvTransposeOp op, ConvTransposeOp::Adaptor adaptor,
ArrayRef<Type> types, ValueRange values,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();

auto newInput = values[0];
auto newWeights = values[1];
auto newBias = values[2];

auto stride = Torch::toTorchList(loc, rewriter, adaptor.getStride().vec());
auto dilation =
Torch::toTorchList(loc, rewriter, adaptor.getDilation().vec());
auto group =
rewriter.create<Torch::ConstantIntOp>(loc, adaptor.getGroupAttr());
auto outputPadding =
Torch::toTorchList(loc, rewriter, adaptor.getOutputPadding().vec());

mlir::Value padValue;
auto structPadding = Padding2d::get(adaptor);
if (!structPadding.isSymmetric()) {
// AtenConvTranspose2dInputOp supports only symmetric padding. This can be
// handled by creating a pad operation on the input.
newInput = structPadding.createZeroAtenPadOp(rewriter, loc, newInput);
padValue = Torch::toTorchList(loc, rewriter, {0, 0});
} else {
padValue = Torch::toTorchList(
loc, rewriter, {structPadding.hPadding[0], structPadding.wPadding[0]});
}

return rewriter
.create<Torch::AtenConvTranspose2dInputOp>(
loc, types[0], newInput, newWeights, newBias, stride, padValue,
outputPadding, group, dilation)
->getResults();
}

std::optional<ValueRange> resizeToTorch(ResizeOp op, ResizeOp::Adaptor adaptor,
ArrayRef<Type> types, ValueRange values,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -377,11 +425,10 @@ struct ConvertXTenNNToTorch
context);
patterns.add<ApplyXTenNNToTorch<DepthToSpaceOp, depthToSpaceToTorch>>(
context);
patterns.add<ApplyXTenNNToTorch<GridSampleOp, gridSampleToTorch>>(
context);
patterns.add<ApplyXTenNNToTorch<ReflectPadOp, padReflectToTorch>>(
context);
patterns.add<ApplyXTenNNToTorch<ResizeOp, resizeToTorch>>(
patterns.add<ApplyXTenNNToTorch<GridSampleOp, gridSampleToTorch>>(context);
patterns.add<ApplyXTenNNToTorch<ReflectPadOp, padReflectToTorch>>(context);
patterns.add<ApplyXTenNNToTorch<ResizeOp, resizeToTorch>>(context);
patterns.add<ApplyXTenNNToTorch<ConvTransposeOp, convTranspose2dToTorch>>(
context);
if (failed(applyPartialConversion(funcOp, target, std::move(patterns))))
signalPassFailure();
Expand Down
Loading

0 comments on commit 4f0c873

Please sign in to comment.