Skip to content

Commit

Permalink
Prune xfail e2e LTC tests & fix bugs from functionalization pass (#1044)
Browse files Browse the repository at this point in the history
- Pruned number of xfailed e2e LTC tests from 305 to 134
  - Reviewed every failure to ensure the error genuinely warrants an xfail
- Fixed bug where non-tensor outputs of LTC computation had `.to('cpu')` called, which caused a failure and inflated the xfail count
- Fixed bug with `HBC_basic` test where a constant tensor was created in its constructor without being declared as a buffer, which prevented the device from being updated when the parent `torch.nn.Module` got moved to the `lazy` device
  - Note that this test is still xfail'd due to some unsupported ops. Left a comment about some potential issues that may arise if it gets reenabled in the future
- Updated autogen `GeneratedTorchOps.td` to reflect the latest set of supported ops
- Renamed `aten.zero.functionalization` to `aten.zero` to reflect upstream PyTorch changes
  • Loading branch information
henrytwo authored and antoniojkim committed Jul 19, 2022
1 parent 0d6e051 commit a335411
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 202 deletions.
177 changes: 3 additions & 174 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,60 +181,22 @@
LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeIntModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeNegativeStartIntModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
"BoolFloatConstantModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntConstantModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DropoutTrainModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
"ElementwiseAtenLogicalOrOpModule_basic",
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
"ElementwiseAtenLogicalOrOpRandomModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
Expand All @@ -244,11 +206,6 @@
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EqIntModule_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
Expand All @@ -261,12 +218,6 @@
"FullLikeModuleInt2DStatic_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleFloat2D_basic",
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GtFloatIntModule_basic",
Expand Down Expand Up @@ -308,56 +259,13 @@
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"MaskedFillScalarDefaultModule_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarIntValueModule_basic",
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
"MaxPool2dWithIndicesCeilModeTrueModule_basic",
"MaxPool2dWithIndicesFullSizeKernelModule_basic",
"MaxPool2dWithIndicesModule_basic",
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimKeepdimModule_basic",
"MeanDimModule_basic",
"MeanDimNegativeModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",
"MobilenetV3Module_basic",
"MulIntModule_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeLayerNormDynamicModule_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
Expand All @@ -375,69 +283,23 @@
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
"NllLossModuleBackward1DSum_basic",
"NllLossModuleBackward1DWeight_basic",
"NllLossModuleBackward1D_basic",
"NllLossModuleBackwardMeanWeight_basic",
"NllLossModuleBackwardMean_basic",
"NllLossModuleBackwardSumWeight_basic",
"NllLossModuleBackwardSum_basic",
"NllLossModuleBackwardWeight_basic",
"NllLossModuleBackward_basic",
"NllLossModuleBackward_ignore_index",
"NllLossModule_1D_basic",
"NllLossModule_basic",
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"OnesModuleDefaultDtype_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleFloat_basic",
"OnesModuleInt_basic",
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SelectIntModule_basic",
"SliceEndSleStartModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceSingleIdxModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceWholeTensorModule_basic",
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"StdBiasedModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TableBatchEmbeddingModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
Expand All @@ -446,43 +308,10 @@
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"TensorsConcatModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic",
"Threshold1dFloatModule_basic",
"Threshold1dIntI32Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dFloatModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dFloatModule_basic",
"Threshold3dIntModule_basic",
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
"ThresholdBackward2dFloatModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward2dMixedModule_basic",
"ThresholdBackward3dFloatModule_basic",
"ThresholdBackward3dIntModule_basic",
"ThresholdBackward3dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"UniformModule_basic",
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"VarBiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleFalsePinMemory_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
}
25 changes: 1 addition & 24 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5674,6 +5674,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [

def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`";
Expand Down Expand Up @@ -6797,30 +6798,6 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
}];
}

def Torch_Aten_UnsafeViewCopyOp : Torch_Op<"aten._unsafe_view_copy", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_unsafe_view_copy : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_UnsafeViewCopyOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_UnsafeViewCopyOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::_unsafe_view_copy : (Tensor, int[]) -> (Tensor)")


# Dict ops.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
torch.empty([_num_interval], dtype=torch.float64).fill_(0.0),
)
self.register_buffer("_bin_ids", torch.arange(_num_interval))
self.positive_weight = torch.tensor([0.4])
self.register_buffer("positive_weight", torch.tensor([0.4]))
self.bin_ctr_in_use_after = 0
self.bin_ctr_weight_value = 0.9995
self.oneminusbin_ctr_weight_value = 0.0005
Expand All @@ -54,6 +54,9 @@ def __init__(self):
def forward(self, segment_value, segment_lengths, logit):
origin_prediction = torch.sigmoid(
logit + torch.log(self.positive_weight))
# TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related
# issues below when new tensors are created on the CPU, which is currently the default behaviour.
# The solution would be to move these tensors to ensure they are on the same device as the existing ones.
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
validoffsets = torch.gt(
segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits])
Expand Down
12 changes: 10 additions & 2 deletions python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
import torch
from torch.utils._pytree import tree_map

from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem


def to_device(device):
"""Returns a lambda that maps `torch.Tensor` objects to `device`, and ignores other types"""
return lambda e: e.to(device) if isinstance(e, torch.Tensor) else e


class LazyTensorCoreTestConfig(TestConfig):
"""TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR"""

Expand All @@ -23,12 +30,13 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:

for item in trace:
# We need to move all the inputs to the lazy device before running in LTC.
lazy_inputs = [arg.to('lazy') for arg in item.inputs]
lazy_inputs = tree_map(to_device('lazy'), item.inputs)
output = getattr(artifact, item.symbol)(*lazy_inputs)
cpu_outputs = tree_map(to_device('cpu'), output)

result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output.to('cpu')))
output=cpu_outputs))

return result

0 comments on commit a335411

Please sign in to comment.