diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 22182ab204c..770e76945af 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -43,7 +43,8 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( @@ -53,8 +54,8 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { self); return success(); } else { - return op.emitError( - "Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } } }; @@ -94,13 +95,14 @@ class ConvertAtenBinaryOp : public OpConversionPattern { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) - return op.emitError("Add: input datatypes mismatched"); + return rewriter.notifyMatchFailure(op, "Input datatypes mismatched"); rewriter.replaceOpWithNewOp( op, @@ -140,7 +142,8 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue)); if (!isFloat && !isInt) - return op->emitError("Unable to extract the scalar constant"); + return rewriter.notifyMatchFailure(op, + "Unable to extract the scalar constant"); if (dtype.isa()) { tosaTensor = tosa::getConstTensor( @@ -149,12 +152,15 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); if (w != 32 && w != 64) - return op->emitError("Unsupported integer type") << intType; + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "Unsupported integer type: " << intType; + }); if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); @@ -162,15 +168,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); tosaTensor = tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); } - } else - return op->emitError("Usupported element type"); + } else { + return rewriter.notifyMatchFailure(op, "Usupported element type"); + } return success(); } @@ -186,11 +194,13 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter, // `alpha` has not been specified. int64_t alphaValue; if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue))) - return op->emitError("Currently only scalar constants are supported for " - "alpha in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "alpha in TOSA operation"); // When no alpha has been specified, this must be 1. if (checkForUnity && alphaValue != 1) - return op->emitError("Unsupported integer value for alpha"); + return rewriter.notifyMatchFailure(op, + "Unsupported integer value for alpha"); alphaTensor = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue); @@ -214,12 +224,13 @@ class ConvertAtenAddSubOp : public OpConversionPattern { auto rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); if (auto lhsElemTy = lhsType.getElementType().dyn_cast()) { if (lhsElemTy.getWidth() > 32) - return op.emitError( - "Integers with widths greater than 32 are not supported"); + return rewriter.notifyMatchFailure( + op, "Integers with widths greater than 32 are not supported"); } auto outType = OpConversionPattern::getTypeConverter() @@ -228,16 +239,17 @@ class ConvertAtenAddSubOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } Value rhsAsTensor; if (!rhsType) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, outElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsType ? rhs : rhsAsTensor; @@ -246,8 +258,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(), alphaTensor, outElemTy, /*checkForUnity=*/false))) { - return op.emitError("Currently only scalar constants are supported for " - "alpha in conversion to TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "alpha in conversion to TOSA operation"); } auto multTensor = rewriter.create( @@ -262,8 +275,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { return success(); } else { - return op.emitError( - "Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } } }; // namespace @@ -283,26 +296,29 @@ class ConvertAtenCompareOp : public OpConversionPattern { auto rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // For bitwise operators, only integer datatype legalization is supported if (lhsElemTy.isa() && std::is_same()) { - return op.emitError("For bitwise operators, only integer datatype " - "legalization is supported"); + return rewriter.notifyMatchFailure(op, + "For bitwise operators, only integer " + "datatype legalization is supported"); } Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, lhsElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; // There is no Lesser operator in TOSA. @@ -343,7 +359,8 @@ class ConvertAtenMulOp : public OpConversionPattern { auto lhsType = lhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -351,8 +368,8 @@ class ConvertAtenMulOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); Value rhsTensor; if (std::is_same()) { @@ -363,10 +380,11 @@ class ConvertAtenMulOp : public OpConversionPattern { auto rhsType = rhs.getType().dyn_cast(); if (!rhsType) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), - rhsAsTensor, outElemTy, {}))) - return op.emitError( - "Currently only scalar constants are supported for " - "conversion in TOSA operation"); + rhsAsTensor, outElemTy, {}))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } } rhsTensor = rhsType ? rhs : rhsAsTensor; } @@ -383,11 +401,12 @@ class ConvertAtenMulOp : public OpConversionPattern { lhs, rhsTensor, /*shift=*/0); return success(); - } else { - // Quantized multiplication may need to rescale inputs. - return op.emitError("Only floating-point or integer datatype " - "legalization currently supported"); } + + // Quantized multiplication may need to rescale inputs. + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype " + "legalization currently supported"); } }; @@ -405,19 +424,21 @@ class ConvertAtenDivOp : public OpConversionPattern { auto rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, lhsElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; @@ -463,12 +484,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); - } else { - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return op.emitError( - "Only floating-point datatype legalization currently supported"); } + // Sigmoid legalization in TOSA for quantized element-type uses specialized + // tosa.table construct. + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); } template <> @@ -481,12 +501,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); - } else { - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return op.emitError( - "Only floating-point datatype legalization currently supported"); } + // Sigmoid legalization in TOSA for quantized element-type uses + // specialized tosa.table construct. + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); } template <> @@ -499,22 +518,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Maps to tosa.clamp which has both int and fp limits. int64_t clampMin = 0; Value clampIn = self; - if (selfTy) { - // Rescale the clampIn for quantized types. TBD - if (!selfTy.getElementType().isa()) { - return op.emitError( - "Only floating-point datatype legalization currently supported"); - } - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), clampIn, - rewriter.getI64IntegerAttr(clampMin), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); - return success(); - } else { - return op.emitError("Only Tensor types supported in TOSA"); + if (!selfTy) { + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); } + + // Rescale the clampIn for quantized types. TBD + if (!selfTy.getElementType().isa()) { + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); + } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), clampIn, + rewriter.getI64IntegerAttr(clampMin), + rewriter.getI64IntegerAttr(std::numeric_limits::max()), + rewriter.getF32FloatAttr(0.0f), + rewriter.getF32FloatAttr(std::numeric_limits::max())); + return success(); } using ReductionConvFunc = llvm::Optional (*)(PatternRewriter &, @@ -547,14 +567,15 @@ class ConvertAtenReductionOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto outputTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); if (!outputTy) - return op.emitError( - "Only ranked tensor type outputs permitted for reduce_mean"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type outputs permitted for reduce_mean"); ElementsAttr reduceDimsAttr; bool keepDims; @@ -677,7 +698,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA argmax"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA argmax"); int64_t reduceDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) { @@ -729,11 +751,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto castToInt64 = [&](Value result) -> LogicalResult { auto resTy = result.getType().cast(); if (!resTy) - return op.emitError("Argmax: Result is not a shaped type"); + return rewriter.notifyMatchFailure(op, + "Argmax: Result is not a shaped type"); auto resShape = resTy.getShape(); - auto outTy = - RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type()); + auto outTy = RankedTensorType::get(resShape, outputETy); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(outTy), result); @@ -779,11 +801,13 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA argmax"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA argmax"); SmallVector newOutputShape; if (failed(generateSqueezedShape(op, selfTy, rewriter, newOutputShape))) - return op.emitError("Squeeze could not compute new shape"); + return rewriter.notifyMatchFailure(op, + "Squeeze could not compute new shape"); auto resultTy = OpConversionPattern::getTypeConverter() ->convertType(op.getResult().getType()) @@ -871,17 +895,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Pow"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); if (!selfTy.getElementType().isa()) - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); Value expTensor; Value expScalar = op.exponent(); if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, selfTy.getElementType(), {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, expTensor); @@ -927,7 +954,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) - return op.emitError("Matmul: input datatypes mismatched"); + return rewriter.notifyMatchFailure(op, + "Matmul: input datatypes mismatched"); // Legalization constructs may offer input shapes but expect output shapes // to be inferred, e.g. @@ -1451,12 +1479,13 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { Value lhs, rhs; if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) - return op.emitError("Failed to read matmul inputs"); + return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); Value output; if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) - return op.emitError("Failed to perform matmul operation"); + return rewriter.notifyMatchFailure(op, + "Failed to perform matmul operation"); rewriter.replaceOpWithNewOp( op, @@ -1485,7 +1514,8 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); return success(); } @@ -1508,7 +1538,8 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -1544,7 +1575,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -1555,8 +1587,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { return op.emitError("aten.Linear called but weight rank not 2 or 3"); // Protection against crash due to unguarded code in TOSA->LinAlg. + // TODO: This should be handled in TOSA->LinAlg instead. if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape()) - return op.emitError("aten.Linear needs statically shaped input"); + return rewriter.notifyMatchFailure( + op, "aten.Linear needs statically shaped input"); return success(); } @@ -1569,7 +1603,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { Value lhs, rhs; if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) - return op.emitError("Failed to read matmul inputs"); + return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); // The aten.Linear op has a bias tensor that is added to the matmul output. auto bias = adaptor.bias(); @@ -1578,8 +1612,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { // TOSA does not mandate that elementwise op tensors need to be ranked. if (!biasTy.template isa() && !biasTy.template isa()) - return op.emitError("Only tensor types supported in GEMM to " - "TOSA for bias tensor"); + return rewriter.notifyMatchFailure( + op, "Only tensor types supported in GEMM to TOSA for bias tensor"); // RHS must have its last two dims transposed prior to matrix // multiplication. @@ -1617,7 +1651,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { Value matmulOutput; if (failed( this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput))) - return op.emitError("Failed to perform matmul operation"); + return rewriter.notifyMatchFailure(op, + "Failed to perform matmul operation"); Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { @@ -1651,17 +1686,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Rsub"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Rsub"); if (!selfTy.getElementType().isa()) - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, selfTy.getElementType(), {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA Rsub operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Rsub operation"); if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, alphaTensor, selfTy.getElementType(), @@ -1694,8 +1732,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .template cast(); if (!inputTy || !weightTy || !outputTy) - return op.emitError( - "Input, weight and output to Convolution must be ranked tensors"); + return rewriter.notifyMatchFailure( + op, "Input, weight and output to Convolution must be ranked tensors"); auto inputElemTy = inputTy.getElementType(); auto weightElemTy = weightTy.getElementType(); @@ -1703,10 +1741,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto weightShape = weightTy.getShape(); if (inputTy.getRank() != 4) - return op.emitError("Unimplemented: only 2D convolutions supported"); + return rewriter.notifyMatchFailure( + op, "Unimplemented: only 2D convolutions supported"); if (!weightTy.hasStaticShape()) - return op.emitError("Unimplemented: TOSA only supports static weight"); + return rewriter.notifyMatchFailure( + op, "Unimplemented: TOSA only supports static weight"); // Bias is optional. TOSA mandates a zero tensor here, so construct one if // required. @@ -1728,7 +1768,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } else { if (!bias.getType().cast()) - return op.emitError("Bias provided but not a ranked tensor"); + return rewriter.notifyMatchFailure( + op, "Bias provided but not a ranked tensor"); } auto biasElemTy = inputElemTy.template isa() ? inputElemTy @@ -1850,19 +1891,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Reshape"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Reshape"); // Check that at most one dimension is -1 SmallVector newShape; if (!matchPattern(op.shape(), m_TorchConstantIntList(newShape))) - return op.emitError("Only constant shape supported in TOSA Reshape"); + return rewriter.notifyMatchFailure( + op, "Only constant shape supported in TOSA Reshape"); int auto_sz = 0; for (auto s : newShape) auto_sz += (s == -1 ? 1 : 0); if (auto_sz > 1) - op.emitError("At most one dimension may be specified as -1 to " - "automatically calculate its size"); + return rewriter.notifyMatchFailure( + op, "At most one dimension may be specified as -1 to " + "automatically calculate its size"); auto newType = RankedTensorType::get(newShape, selfTy.getElementType()); @@ -1929,7 +1973,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor output if (!adaptor.input().getType().dyn_cast()) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); auto outType = getTypeConverter()->convertType(op.getType()); @@ -1937,12 +1982,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // FIXME: Handle training and momentum. if (op.momentum().getType().isa()) - op.emitError("Unsupported None for momentum"); + return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); auto meanType = adaptor.running_mean().getType().dyn_cast(); auto varianceType = adaptor.running_var().getType().dyn_cast(); if (!varianceType || !meanType) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); // Normalization ops perform elementwise ops of a single mean/stdev value // against the feature map and because input is NCHW, the rank-1 value must be @@ -1954,7 +2000,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType toBcastType = toBcast.getType().dyn_cast(); if (toBcastType.getRank() > 1) - op->emitError("Rank cannot be more than 1"); + return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); RankedTensorType outTensorType = outType.cast(); SmallVector newShape = {toBcastType.getShape()[0]}; @@ -1974,26 +2020,27 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.running_mean(), meanVal))) - op.emitError("Failed to reshape running mean"); + return rewriter.notifyMatchFailure(op, "Failed to reshape running mean"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.running_var(), varianceVal))) - op.emitError("Failed to reshape running variance"); + return rewriter.notifyMatchFailure(op, + "Failed to reshape running variance"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.weight(), weightVal))) - op.emitError("Failed to reshape weight"); + return rewriter.notifyMatchFailure(op, "Failed to reshape weight"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.bias(), biasVal))) - op.emitError("Failed to reshape bias"); + return rewriter.notifyMatchFailure(op, "Failed to reshape bias"); double eps; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) - return op.emitError("eps must be a scalar constant"); + return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); @@ -2021,11 +2068,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor output if (!adaptor.input().getType().dyn_cast()) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); auto inputType = adaptor.input().getType().cast(); if (inputType.getRank() > 4) - return op.emitError("Only up to 4D tensors are supported"); + return rewriter.notifyMatchFailure(op, + "Only up to 4D tensors are supported"); auto outType = getTypeConverter()->convertType(op.getType(0)); @@ -2033,9 +2082,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // FIXME: Handle the None cases for the optional parameters. if (adaptor.weight().getType().isa()) - return op.emitError("Unsupported None for weight"); + return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); if (adaptor.bias().getType().isa()) - return op.emitError("Unsupported None for bias"); + return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); auto weightType = adaptor.weight().getType().cast(); auto biasType = adaptor.bias().getType().cast(); @@ -2065,7 +2114,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (inputType.getShape()[index + meanAndVarShapeRank] != value || weightType.getShape()[index] != value || biasType.getShape()[index] != value) - return op.emitError("mismatching contracting dimension"); + return rewriter.notifyMatchFailure(op, + "mismatching contracting dimension"); } // Helper for computing mean and variance. @@ -2146,7 +2196,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double eps; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) - return op.emitError("eps must be a scalar constant"); + return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); @@ -2181,7 +2231,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType || !selfType.hasStaticShape()) - return op.emitError( + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types with static shapes are currently supported"); int64_t selfRank = selfType.getRank(); @@ -2189,19 +2240,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t start_dim, end_dim; if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim))) - return op.emitError("start_dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, + "start_dim must be a Scalar constant"); start_dim = toPositiveDim(start_dim, selfRank); if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim))) - return op.emitError("end_dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "end_dim must be a Scalar constant"); end_dim = toPositiveDim(end_dim, selfRank); if (selfRank > 0 && !isValidDim(start_dim, selfRank)) - return op.emitError("start_dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "start_dim is statically invalid"); if (selfRank > 0 && !isValidDim(end_dim, selfRank)) - return op.emitError("end_dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "end_dim is statically invalid"); if (end_dim < start_dim) - return op.emitError("end_dim must be larger than start_dim"); + return rewriter.notifyMatchFailure(op, + "end_dim must be larger than start_dim"); SmallVector newShape; for (auto s : llvm::enumerate(selfType.getShape())) { @@ -2238,7 +2291,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError( + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types with static shapes are currently supported"); SmallVector dimListInt; @@ -2247,10 +2301,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only constant dimensions are currently supported"); int64_t selfRank = selfType.getRank(); + // TODO: If this is already verified on the op then we can drop checking here. for (auto &d : dimListInt) { d = toPositiveDim(d, selfRank); if (!isValidDim(d, selfRank)) - return op.emitError("Not all dims are valid"); + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); } auto transposeDimsConst = mlir::tosa::getConstTensor( @@ -2271,7 +2326,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); @@ -2298,29 +2354,32 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // Integer types with width > 32 are not supported auto selfIntType = selfElemTy.dyn_cast(); if (selfIntType && selfIntType.getWidth() > 32) { - return op.emitError( - "Integer types with width greater than 32 are not supported"); + return rewriter.notifyMatchFailure( + op, "Integer types with width greater than 32 are not supported"); } SmallVector constTypeShape(selfType.getRank(), 1); Value threshold, value; if (failed(torchScalarToTosaTensor(rewriter, op, op.threshold(), threshold, selfElemTy, constTypeShape))) - return op.emitError("Only scalar constant is supported for threshold"); + return rewriter.notifyMatchFailure( + op, "Only scalar constant is supported for threshold"); if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), value, selfElemTy, constTypeShape))) - return op.emitError("Only scalar constant is supported for value"); + return rewriter.notifyMatchFailure( + op, "Only scalar constant is supported for value"); // Threshold only clamps the upper values. tosa::ClampOp has the same // value for both threshold and clamped value so cannot be used. @@ -2345,23 +2404,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) { - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); } auto selfRank = selfType.getRank(); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfRank); if (!isValidDim(dim, selfRank)) - return op.emitError("dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector outShape; for (auto en : llvm::enumerate(selfType.getShape())) { @@ -2386,7 +2446,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // FIXME: memory_format is not handled. @@ -2403,16 +2464,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.input().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // FIXME: train and p are not handled. bool train; if (!matchPattern(op.train(), m_TorchConstantBool(&train))) - op.emitError("train must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "train must be a Scalar constant"); if (train) - op.emitError("train must be false"); + return rewriter.notifyMatchFailure(op, "train must be false"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.input()); @@ -2428,17 +2490,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } SmallVector outShape; if (!matchPattern(op.size(), m_TorchConstantIntList(outShape))) - return op.emitError("size must consist of Scalar constants"); + return rewriter.notifyMatchFailure(op, + "size must consist of Scalar constants"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), @@ -2530,18 +2594,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isa()) { - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") { - return op.emitError("Unsupported value of approximate"); + return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self()); @@ -2561,18 +2627,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isa()) { - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") { - return op.emitError("Unsupported value of approximate"); + return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } auto loc = op->getLoc(); @@ -2620,10 +2688,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesType = indices.getType().dyn_cast(); if (!indicesType || !indicesType.getElementType().isa()) - return op.emitError("Indices must be of integer tensor type"); + return rewriter.notifyMatchFailure( + op, "Indices must be of integer tensor type"); if (indicesType.getRank() != 2) - return op.emitError("indices must be of rank 2"); + return rewriter.notifyMatchFailure(op, "indices must be of rank 2"); auto weightType = weight.getType().cast(); if (weightType.getRank() != 2) @@ -2711,22 +2780,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); // Only statically resolvable values are currently supported int64_t dim0, dim1; if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) - return op->emitError("dim0 must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim0 must be a Scalar constant"); if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) - return op->emitError("dim1 must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim1 must be a Scalar constant"); dim0 = toPositiveDim(dim0, selfType.getRank()); dim1 = toPositiveDim(dim1, selfType.getRank()); auto selfRank = selfType.getRank(); if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank)) - return op->emitError("dim0 and dim1 must be less than tensor rank"); + return rewriter.notifyMatchFailure( + op, "dim0 and dim1 must be less than tensor rank"); SmallVector transposeDims; for (auto i = 0; i < selfType.getRank(); ++i) @@ -2752,12 +2822,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); auto indicesType = getTypeConverter()->convertType(op.getType(1)).dyn_cast(); if (!indicesType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); auto selfElemType = selfType.getElementType(); auto indicesElemType = indicesType.getElementType(); @@ -2765,16 +2835,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfType.getRank()); if (!isValidDim(dim, selfType.getRank())) - return op->emitError("dim must be less than tensor rank"); + return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank"); bool keepDim; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) - return op->emitError("keepdim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant"); SmallVector reducedShape, prunedShape; for (auto en : llvm::enumerate(selfType.getShape())) { @@ -2820,39 +2890,42 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType || !selfType.hasStaticShape()) - return op.emitError("Only tensor types with static shape are supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfType.getRank()); if (!isValidDim(dim, selfType.getRank())) - return op->emitError("dim must less than tensor rank"); + return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); int64_t start; if (!matchPattern(op.start(), m_TorchConstantInt(&start))) - return op->emitError("start must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); if (start < 0) - return op->emitError("Currently unsupported: start < 0"); + return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); int64_t end; if (!matchPattern(op.end(), m_TorchConstantInt(&end))) - return op->emitError("end must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); // FIXME: add support for start/end < 0 and end < start if (end < start) - return op->emitError("Currently unsupported: end < start"); + return rewriter.notifyMatchFailure(op, + "Currently unsupported: end < start"); int64_t step; if (!matchPattern(op.step(), m_TorchConstantInt(&step))) - return op->emitError("step must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); if (step != 1) - return op->emitError("step value other than 1 is currently unsupported"); + return rewriter.notifyMatchFailure( + op, "step value other than 1 is currently unsupported"); SmallVector startSlice(selfType.getRank(), 0); SmallVector sizeSlice = llvm::to_vector(selfType.getShape()); @@ -2962,7 +3035,8 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { // case of adaptive pooling. Also performs input CHW->HWC transpose. if (failed(processInputs(op, adaptor, rewriter, input, kernel, stride, pad, outputTy))) - return op.emitError("Failed to process inputs for pooling"); + return rewriter.notifyMatchFailure( + op, "Failed to process inputs for pooling"); auto pooledOutput = rewriter @@ -2996,7 +3070,8 @@ class ConvertAtenAdaptivePoolingOp auto inputXchw = adaptor.self(); auto inputTy = inputXchw.getType().template cast(); if (!inputTy) - return op.emitError("Adaptive avgpool requires ranked tensor input"); + return rewriter.notifyMatchFailure( + op, "Adaptive avgpool requires ranked tensor input"); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); @@ -3004,7 +3079,8 @@ class ConvertAtenAdaptivePoolingOp // Rank sanity check. if (inputTy.getRank() != 4 && inputRank != 3) - return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + return rewriter.notifyMatchFailure( + op, "NCHW->NHWC transpose requires 3D or 4D tensor"); int64_t inputHDim = inputShape[inputRank - 2]; int64_t inputWDim = inputShape[inputRank - 1]; @@ -3020,8 +3096,8 @@ class ConvertAtenAdaptivePoolingOp outputHDim = outputWDim = outputSize[0]; } else { if (outputSize.size() != 2) - return op.emitError( - "Adaptive avgpool output_size not 1 or 2 elements."); + return rewriter.notifyMatchFailure( + op, "Adaptive avgpool output_size not 1 or 2 elements."); // Assumes 'None' (e.g. output_size=(None, 5) ) is expressed as <=0. outputHDim = @@ -3096,12 +3172,14 @@ static LogicalResult getOutputTypeAndPoolingParameters( RankedTensorType inputTy = inputXchw.getType().cast(); if (!inputTy) - return op.emitError("Pooling op requires ranked tensor input"); + return rewriter.notifyMatchFailure( + op, "Pooling op requires ranked tensor input"); auto inputRank = inputTy.getRank(); // Rank sanity check. if (inputTy.getRank() != 4 && inputRank != 3) - return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + return rewriter.notifyMatchFailure( + op, "NCHW->NHWC transpose requires 3D or 4D tensor"); SmallVector kernelSizeInts, strideInts, paddingInts; if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) @@ -3149,7 +3227,8 @@ class ConvertAtenMaxPool2dOp op, "Non-const dilation for pooling op unsupported."); // TOSA pooling only supports unit dilation. if (dilationArray[0] > 1 || dilationArray[1] > 1) - return op.emitError("Cannot process non-unit pooling dilation."); + return rewriter.notifyMatchFailure( + op, "Cannot process non-unit pooling dilation."); if (failed(getOutputTypeAndPoolingParameters( @@ -3206,29 +3285,32 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { .template dyn_cast(); if (!outType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? if (!op.layout().getType().template isa()) - return op.emitError("Only default layout is supported"); + return rewriter.notifyMatchFailure(op, + "Only default layout is supported"); bool pinMemory; if (!op.pin_memory().getType().template isa() && (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { - return op.emitError( - "Unsupported pin_memory, should be either None or false"); + return rewriter.notifyMatchFailure( + op, "Unsupported pin_memory, should be either None or false"); } SmallVector shape; if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { - return op.emitError("Shape must be a list of Scalar constants"); + return rewriter.notifyMatchFailure( + op, "Shape must be a list of Scalar constants"); } int64_t size = 1; @@ -3258,18 +3340,19 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { .template dyn_cast(); if (!outType || !outType.hasStaticShape()) - return op.emitError( - "Only Tensor types with static shapes are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } Value constOp; if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), constOp, outElemTy, outType.getShape()))) - return op.emitError("Supplied value must be a Scalar constant"); + return rewriter.notifyMatchFailure( + op, "Supplied value must be a Scalar constant"); rewriter.replaceOpWithNewOp(op, outType, constOp);