Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix QuantizedReshapeNode() function.
Browse files Browse the repository at this point in the history
  • Loading branch information
agrabow committed Feb 9, 2022
1 parent eece463 commit 2c27399
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/operator/quantization/quantized_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,24 @@ MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(_npx_quantized_reshape)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape<NumpyXReshapeShape>)
.add_arguments(NumpyXReshapeParam::__FIELDS__());

template <bool is_numpy_op>
enum ReshapeModule { NumPy = 0, NDArray = 1 };

inline const char* QuantizedReshapeModeMap(ReshapeModule module) {
switch (module) {
case ReshapeModule::NumPy:
return "_npx_quantized_reshape";
case ReshapeModule::NDArray:
return "_contrib_quantized_reshape";
default:
return nullptr;
}
}

template <ReshapeModule module>
nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
nnvm::ObjectPtr node = nnvm::Node::Create();

if constexpr (is_numpy_op) {
node->attrs.op = Op::Get("_npx_quantized_reshape");
} else {
node->attrs.op = Op::Get("_contrib_quantized_reshape");
}
node->attrs.op = Op::Get(QuantizedReshapeModeMap(module));
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;

Expand All @@ -112,9 +121,11 @@ nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
return node;
}

NNVM_REGISTER_OP(_npx_reshape).set_attr<FQuantizedOp>("FQuantizedOp", QuantizedReshapeNode<true>);
NNVM_REGISTER_OP(_npx_reshape)
.set_attr<FQuantizedOp>("FQuantizedOp", QuantizedReshapeNode<ReshapeModule::NumPy>);

NNVM_REGISTER_OP(Reshape).set_attr<FQuantizedOp>("FQuantizedOp", QuantizedReshapeNode<false>);
NNVM_REGISTER_OP(Reshape).set_attr<FQuantizedOp>("FQuantizedOp",
QuantizedReshapeNode<ReshapeModule::NDArray>);

} // namespace op
} // namespace mxnet

0 comments on commit 2c27399

Please sign in to comment.