Skip to content

Commit

Permalink
mlir: fix replacement of OpaqueElementsAttr (#1274)
Browse files Browse the repository at this point in the history
An earlier patch (bb47c16) incorrectly replaced the now-dropped
`OpaqueElementsAttr` with `SparseElementsAttr` in one place and with
`DenseElementsAttr` in another.  This patch fixes the problem by making
both replacements use the dense-equivalent type.
  • Loading branch information
ashay authored Aug 24, 2022
1 parent e2f862c commit 1d9d925
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
Expand Down Expand Up @@ -178,15 +179,17 @@ class ConvertTorchTensorLiteralOp
}));
return success();
}
if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) {
if (auto elements = op.valueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy =
IntegerType::get(context, intType.getIntOrFloatBitWidth());
auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(shapedType, elements.getValues()));
op, DenseElementsAttr::get(shapedType, blob->getData()));
return success();
}
}
Expand Down

0 comments on commit 1d9d925

Please sign in to comment.