Skip to content

Commit

Permalink
Conversion from FP32 model to Mixed Precision model (apache#15118)
Browse files Browse the repository at this point in the history
* Initial AMP commit

* Fix

* Merge AMP Changes

* AMP Changes to support conditional op names switch

* Add example and fix issues with AMP conversion

* Remove amp convert symbol test

* Fix comment for inference use case

* Remove input_names for convert_hybrid_block

* Check all conditions

* Fix lint

* Fix error_str for load_dict

* Fix lint, Add tests, fix bugs, add examples

* Fix warnings

* Add license for example script

* Remove gpu test and move tests to test_contrib_amp

* Clean up AMP tests

* Add additional comments, add tutorial

* Move the test to gpu dir

* Make the code python3 compatible

* Upgrade archive utility, fixes: apache#15084

* Allow AR path to be chosen by user

* Use current_context in tutorial

* Update __all__

* Merge with load params API changes

* Revert "Allow AR path to be chosen by user"

This reverts commit 94156b6.

* Revert "Upgrade archive utility, fixes: apache#15084"

This reverts commit ea7dd32.

* Set numpy dtype to float32

* Address review comments

* Add range based for

* Change quantized to low precision

* Fix lint

* Fix pylint

* Forward args for Node::Create

* Fixes

* Add dtype casting wherever needed

* Fix lint in source

* Add cast_optional_params to example

* Tweak example

* Add README

* Add README

* Add cast_optional_params test for convert_model and convert_hybrid_bloc
  • Loading branch information
anirudh2290 authored and ptrendx committed Jun 28, 2019
1 parent 11e6d45 commit ca565a0
Show file tree
Hide file tree
Showing 15 changed files with 1,766 additions and 96 deletions.
42 changes: 42 additions & 0 deletions docs/tutorials/amp/amp_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ For demonstration purposes we will use synthetic data loader.


```python
import os
import logging
import warnings
import time
import mxnet as mx
import mxnet.gluon as gluon
from mxnet import autograd
from mxnet.test_utils import download_model
import gluoncv as gcv
from gluoncv.model_zoo import get_model

Expand Down Expand Up @@ -249,6 +251,46 @@ for epoch in range(1):

We got 60% speed increase from 3 additional lines of code!

## Inference with AMP

To do inference with mixed precision for a trained model in FP32, you can use the conversion APIs: `amp.convert_model` for symbolic model and `amp.convert_hybrid_block` for gluon models. The conversion APIs will take the FP32 model as input and will return a mixed precision model, which can be used to run inference. Below, we demonstrate for a gluon model and a symbolic model: 1. Conversion from FP32 model to mixed precision model 2. Run inference on the mixed precision model.

```python
with mx.Context(mx.gpu(0)):
# Below is an example of converting a gluon hybrid block to a mixed precision block
model = get_model("resnet50_v1")
model.collect_params().initialize(ctx=mx.current_context())
model.hybridize()
model(mx.nd.zeros((1, 3, 224, 224)))
converted_model = amp.convert_hybrid_block(model)

# Run dummy inference with the converted gluon model
result = converted_model.forward(mx.nd.random.uniform(shape=(1, 3, 224, 224),
dtype=np.float32))

# Below is an example of converting a symbolic model to a mixed precision model
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = mx.test_utils.download_model("imagenet1k-resnet-18", dst_dir=model_path)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
arg_params,
aux_params)

# Run dummy inference with the converted symbolic model
mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.current_context())
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(result_arg_params, result_aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
mod.get_outputs()[0].wait_to_read()
print("Conversion and Inference completed successfully")
```



## Current limitations of AMP

- AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set
Expand Down
35 changes: 35 additions & 0 deletions example/automatic-mixed-precision/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# Conversion of FP32 models to Mixed Precision Models


This folder contains examples for converting FP32 models to mixed precision models. The script allows for converting FP32 symbolic models or gluon models to mixed precision model.

## Basic Usages

1. AMP Model Conversion for a gluon model, casting the params wherever possible to FP16. The below script will convert the `resnet101_v1` model to Mixed Precision Model and cast params to FP16 wherever possible, load this converted model and run inference on it.

```bash
python amp_model_conversion.py --model resnet101_v1 --use-gluon-model --run-dummy-inference --cast-optional-params
```

2. AMP Model Conversion for a symbolic model, keeping the params in FP32 wherever possible (--cast-optional-params not used).

```bash
python amp_model_conversion.py --model imagenet1k-resnet-152 --run-dummy-inference
```
119 changes: 119 additions & 0 deletions example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import logging
import argparse
import mxnet as mx
from common import modelzoo
import gluoncv
from gluoncv.model_zoo import get_model
from mxnet.contrib.amp import amp
import numpy as np

def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model {}... into path {}'.format(model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def save_symbol(fname, sym, logger=None):
if logger is not None:
logger.info('Saving symbol into file at {}'.format(fname))
sym.save(fname, remove_amp_cast=False)


def save_params(fname, arg_params, aux_params, logger=None):
if logger is not None:
logger.info('Saving params into file at {}'.format(fname))
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)


if __name__ == '__main__':
symbolic_models = ['imagenet1k-resnet-152',
'imagenet1k-resnet-18',
'imagenet1k-resnet-34',
'imagenet1k-resnet-50',
'imagenet1k-resnet-101',
'imagenet1k-resnext-50',
'imagenet1k-resnext-101',
'imagenet1k-resnext-101-64x4d',
'imagenet11k-place365ch-resnet-152',
'imagenet11k-place365ch-resnet-50']
gluon_models = ['resnet18_v1',
'resnet50_v1',
'resnet101_v1',
'squeezenet1.0',
'mobilenet1.0',
'mobilenetv2_1.0',
'inceptionv3']
models = symbolic_models + gluon_models

parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
parser.add_argument('--model', type=str, choices=models)
parser.add_argument('--run-dummy-inference', action='store_true', default=False,
help='Will generate random input of shape (1, 3, 224, 224) '
'and run a dummy inference forward pass')
parser.add_argument('--use-gluon-model', action='store_true', default=False,
help='If enabled, will download pretrained model from Gluon-CV '
'and convert to mixed precision model ')
parser.add_argument('--cast-optional-params', action='store_true', default=False,
help='If enabled, will try to cast params to target dtype wherever possible')
args = parser.parse_args()
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

if not args.use_gluon_model:
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models)

prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params,
cast_optional_params=args.cast_optional_params)
sym_name = "%s-amp-symbol.json" % (prefix)
save_symbol(sym_name, result_sym, logger)
param_name = '%s-%04d.params' % (prefix + '-amp', epoch)
save_params(param_name, result_arg_params, result_aux_params, logger)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1")
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0))
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(arg_params, aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
result = mod.get_outputs()[0].asnumpy()
logger.info("Inference run successfully")
else:
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
net.hybridize()
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
net.export("{}".format(args.model))
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
net.export("{}-amp".format(args.model), remove_amp_cast=False)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
logger.info("Inference run successfully")
1 change: 1 addition & 0 deletions example/automatic-mixed-precision/common
49 changes: 49 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,55 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_ha
const mx_uint num_offline, const char **offline_params,
const char *quantized_dtype, const bool calib_quantize);

/*!
* \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting
* \param sym_handle symbol to be converted
* \param ret_sym_handle mixed precision symbol result
* \param num_args number of arguments for known dtypes
* \param arg_type_data arg types of the arguments
* \param target_dtype target_dtype for mixed precision symbol
* \param cast_optional_params whether to cast optional params to target_dtype
* \param num_target_dtype_op_names number of ops to be casted to target_dtype
* \param num_fp32_op_names number of ops to be casted to FP32
* \param num_widest_dtype_op_names number of ops to be casted to widest dtype
* \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition
* \param num_excluded_symbols number of symbols to be excluded from casting
* \param num_model_params number of model parameters
* \param num_widest_dtype_op_names number of ops to be casted to the widest dtype
* \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision
* \param target_dtype_op_names op names to be casted to target_dtype
* \param fp32_op_names op names to be casted to fp32
* \param widest_dtype_op_names names to be casted to widest dtype
* \param conditional_fp32_op_names names to be casted to FP32 conditionally
* \param excluded_symbols symbol names to be excluded from casting
* \param param_names param names for conditional FP32 casting
* \param param_values param values for conditional FP32 casting
* \param arg_names argument names for which type information is provided
* \param model_param_names names for model parameters
*/
MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
mx_uint num_args,
const int* arg_type_data,
mx_uint num_ind_ptr,
const int* ind_ptr,
const int* target_dtype,
const int cast_optional_params,
const mx_uint num_target_dtype_op_names,
const mx_uint num_fp32_op_names,
const mx_uint num_widest_dtype_op_names,
const mx_uint num_conditional_fp32_op_names,
const mx_uint num_excluded_symbols,
const mx_uint num_model_params,
const char **target_dtype_op_names,
const char **fp32_op_names,
const char **widest_dtype_op_names,
const char **conditional_fp32_op_names,
const char **excluded_symbols,
const char **conditional_param_names,
const char **conditional_param_vals,
const char **model_param_names,
const char **arg_names);
/*!
* \brief Set calibration table to node attributes in the sym
* \param sym_handle symbol whose node attributes are to be set by calibration table
Expand Down
Loading

0 comments on commit ca565a0

Please sign in to comment.