Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix SliceChannel Type inference (#16748)
Browse files Browse the repository at this point in the history
* Refactor elemwise_op_common and change SliceChannel InferType

* Add gluoncv models

* Comment Faster RCNN models
  • Loading branch information
anirudh2290 authored Nov 8, 2019
1 parent 5dfa121 commit a37dcd4
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 23 deletions.
139 changes: 132 additions & 7 deletions example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,129 @@ def save_params(fname, arg_params, aux_params, logger=None):
'imagenet1k-resnext-101-64x4d',
'imagenet11k-place365ch-resnet-152',
'imagenet11k-place365ch-resnet-50']
gluon_models = ['resnet18_v1',
# Faster RCNN and Mask RCNN commented because of model loading issues
# https://github.com/dmlc/gluon-cv/issues/1034
gluon_models = [#'faster_rcnn_fpn_resnet50_v1b_coco',
'mobilenetv2_0.75',
'cifar_resnet56_v1',
'mobilenet0.25',
'mobilenet1.0',
#'mask_rcnn_fpn_resnet50_v1b_coco',
'simple_pose_resnet152_v1b',
'ssd_512_resnet50_v1_voc',
#'faster_rcnn_resnet50_v1b_voc',
'cifar_resnet20_v1',
'yolo3_darknet53_voc',
'resnet101_v1c',
'simple_pose_resnet18_v1b',
#'mask_rcnn_resnet50_v1b_coco',
'ssd_512_mobilenet1.0_coco',
'vgg19_bn',
#'faster_rcnn_resnet50_v1b_coco',
'cifar_resnet110_v1',
'yolo3_mobilenet1.0_voc',
'cifar_resnext29_16x64d',
'resnet34_v1',
'densenet121',
#'mask_rcnn_fpn_resnet101_v1d_coco',
'vgg13_bn',
'vgg19',
'resnet152_v1d',
'resnet152_v1s',
'densenet201',
'alexnet',
'se_resnext50_32x4d',
'resnet50_v1d_0.86',
'resnet18_v1b_0.89',
'yolo3_darknet53_coco',
'resnet152_v1',
'resnext101_64x4d',
'vgg13',
'resnet101_v1d_0.76',
'simple_pose_resnet50_v1d',
'senet_154',
'resnet50_v1',
'resnet101_v1',
'se_resnext101_32x4d',
'fcn_resnet101_voc',
'resnet152_v2',
#'mask_rcnn_resnet101_v1d_coco',
'squeezenet1.1',
'mobilenet0.5',
'resnet34_v2',
'resnet18_v1',
'resnet152_v1b',
'resnet101_v2',
'cifar_resnet56_v2',
'ssd_512_resnet101_v2_voc',
'resnet50_v1d_0.37',
'mobilenetv2_0.5',
#'faster_rcnn_fpn_bn_resnet50_v1b_coco',
'resnet50_v1c',
'densenet161',
'simple_pose_resnet50_v1b',
'resnet18_v1b',
'darknet53',
'fcn_resnet50_ade',
'cifar_wideresnet28_10',
'simple_pose_resnet101_v1d',
'vgg16',
'ssd_512_resnet50_v1_coco',
'resnet101_v1d_0.73',
'squeezenet1.0',
'mobilenet1.0',
'resnet50_v1b',
#'faster_rcnn_resnet101_v1d_coco',
'ssd_512_mobilenet1.0_voc',
'cifar_wideresnet40_8',
'cifar_wideresnet16_10',
'cifar_resnet110_v2',
'resnet101_v1s',
'mobilenetv2_0.25',
'resnet152_v1c',
'se_resnext101_64x4d',
#'faster_rcnn_fpn_resnet101_v1d_coco',
'resnet50_v1d',
'densenet169',
'resnet34_v1b',
'resnext50_32x4d',
'resnet101_v1',
'resnet101_v1b',
'resnet50_v1s',
'mobilenet0.75',
'cifar_resnet20_v2',
'resnet101_v1d',
'vgg11_bn',
'resnet18_v2',
'vgg11',
'simple_pose_resnet101_v1b',
'resnext101_32x4d',
'resnet50_v2',
'vgg16_bn',
'mobilenetv2_1.0',
'inceptionv3']
'resnet50_v1d_0.48',
'resnet50_v1d_0.11',
'fcn_resnet101_ade',
'simple_pose_resnet152_v1d',
'yolo3_mobilenet1.0_coco',
'fcn_resnet101_coco']
# TODO(anisub): add support for other models from gluoncv
# Not supported today mostly because of broken net.forward calls
segmentation_models = ['deeplab_resnet50_ade',
'psp_resnet101_voc',
'deeplab_resnet152_voc',
'deeplab_resnet101_ade',
'deeplab_resnet152_coco',
'psp_resnet101_ade',
'deeplab_resnet101_coco',
'psp_resnet101_citys',
'psp_resnet50_ade',
'psp_resnet101_coco',
'deeplab_resnet101_voc']
calib_ssd_models = ["ssd_512_vgg16_atrous_voc",
"ssd_300_vgg16_atrous_voc",
"ssd_300_vgg16_atrous_coco"]
calib_inception_models = ["inceptionv3"]
gluon_models = gluon_models + segmentation_models + \
calib_ssd_models + calib_inception_models
models = symbolic_models + gluon_models

parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
Expand Down Expand Up @@ -106,14 +222,23 @@ def save_params(fname, arg_params, aux_params, logger=None):
else:
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
shape = None
if args.model in segmentation_models:
shape = (1, 3, 480, 480)
elif args.model in calib_ssd_models:
shape = (1, 3, 512, 544)
elif args.model in calib_inception_models:
shape = (1, 3, 299, 299)
else:
shape = (1, 3, 224, 224)
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
net.hybridize()
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
result_before1 = net.forward(mx.nd.random.uniform(shape=shape))
net.export("{}".format(args.model))
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
net.export("{}-amp".format(args.model), remove_amp_cast=False)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
logger.info("Inference run successfully")
27 changes: 21 additions & 6 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
index_t n_in = -1, index_t n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
inline bool ElemwiseAttrHelper(const std::string& node_name,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
Expand All @@ -133,7 +133,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, vec.at(i)))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string(vec.at(i));
}
Expand All @@ -145,7 +145,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(vec->at(i)), dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string(vec->at(i));
}
Expand All @@ -158,6 +158,21 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
return true;
}


template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
index_t n_in = -1, index_t n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
return ElemwiseAttrHelper<AttrType, is_none,
assign, reverse_infer,
attr_string, n_in,
n_out>(attrs.name, in_attrs, out_attrs, none);
}

template<index_t n_in, index_t n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down
15 changes: 5 additions & 10 deletions src/operator/slice_channel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <utility>
#include "./operator_common.h"
#include "./channel_op_common.h"
#include "./elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -176,16 +177,10 @@ class SliceChannelProp : public OperatorProperty {
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
out_type->clear();
out_type->reserve(param_.num_outputs);
for (int i = 0; i < param_.num_outputs; ++i) {
out_type->push_back(dtype);
}
aux_type->clear();
return true;
std::string node_name = "slice_channel_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string, 1>(node_name, in_type, out_type, -1);
}

bool InferShape(mxnet::ShapeVector *in_shape,
Expand Down
9 changes: 9 additions & 0 deletions tests/python/gpu/test_contrib_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ def test_fp16_casting():
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16

# Check for symbol which has slice channel
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data._set_attr(__dtype__="-1")
data2._set_attr(__dtype__="-1")
concat_res = mx.sym.concat(data, data2)
out = mx.sym.split(concat_res, axis=1, num_outputs=2)
final_res = amp.convert_symbol(out)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit a37dcd4

Please sign in to comment.