Skip to content

Commit

Permalink
[Mhlo] fixes for code review
Browse files Browse the repository at this point in the history
Signed-off-by: chongsong.chen <chongsong.chen@bytedance.com>
  • Loading branch information
chenchongsong committed Jul 15, 2022
1 parent 2cedfc1 commit e3ed189
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
13 changes: 5 additions & 8 deletions src/Conversion/ONNXToMhlo/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;

Expand All @@ -31,17 +32,13 @@ struct ONNXConcatOpLoweringToMhlo : public ConversionPattern {
ONNXConcatOpAdaptor operandAdaptor(operands);
ONNXConcatOp concatOp = llvm::cast<ONNXConcatOp>(op);

if (op->getNumResults() < 1) {
op->emitError() << "ONNXConcatOp Has No Output\n";
return failure();
}
RankedTensorType resultType =
op->getResult(0).getType().dyn_cast_or_null<RankedTensorType>();
if (resultType == nullptr) {
assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result");
Type resultType = op->getResult(0).getType();
if (!onnx_mlir::isRankedShapedType(resultType)) {
op->emitError() << "Concat Output Is Not Ranked\n";
return failure();
}
int64_t rank = resultType.getRank();
int64_t rank = onnx_mlir::getRank(resultType);
int64_t axis = concatOp.axis();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range");
Expand Down
6 changes: 6 additions & 0 deletions src/Support/TypeUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ ArrayRef<int64_t> getShape(Type ty) {
return ty.cast<ShapedType>().getShape();
}

/// Get rank.
int64_t getRank(Type ty) {
assert(isRankedShapedType(ty) && "Type must be ranked");
return ty.cast<ShapedType>().getRank();
}

/// Get the number of elements.
int64_t getNumberOfElements(Type ty) {
ArrayRef<int64_t> shape = getShape(ty);
Expand Down
2 changes: 2 additions & 0 deletions src/Support/TypeUtilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mlir::Type getElementType(mlir::Type ty);
bool isRankedShapedType(mlir::Type ty);
/// Get shape.
llvm::ArrayRef<int64_t> getShape(mlir::Type ty);
/// Get rank.
int64_t getRank(mlir::Type ty);
/// Get the number of elements.
int64_t getNumberOfElements(mlir::Type ty);
/// Get the element size in bytes.
Expand Down
14 changes: 7 additions & 7 deletions test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s

// Test when output shape is unkown
func @test_concat_unknown_dims(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32>
"func.return"(%0) : (tensor<?x?x?x?xf32>) -> ()
// CHECK-LABEL: func @test_concat_unknown_dims
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x?x?x?xf32>
func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
"func.return"(%0) : (tensor<5x5x?x32xf32>) -> ()
// CHECK-LABEL: func @test_concat_dynamic_shape
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32>
// CHECK-NEXT: }
}

Expand Down

0 comments on commit e3ed189

Please sign in to comment.