diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 6ee8e104e..f519baf54 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -391,10 +391,10 @@ diopiError_t diopiDivScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d diopiRoundMode_t rounding_mode) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto roundingMode = impl::aten::getRoundingMode(rounding_mode); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(div_out, atOut, atInput, c10::scalar_to_tensor(atOther), roundingMode); + CALL_ATEN_CUDA_FUNC(div_out, atOut, atInput, atOther, roundingMode); return diopiSuccess; } @@ -406,9 +406,9 @@ diopiError_t diopiDivScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d diopiError_t diopiDivInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, diopiRoundMode_t rounding_mode) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto roundingMode = impl::aten::getRoundingMode(rounding_mode); - CALL_ATEN_CUDA_FUNC(div_, atInput, c10::scalar_to_tensor(atOther), roundingMode); + CALL_ATEN_CUDA_FUNC(div_, atInput, atOther, roundingMode); return diopiSuccess; } @@ -1587,10 +1587,10 @@ diopiError_t diopiAddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d const diopiScalar_t* alpha) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atAlpha = impl::aten::buildAtScalar(alpha); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(add_out, atOut, atInput, c10::scalar_to_tensor(atOther), atAlpha); + CALL_ATEN_CUDA_FUNC(add_out, atOut, atInput, atOther, atAlpha); return diopiSuccess; } @@ -1598,9 +1598,9 @@ diopiError_t diopiAddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, const diopiScalar_t* alpha) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atAlpha = impl::aten::buildAtScalar(alpha); - CALL_ATEN_CUDA_FUNC(add_, atInput, c10::scalar_to_tensor(atOther), atAlpha); + CALL_ATEN_CUDA_FUNC(add_, atInput, atOther, atAlpha); return diopiSuccess; } @@ -1653,10 +1653,10 @@ diopiError_t diopiSubScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d const diopiScalar_t* alpha) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atAlpha = impl::aten::buildAtScalar(alpha); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(sub_out, atOut, atInput, c10::scalar_to_tensor(atOther), atAlpha); + CALL_ATEN_CUDA_FUNC(sub_out, atOut, atInput, atOther, atAlpha); return diopiSuccess; } @@ -1664,9 +1664,9 @@ diopiError_t diopiSubScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d diopiError_t diopiSubInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, const diopiScalar_t* alpha) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atAlpha = impl::aten::buildAtScalar(alpha); - CALL_ATEN_CUDA_FUNC(sub_, atInput, c10::scalar_to_tensor(atOther), atAlpha); + CALL_ATEN_CUDA_FUNC(sub_, atInput, atOther, atAlpha); return diopiSuccess; } @@ -1693,9 +1693,9 @@ diopiError_t diopiMulInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, di diopiError_t diopiMulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(mul_out, atOut, atInput, c10::scalar_to_tensor(atOther)); + CALL_ATEN_CUDA_FUNC(mul_out, atOut, atInput, atOther); return diopiSuccess; } @@ -1703,8 +1703,8 @@ diopiError_t diopiMulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, d diopiError_t diopiMulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); - CALL_ATEN_CUDA_FUNC(mul_, atInput, c10::scalar_to_tensor(atOther)); + auto atOther = impl::aten::buildAtScalarTensor(other); + CALL_ATEN_CUDA_FUNC(mul_, atInput, atOther); return diopiSuccess; } @@ -2017,9 +2017,9 @@ diopiError_t diopiBitwiseAndInp(diopiContextHandle_t ctx, diopiTensorHandle_t in diopiError_t diopiBitwiseAndScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(bitwise_and_out, atOut, atInput, c10::scalar_to_tensor(atOther)); + CALL_ATEN_CUDA_FUNC(bitwise_and_out, atOut, atInput, atOther); return diopiSuccess; } @@ -2027,8 +2027,8 @@ diopiError_t diopiBitwiseAndScalar(diopiContextHandle_t ctx, diopiTensorHandle_t diopiError_t diopiBitwiseAndInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); - CALL_ATEN_CUDA_FUNC(bitwise_and_, atInput, c10::scalar_to_tensor(atOther)); + auto atOther = impl::aten::buildAtScalarTensor(other); + CALL_ATEN_CUDA_FUNC(bitwise_and_, atInput, atOther); return diopiSuccess; } @@ -2055,9 +2055,9 @@ diopiError_t diopiBitwiseOrInp(diopiContextHandle_t ctx, diopiTensorHandle_t inp diopiError_t diopiBitwiseOrScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(bitwise_or_out, atOut, atInput, c10::scalar_to_tensor(atOther)); + CALL_ATEN_CUDA_FUNC(bitwise_or_out, atOut, atInput, atOther); return diopiSuccess; } @@ -2065,8 +2065,8 @@ diopiError_t diopiBitwiseOrScalar(diopiContextHandle_t ctx, diopiTensorHandle_t diopiError_t diopiBitwiseOrInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); - CALL_ATEN_CUDA_FUNC(bitwise_or_, atInput, c10::scalar_to_tensor(atOther)); + auto atOther = impl::aten::buildAtScalarTensor(other); + CALL_ATEN_CUDA_FUNC(bitwise_or_, atInput, atOther); return diopiSuccess; } @@ -4628,19 +4628,19 @@ diopiError_t diopiRemainderTensor(diopiContextHandle_t ctx, diopiTensorHandle_t diopiError_t diopiRemainderScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); - auto atOther = impl::aten::buildAtScalar(other); + auto atOther = impl::aten::buildAtScalarTensor(other); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInput, c10::scalar_to_tensor(atOther)); + CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInput, atOther); return diopiSuccess; } diopiError_t diopiRemainder(diopiContextHandle_t ctx, diopiTensorHandle_t out, const diopiScalar_t* input, diopiConstTensorHandle_t other) { impl::aten::setCurStream(ctx); - auto atInputScalar = impl::aten::buildAtScalar(input); + auto atInputScalar = impl::aten::buildAtScalarTensor(input); auto atOther = impl::aten::buildATen(other); auto atOut = impl::aten::buildATen(out); - CALL_ATEN_CUDA_FUNC(remainder_out, atOut, c10::scalar_to_tensor(atInputScalar), atOther); + CALL_ATEN_CUDA_FUNC(remainder_out, atOut, atInputScalar, atOther); return diopiSuccess; } diff --git a/impl/torch/helper.cpp b/impl/torch/helper.cpp index ff5176fdd..aa7f37063 100644 --- a/impl/torch/helper.cpp +++ b/impl/torch/helper.cpp @@ -94,6 +94,13 @@ at::Scalar buildAtScalar(const diopiScalar_t* scalar) { } } +at::Tensor buildAtScalarTensor(const diopiScalar_t* scalar) { + auto atScalar = buildAtScalar(scalar); + at::Tensor atScalarTensor = c10::scalar_to_tensor(atScalar); + atScalarTensor.unsafeGetTensorImpl()->set_wrapped_number(true); + return atScalarTensor; +} + void buildDiopiTensor(diopiContextHandle_t ctx, const at::Tensor& input, diopiTensorHandle_t* out) { at::IntArrayRef atSize = input.sizes(); at::IntArrayRef atStride = input.strides(); diff --git a/impl/torch/helper.hpp b/impl/torch/helper.hpp index 930f860f4..8ffc607a5 100644 --- a/impl/torch/helper.hpp +++ b/impl/torch/helper.hpp @@ -104,6 +104,8 @@ inline bool isFloat(const diopiScalar_t* scalar) { return scalar->stype > 7; } at::Scalar buildAtScalar(const diopiScalar_t* scalar); +at::Tensor buildAtScalarTensor(const diopiScalar_t* scalar); + inline at::IntArrayRef buildAtIntArray(const diopiSize_t* size) { return at::IntArrayRef(size->data, size->len); } inline at::IntArrayRef buildAtIntArray(diopiSize_t size) { return at::IntArrayRef(size.data, size.len); }