Skip to content

Commit

Permalink
feat: add more tests and support for asymettric pads
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsan-toosi committed Jul 31, 2024
1 parent 4e96abe commit 5d8725f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 51 deletions.
69 changes: 40 additions & 29 deletions lib/Conversion/XTenNNToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,33 @@ struct Padding2d {

return {{hPadding[0], hPadding[1]}, {wPadding[0], wPadding[1]}};
}

// Zero padding of the input value with hPadding and wPadding using
// AtenConstantPadNdOp.
Value createZeroAtenPadOp(ConversionPatternRewriter &rewriter, Location loc,
Value input) {

// Build new vtensor result type
auto ty = cast<Torch::ValueTensorType>(input.getType());
mlir::Type paddingResultTy;
std::optional<llvm::ArrayRef<int64_t>> optSizes = ty.getOptionalSizes();
if (optSizes) {
auto newSizes = ty.getSizes().vec();
newSizes[2] += hPadding[0] + hPadding[1];
newSizes[3] += wPadding[0] + wPadding[1];
paddingResultTy = Torch::ValueTensorType::get(
rewriter.getContext(), newSizes, ty.getOptionalDtype());
} else {
paddingResultTy = Torch::ValueTensorType::get(
rewriter.getContext(), ty.getOptionalSizes(), ty.getOptionalDtype());
}

auto zeroPadValue = rewriter.create<Torch::ConstantIntOp>(loc, 0);
auto pads = Torch::toTorchList(
loc, rewriter, {hPadding[0], hPadding[1], wPadding[0], wPadding[1]});
return rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddingResultTy, input, pads, zeroPadValue);
}
};

template <typename SrcOpT>
Expand Down Expand Up @@ -123,29 +150,7 @@ std::optional<ValueRange> groupConv2dToTorch(GroupConv2dOp op, GroupConv2dOp::Ad
if (!structPadding.isSymmetric()) {
// Padding is not symmetric which is the only mode aten conv2d op supports.
// We circumvent this problem by adding a padding operation

// Build new vtensor result type
auto ty = cast<Torch::ValueTensorType>(newInput.getType());
mlir::Type paddingResultTy;
std::optional<llvm::ArrayRef<int64_t>> optSizes = ty.getOptionalSizes();
if (optSizes) {
auto newSizes = ty.getSizes().vec();
newSizes[2] += structPadding.hPadding[0] + structPadding.hPadding[1];
newSizes[3] += structPadding.wPadding[0] + structPadding.wPadding[1];
paddingResultTy = Torch::ValueTensorType::get(op->getContext(), newSizes,
ty.getOptionalDtype());
} else {
paddingResultTy = Torch::ValueTensorType::get(
op->getContext(), ty.getOptionalSizes(), ty.getOptionalDtype());
}

auto zeroPadValue = rewriter.create<Torch::ConstantIntOp>(loc, 0);
auto pads = Torch::toTorchList(
loc, rewriter,
{structPadding.hPadding[0], structPadding.hPadding[1],
structPadding.wPadding[0], structPadding.wPadding[1]});
newInput = rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddingResultTy, newInput, pads, zeroPadValue);
newInput = structPadding.createZeroAtenPadOp(rewriter, loc, newInput);

// We want zero pad for the Conv2d since we are going to apply it with a
// padding op
Expand Down Expand Up @@ -187,15 +192,21 @@ convTranspose2dToTorch(ConvTransposeOp op, ConvTransposeOp::Adaptor adaptor,
auto outputPadding =
Torch::toTorchList(loc, rewriter, adaptor.getOutputPadding().vec());

// TODO: check the order is matched with what torch needs.
auto pads = Padding2d::get(adaptor);
auto inputPadding = Torch::toTorchList(
loc, rewriter,
{pads.hPadding[0], pads.hPadding[1], pads.wPadding[0], pads.hPadding[1]});
mlir::Value padValue;
auto structPadding = Padding2d::get(adaptor);
if (!structPadding.isSymmetric()) {
// AtenConvTranspose2dInputOp supports only symmetric padding. This can be
// handled by creating a pad operation on the input.
newInput = structPadding.createZeroAtenPadOp(rewriter, loc, newInput);
padValue = Torch::toTorchList(loc, rewriter, {0, 0});
} else {
padValue = Torch::toTorchList(
loc, rewriter, {structPadding.hPadding[0], structPadding.wPadding[0]});
}

return rewriter
.create<Torch::AtenConvTranspose2dInputOp>(
loc, types[0], newInput, newWeights, newBias, stride, inputPadding,
loc, types[0], newInput, newWeights, newBias, stride, padValue,
outputPadding, group, dilation)
->getResults();
}
Expand Down
Loading

0 comments on commit 5d8725f

Please sign in to comment.