Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extract shared util function ComputeBroadcastOutputShape #21940

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

// ORT system.
#include "core/providers/cuda/tensor/expand.h"
#include "core/providers/common.h"

// std C++.
#include <iostream>
Expand Down Expand Up @@ -51,7 +52,7 @@ Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
onnxruntime::ComputeBroadcastOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape)
Expand Down
29 changes: 0 additions & 29 deletions onnxruntime/core/providers/cann/cann_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,34 +224,5 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
}

Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

} // namespace cann
} // namespace onnxruntime
2 changes: 0 additions & 2 deletions onnxruntime/core/providers/cann/cann_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ Status aclrtblasGemmEx(aclTransType transA,

bool FileExist(const std::string& file_name);
void GenerateHashValue(const std::string string, HashValue& hash_value);
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape);

std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cann/math/binary_elementwise_ops.h"
#include <vector>
#include <algorithm>
Expand All @@ -20,7 +22,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
const Tensor* B = ctx->Input<Tensor>(1);

TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
Tensor* C = ctx->Output(0, output_shape);

void* A_data = const_cast<void*>(A->DataRaw());
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,38 @@
return accumulate(c.cbegin(), c.cend(), static_cast<T>(1), std::multiplies<T>());
}

/// <summary>
/// Compute the output shape for broadcasting the given input shapes of lhs and rhs.
/// </summary>
inline Status ComputeBroadcastOutputShape(const std::string& node_name,

Check warning on line 186 in onnxruntime/core/providers/common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/common.h:186: Add #include <string> for string [build/include_what_you_use] [4]
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);

Check warning on line 194 in onnxruntime/core/providers/common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/common.h:194: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);

Check warning on line 203 in onnxruntime/core/providers/common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/common.h:203: Add #include <algorithm> for min [build/include_what_you_use] [4]
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

} // namespace onnxruntime
32 changes: 3 additions & 29 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
Expand All @@ -21,34 +23,6 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
return Status::OK();
}

Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Expand Down Expand Up @@ -77,7 +51,7 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
const auto& rhs_shape = rhs_tensor->Shape();

TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context->Output(0, output_shape);

ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ struct BinaryElementwisePreparation {
}
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/variadic_elementwise_ops.h"

#include <cassert>
Expand Down Expand Up @@ -209,7 +210,7 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
TensorShape output_shape;
TensorShape previous_output_shape = first_input_tensor.Shape();
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(ComputeOutputShape(
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
previous_output_shape = output_shape;
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
TensorShape output_shape(output_dims);

ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
auto& output_tensor = *ctx->Output(0, output_shape);
if (0 == output_shape.Size()) {
return Status::OK();
Expand Down Expand Up @@ -202,7 +202,7 @@ std::unique_ptr<Tensor> FuncExpand(
TensorShape output_shape(output_dims);

ORT_ENFORCE(
ComputeOutputShape(
ComputeBroadcastOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape)
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ class Expand final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
Expand Down
Loading