Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add scalar tensor wrapped #1358

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ diopiError_t diopiDivScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
diopiRoundMode_t rounding_mode) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto roundingMode = impl::aten::getRoundingMode(rounding_mode);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(div_out, atOut, atInput, c10::scalar_to_tensor(atOther), roundingMode);
CALL_ATEN_CUDA_FUNC(div_out, atOut, atInput, atOther, roundingMode);

return diopiSuccess;
}
Expand All @@ -406,9 +406,9 @@ diopiError_t diopiDivScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
diopiError_t diopiDivInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, diopiRoundMode_t rounding_mode) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto roundingMode = impl::aten::getRoundingMode(rounding_mode);
CALL_ATEN_CUDA_FUNC(div_, atInput, c10::scalar_to_tensor(atOther), roundingMode);
CALL_ATEN_CUDA_FUNC(div_, atInput, atOther, roundingMode);

return diopiSuccess;
}
Expand Down Expand Up @@ -1587,20 +1587,20 @@ diopiError_t diopiAddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atAlpha = impl::aten::buildAtScalar(alpha);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(add_out, atOut, atInput, c10::scalar_to_tensor(atOther), atAlpha);
CALL_ATEN_CUDA_FUNC(add_out, atOut, atInput, atOther, atAlpha);

return diopiSuccess;
}

diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atAlpha = impl::aten::buildAtScalar(alpha);
CALL_ATEN_CUDA_FUNC(add_, atInput, c10::scalar_to_tensor(atOther), atAlpha);
CALL_ATEN_CUDA_FUNC(add_, atInput, atOther, atAlpha);

return diopiSuccess;
}
Expand Down Expand Up @@ -1653,20 +1653,20 @@ diopiError_t diopiSubScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atAlpha = impl::aten::buildAtScalar(alpha);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(sub_out, atOut, atInput, c10::scalar_to_tensor(atOther), atAlpha);
CALL_ATEN_CUDA_FUNC(sub_out, atOut, atInput, atOther, atAlpha);

return diopiSuccess;
}

diopiError_t diopiSubInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atAlpha = impl::aten::buildAtScalar(alpha);
CALL_ATEN_CUDA_FUNC(sub_, atInput, c10::scalar_to_tensor(atOther), atAlpha);
CALL_ATEN_CUDA_FUNC(sub_, atInput, atOther, atAlpha);

return diopiSuccess;
}
Expand All @@ -1693,18 +1693,18 @@ diopiError_t diopiMulInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, di
diopiError_t diopiMulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(mul_out, atOut, atInput, c10::scalar_to_tensor(atOther));
CALL_ATEN_CUDA_FUNC(mul_out, atOut, atInput, atOther);

return diopiSuccess;
}

diopiError_t diopiMulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
CALL_ATEN_CUDA_FUNC(mul_, atInput, c10::scalar_to_tensor(atOther));
auto atOther = impl::aten::buildAtScalarTensor(other);
CALL_ATEN_CUDA_FUNC(mul_, atInput, atOther);

return diopiSuccess;
}
Expand Down Expand Up @@ -2017,18 +2017,18 @@ diopiError_t diopiBitwiseAndInp(diopiContextHandle_t ctx, diopiTensorHandle_t in
diopiError_t diopiBitwiseAndScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(bitwise_and_out, atOut, atInput, c10::scalar_to_tensor(atOther));
CALL_ATEN_CUDA_FUNC(bitwise_and_out, atOut, atInput, atOther);

return diopiSuccess;
}

diopiError_t diopiBitwiseAndInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
CALL_ATEN_CUDA_FUNC(bitwise_and_, atInput, c10::scalar_to_tensor(atOther));
auto atOther = impl::aten::buildAtScalarTensor(other);
CALL_ATEN_CUDA_FUNC(bitwise_and_, atInput, atOther);

return diopiSuccess;
}
Expand All @@ -2055,18 +2055,18 @@ diopiError_t diopiBitwiseOrInp(diopiContextHandle_t ctx, diopiTensorHandle_t inp
diopiError_t diopiBitwiseOrScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(bitwise_or_out, atOut, atInput, c10::scalar_to_tensor(atOther));
CALL_ATEN_CUDA_FUNC(bitwise_or_out, atOut, atInput, atOther);

return diopiSuccess;
}

diopiError_t diopiBitwiseOrInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
CALL_ATEN_CUDA_FUNC(bitwise_or_, atInput, c10::scalar_to_tensor(atOther));
auto atOther = impl::aten::buildAtScalarTensor(other);
CALL_ATEN_CUDA_FUNC(bitwise_or_, atInput, atOther);

return diopiSuccess;
}
Expand Down Expand Up @@ -4628,19 +4628,19 @@ diopiError_t diopiRemainderTensor(diopiContextHandle_t ctx, diopiTensorHandle_t
diopiError_t diopiRemainderScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOther = impl::aten::buildAtScalar(other);
auto atOther = impl::aten::buildAtScalarTensor(other);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInput, c10::scalar_to_tensor(atOther));
CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInput, atOther);

return diopiSuccess;
}

diopiError_t diopiRemainder(diopiContextHandle_t ctx, diopiTensorHandle_t out, const diopiScalar_t* input, diopiConstTensorHandle_t other) {
impl::aten::setCurStream(ctx);
auto atInputScalar = impl::aten::buildAtScalar(input);
auto atInputScalar = impl::aten::buildAtScalarTensor(input);
auto atOther = impl::aten::buildATen(other);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(remainder_out, atOut, c10::scalar_to_tensor(atInputScalar), atOther);
CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInputScalar, atOther);
return diopiSuccess;
}

Expand Down
7 changes: 7 additions & 0 deletions impl/torch/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ at::Scalar buildAtScalar(const diopiScalar_t* scalar) {
}
}

at::Tensor buildAtScalarTensor(const diopiScalar_t* scalar) {
auto atScalar = buildAtScalar(scalar);
at::Tensor atScalarTensor = c10::scalar_to_tensor(atScalar);
atScalarTensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return atScalarTensor;
}

void buildDiopiTensor(diopiContextHandle_t ctx, const at::Tensor& input, diopiTensorHandle_t* out) {
at::IntArrayRef atSize = input.sizes();
at::IntArrayRef atStride = input.strides();
Expand Down
2 changes: 2 additions & 0 deletions impl/torch/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ inline bool isFloat(const diopiScalar_t* scalar) { return scalar->stype > 7; }

at::Scalar buildAtScalar(const diopiScalar_t* scalar);

at::Tensor buildAtScalarTensor(const diopiScalar_t* scalar);

inline at::IntArrayRef buildAtIntArray(const diopiSize_t* size) { return at::IntArrayRef(size->data, size->len); }

inline at::IntArrayRef buildAtIntArray(diopiSize_t size) { return at::IntArrayRef(size.data, size.len); }
Expand Down
Loading