Skip to content

Commit

Permalink
Modify signature of dequantize ops for decomposed quantized Tensor (#…
Browse files Browse the repository at this point in the history
…2308)

Summary:
X-link: pytorch/pytorch#121450


Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any.

At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization.

This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel

X-link: pytorch/pytorch#119173

Reviewed By: digantdesai

Differential Revision: D53590486

Pulled By: manuelcandales
  • Loading branch information
kausikmaiti authored and facebook-github-bot committed Mar 11, 2024
1 parent caade55 commit 72cc9e0
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 16 deletions.
8 changes: 8 additions & 0 deletions examples/xtensa/ops/dequantize_per_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ void dequantize_per_tensor_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType>& out_dtype,
Tensor& out) {
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out_dtype.value() == ScalarType::Float,
"Expected out dtype to be Float but got %hhd",
out_dtype.value());
}

float* out_data = out.mutable_data_ptr<float>();
size_t numel = out.numel();

Expand Down
2 changes: 1 addition & 1 deletion examples/xtensa/ops/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
- arg_meta: null
kernel_name: impl::HiFi::quantized_linear_pt2_out

- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
Expand Down
37 changes: 28 additions & 9 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void check_dequantize_per_tensor_args(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType>& out_dtype,
Tensor& out) {
ET_CHECK_MSG(
input.scalar_type() == ScalarType::Byte ||
Expand All @@ -47,10 +48,11 @@ void check_dequantize_per_tensor_args(
"input.scalar_type() %" PRId8 " is not matching dtype argumenta:",
static_cast<int8_t>(input.scalar_type()));

ET_CHECK_MSG(
out.scalar_type() == ScalarType::Float,
"out.scalar_type() %" PRId8 " is not supported:",
static_cast<int8_t>(out.scalar_type()));
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out.scalar_type() == out_dtype.value(),
"output_dtype must match the dtype of the out tensor");
}

ET_CHECK_MSG(
quant_min <= quant_max,
Expand All @@ -77,13 +79,15 @@ Tensor& dequantize_per_tensor_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
torch::executor::Error err = resize_tensor(out, input.sizes());
ET_CHECK_MSG(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in dequantize_per_tensor_out");

check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);

// calculate the dequantized output, cast scale to float to match fbgemm
// behavior
Expand Down Expand Up @@ -128,6 +132,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Double,
Expand All @@ -153,6 +158,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
quant_min,
quant_max,
dtype,
out_dtype,
out);
return out;
}
Expand All @@ -165,6 +171,7 @@ Tensor& dequantize_per_channel_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
torch::executor::Error err = resize_tensor(out, input.sizes());

Expand Down Expand Up @@ -205,7 +212,8 @@ Tensor& dequantize_per_channel_out(
ssize_t(zero_point.numel()),
ssize_t(input.size(axis)));

check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);

// a list contains all dimensions except axis
int64_t dims[input.dim() - 1];
Expand Down Expand Up @@ -281,10 +289,19 @@ Tensor& dequantize_per_channel_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
(void)context;
return dequantize_per_channel_out(
input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
input,
scale,
zero_point,
axis,
quant_min,
quant_max,
dtype,
out_dtype,
out);
}

Tensor& dequantize_per_tensor_out(
Expand All @@ -295,12 +312,13 @@ Tensor& dequantize_per_tensor_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return dequantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, dtype, out);
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
}

Tensor& dequantize_per_tensor_tensor_args_out(
Expand All @@ -311,12 +329,13 @@ Tensor& dequantize_per_tensor_tensor_args_out(
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return dequantize_per_tensor_tensor_args_out(
input, scale, zero_point, quant_min, quant_max, dtype, out);
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
}

} // namespace native
Expand Down
6 changes: 3 additions & 3 deletions kernels/quantized/quantized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
- arg_meta: null
kernel_name: torch::executor::choose_qparams_tensor_out

- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: torch::executor::dequantize_per_tensor_out

- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
Expand All @@ -28,7 +28,7 @@
- arg_meta: null
kernel_name: torch::executor::quantize_per_channel_out

- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
Expand Down
7 changes: 7 additions & 0 deletions kernels/quantized/test/op_add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

using namespace ::testing;
using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::RuntimeContext;
using exec_aten::Scalar;
using exec_aten::ScalarType;
Expand Down Expand Up @@ -190,6 +191,8 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
Tensor qinput2 = tfo.zeros({3, 5});
Tensor qoutput = tfo.zeros({3, 5});

optional<ScalarType> out_dtype = optional<ScalarType>();

RuntimeContext context{};
// q -> qadd -> dq
// 3.5 / 0.5 + 1 = 8
Expand Down Expand Up @@ -235,6 +238,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
ScalarType::Byte,
out_dtype,
reference_op_output);

// now get results for q -> dq -> fp add -> q -> dq
Expand All @@ -245,6 +249,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
ScalarType::Byte,
out_dtype,
dq_input1);

dequantize_per_tensor_out(
Expand All @@ -254,6 +259,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
ScalarType::Byte,
out_dtype,
dq_input2);

add_out(context, dq_input1, dq_input2, 1.0, fp_output);
Expand All @@ -274,6 +280,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
ScalarType::Byte,
out_dtype,
reference_pattern_output);

Tensor expected = tf.full({3, 5}, 7.0);
Expand Down
30 changes: 27 additions & 3 deletions kernels/quantized/test/op_dequantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

using namespace ::testing;
using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
Expand All @@ -43,7 +44,14 @@ void test_dtype() {
// (100 - 30) * 0.5
Tensor expected = tfo.full({3, 5}, 35);
dequantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, DTYPE, out);
input,
scale,
zero_point,
quant_min,
quant_max,
DTYPE,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
}
Expand All @@ -66,7 +74,14 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) {
// (100 - 30) * 0.5
Tensor expected = tfo.full({3, 5}, 31.5);
dequantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
input,
scale,
zero_point,
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
}
Expand All @@ -87,7 +102,14 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
// (100 - 30) * 0.5
Tensor expected = tfo.full({3, 5}, 31.5);
dequantize_per_tensor_tensor_args_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
input,
scale,
zero_point,
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
}
Expand Down Expand Up @@ -116,6 +138,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
Expand All @@ -136,6 +159,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
Expand Down
2 changes: 2 additions & 0 deletions kernels/quantized/test/op_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

using namespace ::testing;
using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::RuntimeContext;
using exec_aten::Scalar;
using exec_aten::ScalarType;
Expand Down Expand Up @@ -149,6 +150,7 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
weight);

embedding_out(
Expand Down

0 comments on commit 72cc9e0

Please sign in to comment.