Skip to content

Commit

Permalink
fix build break
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Aug 31, 2024
1 parent 58c5528 commit f328b6b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/utils.h"

#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"

#include "core/providers/utils.h"

using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -51,7 +51,7 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
const auto& rhs_shape = rhs_tensor->Shape();

TensorShape output_shape;
ORT_RETURN_IF_ERROR(utils::ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
ORT_RETURN_IF_ERROR(utils::ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context->Output(0, output_shape);

ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/utils.h"
#include "core/providers/cuda/math/variadic_elementwise_ops.h"

#include <cassert>
Expand Down Expand Up @@ -209,7 +210,7 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
TensorShape output_shape;
TensorShape previous_output_shape = first_input_tensor.Shape();
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(ComputeOutputShape(
ORT_RETURN_IF_ERROR(utils::ComputeBroadcastOutputShape(
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
previous_output_shape = output_shape;
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

#pragma once

#ifndef SHARED_PROVIDER
#include "core/framework/framework_common.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/common.h"
#endif

namespace onnxruntime {
namespace utils {
Expand Down

0 comments on commit f328b6b

Please sign in to comment.