Skip to content

Commit

Permalink
Fix typo in inputRank check of AtenBatchNormOp (#1046)
Browse files Browse the repository at this point in the history
The original conversion pattern for `AtenBatchNormOp` required that
the input rank be greater than 2; however, the only
expectation in the conversion pattern and in Pytorch is that the input
rank is greater than 1, since the second dimension of the input must
match the size of the `weight`, `bias`, `runningMean`, and
`runningVar` inputs. This commit fixes the `inputRank` check.
  • Loading branch information
ramiro050 authored Jul 15, 2022
1 parent 3589134 commit afdaa60
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"FlattenStaticModule_basic",
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,9 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
auto runningVarType = runningVar.getType().cast<RankedTensorType>();

auto inputRank = inputType.getRank();
if (inputRank <= 2)
if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input should have rank larger than 2");
op, "input should have rank larger than 1");

if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
Expand Down
26 changes: 26 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,32 @@ def BatchNorm1DModule_basic(module, tu: TestUtils):

# ==============================================================================

class BatchNorm1DWith2DInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn1d = torch.nn.BatchNorm1d(4)
self.bn1d.eval()
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
self.bn1d.weight = torch.nn.Parameter(
torch.tensor([3.0, 2.0, 4.0, 5.0]))
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))

@export
@annotate_args([
None,
([10, 4], torch.float32, True),
])
def forward(self, x):
return self.bn1d(x)


@register_test_case(module_factory=lambda: BatchNorm1DWith2DInputModule())
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4))

# ==============================================================================

class BatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit afdaa60

Please sign in to comment.