From 591fe9bc304eeac0e32255363cb8e7a912391dca Mon Sep 17 00:00:00 2001 From: zzpmiracle Date: Mon, 28 Nov 2022 17:39:06 +0800 Subject: [PATCH] support for stable diffusion --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 91 +++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 121 +++++++++++++++++- .../jit_ir/build_tools/torch_ods_gen.py | 9 ++ 3 files changed, 215 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 1ef447eaab3b..e6bd53ac28ac 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4244,6 +4244,97 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_IntType:$num_groups, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_IntType:$N, + Torch_IntType:$C, + Torch_IntType:$HxW, + Torch_IntType:$group, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 3); + } + void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 3); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 89be8e51b931..04327f091981 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1800,6 +1800,119 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { return success(); } }; + +class DecomposeAtenGroupNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value input = op.input(); + auto inputTy = input.getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + ArrayRef inputSize = inputTy.getSizes(); + int64_t inputRank = inputTy.getSizes().size(); + if (inputRank != 4) { + return rewriter.notifyMatchFailure( + op, "group norm only support 4D input now."); + } + Value num_groups = op.num_groups(); + int64_t num_groups_int; + if (!matchPattern(num_groups, m_TorchConstantInt(&num_groups_int))) + return rewriter.notifyMatchFailure( + op, "non const num_groups for AtenGroupNormOp"); + + // reshape input -> [N, G, -1(G//C), H, W] + Value negOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value none = rewriter.create(loc); + SmallVector inputForNormTySize(inputRank + 1, + ShapedType::kDynamicSize); + inputForNormTySize[1] = num_groups_int; + Type inputForNormTy = inputTy.getWithSizesAndDtype( + llvm::makeArrayRef(inputForNormTySize), inputTy.getDtype()); + SmallVector orginInputSize; + for (int i = 0; i < inputRank; ++i) { + Value index = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + orginInputSize.push_back( + rewriter.create(loc, input, index)); + } + SmallVector inputForNormSize{orginInputSize.begin(), + orginInputSize.end()}; + inputForNormSize.insert(inputForNormSize.begin() + 1, num_groups); + inputForNormSize[2] = negOne; + Value inputForNormSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), inputForNormSize); + Value reshapedInput = rewriter.create( + loc, inputForNormTy, input, inputForNormSizeList); + // only keep N, G, reduce G//C, H, W + int64_t axis = 2; + std::vector meanVarTySizes(inputForNormTySize.size(), 1); + for (int i = 0; i < axis; i++) + meanVarTySizes[i] = inputForNormTySize[i]; + auto meanVarTy = inputTy.getWithSizesAndDtype( + llvm::makeArrayRef(meanVarTySizes), inputTy.getDtype()); + SmallVector normalizedShapeSize{inputForNormSize.begin() + axis, + inputForNormSize.end()}; + auto normalizedSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), normalizedShapeSize); + + auto nativeLayerNorm = + rewriter + .create( + loc, inputForNormTy, meanVarTy, meanVarTy, reshapedInput, + normalizedSizeList, none, none, op.eps()) + .getResult(0); + // rehshape back to origin shape + Value inputSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), orginInputSize); + Value originOutput = rewriter.create( + loc, op.getType(), nativeLayerNorm, inputSizeList); + // reshape weight and bias to [1, C, 1, 1] + Value weight = op.weight(); + Value bias = op.bias(); + if (!weight.getType().isa() || + !bias.getType().isa()) { + SmallVector weightsAndBiasSize(inputRank - 1, one); + weightsAndBiasSize[0] = orginInputSize[1]; + + SmallVector weightsAndBiasTySize(inputRank - 1, + ShapedType::kDynamicSize); + // weightsAndBiasTySize[1] = ShapedType::kDynamicSize; + + Value weightsAndBiasSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), weightsAndBiasSize); + if (!weight.getType().isa()) { + BaseTensorType weightType = weight.getType().cast(); + Type weightTy = weightType.getWithSizesAndDtype( + llvm::makeArrayRef(weightsAndBiasTySize), weightType.getDtype()); + weight = rewriter.create(loc, weightTy, weight, + weightsAndBiasSizeList); + originOutput = rewriter.create(loc, op.getType(), + originOutput, weight); + } + if (!bias.getType().isa()) { + BaseTensorType biasType = bias.getType().cast(); + Type biasTy = biasType.getWithSizesAndDtype( + llvm::makeArrayRef(weightsAndBiasTySize), biasType.getDtype()); + bias = rewriter.create(loc, biasTy, bias, + weightsAndBiasSizeList); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + originOutput = rewriter.create( + loc, op.getType(), originOutput, bias, alpha); + } + } + rewriter.replaceOp(op, {originOutput}); + return success(); + } +}; + } // namespace namespace { @@ -2506,12 +2619,6 @@ class DecomposeAtenToDtypeLayoutOp op, "unimplemented: pin_memory is expected to be false"); } - // TODO: Add support for non-None device arg. - if (!op.device().getType().isa()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: device arg must be None"); - } - // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!op.layout().getType().isa()) { @@ -3378,6 +3485,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5e2febb20323..e16792b96259 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -366,6 +366,15 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) + emit( + "aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)" + ) + emit( + "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" )