forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Conversion from FP32 model to Mixed Precision model (apache#15118)
* Initial AMP commit * Fix * Merge AMP Changes * AMP Changes to support conditional op names switch * Add example and fix issues with AMP conversion * Remove amp convert symbol test * Fix comment for inference use case * Remove input_names for convert_hybrid_block * Check all conditions * Fix lint * Fix error_str for load_dict * Fix lint, Add tests, fix bugs, add examples * Fix warnings * Add license for example script * Remove gpu test and move tests to test_contrib_amp * Clean up AMP tests * Add additional comments, add tutorial * Move the test to gpu dir * Make the code python3 compatible * Upgrade archive utility, fixes: apache#15084 * Allow AR path to be chosen by user * Use current_context in tutorial * Update __all__ * Merge with load params API changes * Revert "Allow AR path to be chosen by user" This reverts commit 94156b6. * Revert "Upgrade archive utility, fixes: apache#15084" This reverts commit ea7dd32. * Set numpy dtype to float32 * Address review comments * Add range based for * Change quantized to low precision * Fix lint * Fix pylint * Forward args for Node::Create * Fixes * Add dtype casting wherever needed * Fix lint in source * Add cast_optional_params to example * Tweak example * Add README * Add README * Add cast_optional_params test for convert_model and convert_hybrid_bloc
- Loading branch information
1 parent
11e6d45
commit ca565a0
Showing
15 changed files
with
1,766 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
<!--- Licensed to the Apache Software Foundation (ASF) under one --> | ||
<!--- or more contributor license agreements. See the NOTICE file --> | ||
<!--- distributed with this work for additional information --> | ||
<!--- regarding copyright ownership. The ASF licenses this file --> | ||
<!--- to you under the Apache License, Version 2.0 (the --> | ||
<!--- "License"); you may not use this file except in compliance --> | ||
<!--- with the License. You may obtain a copy of the License at --> | ||
|
||
<!--- http://www.apache.org/licenses/LICENSE-2.0 --> | ||
|
||
<!--- Unless required by applicable law or agreed to in writing, --> | ||
<!--- software distributed under the License is distributed on an --> | ||
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> | ||
<!--- KIND, either express or implied. See the License for the --> | ||
<!--- specific language governing permissions and limitations --> | ||
<!--- under the License. --> | ||
|
||
# Conversion of FP32 models to Mixed Precision Models | ||
|
||
|
||
This folder contains examples for converting FP32 models to mixed precision models. The script allows for converting FP32 symbolic models or gluon models to mixed precision model. | ||
|
||
## Basic Usages | ||
|
||
1. AMP Model Conversion for a gluon model, casting the params wherever possible to FP16. The below script will convert the `resnet101_v1` model to Mixed Precision Model and cast params to FP16 wherever possible, load this converted model and run inference on it. | ||
|
||
```bash | ||
python amp_model_conversion.py --model resnet101_v1 --use-gluon-model --run-dummy-inference --cast-optional-params | ||
``` | ||
|
||
2. AMP Model Conversion for a symbolic model, keeping the params in FP32 wherever possible (--cast-optional-params not used). | ||
|
||
```bash | ||
python amp_model_conversion.py --model imagenet1k-resnet-152 --run-dummy-inference | ||
``` |
119 changes: 119 additions & 0 deletions
119
example/automatic-mixed-precision/amp_model_conversion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import os | ||
import logging | ||
import argparse | ||
import mxnet as mx | ||
from common import modelzoo | ||
import gluoncv | ||
from gluoncv.model_zoo import get_model | ||
from mxnet.contrib.amp import amp | ||
import numpy as np | ||
|
||
def download_model(model_name, logger=None): | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_path = os.path.join(dir_path, 'model') | ||
if logger is not None: | ||
logger.info('Downloading model {}... into path {}'.format(model_name, model_path)) | ||
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) | ||
|
||
|
||
def save_symbol(fname, sym, logger=None): | ||
if logger is not None: | ||
logger.info('Saving symbol into file at {}'.format(fname)) | ||
sym.save(fname, remove_amp_cast=False) | ||
|
||
|
||
def save_params(fname, arg_params, aux_params, logger=None): | ||
if logger is not None: | ||
logger.info('Saving params into file at {}'.format(fname)) | ||
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()} | ||
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()}) | ||
mx.nd.save(fname, save_dict) | ||
|
||
|
||
if __name__ == '__main__': | ||
symbolic_models = ['imagenet1k-resnet-152', | ||
'imagenet1k-resnet-18', | ||
'imagenet1k-resnet-34', | ||
'imagenet1k-resnet-50', | ||
'imagenet1k-resnet-101', | ||
'imagenet1k-resnext-50', | ||
'imagenet1k-resnext-101', | ||
'imagenet1k-resnext-101-64x4d', | ||
'imagenet11k-place365ch-resnet-152', | ||
'imagenet11k-place365ch-resnet-50'] | ||
gluon_models = ['resnet18_v1', | ||
'resnet50_v1', | ||
'resnet101_v1', | ||
'squeezenet1.0', | ||
'mobilenet1.0', | ||
'mobilenetv2_1.0', | ||
'inceptionv3'] | ||
models = symbolic_models + gluon_models | ||
|
||
parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model') | ||
parser.add_argument('--model', type=str, choices=models) | ||
parser.add_argument('--run-dummy-inference', action='store_true', default=False, | ||
help='Will generate random input of shape (1, 3, 224, 224) ' | ||
'and run a dummy inference forward pass') | ||
parser.add_argument('--use-gluon-model', action='store_true', default=False, | ||
help='If enabled, will download pretrained model from Gluon-CV ' | ||
'and convert to mixed precision model ') | ||
parser.add_argument('--cast-optional-params', action='store_true', default=False, | ||
help='If enabled, will try to cast params to target dtype wherever possible') | ||
args = parser.parse_args() | ||
logging.basicConfig() | ||
logger = logging.getLogger('logger') | ||
logger.setLevel(logging.INFO) | ||
|
||
if not args.use_gluon_model: | ||
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \ | ||
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models) | ||
|
||
prefix, epoch = download_model(model_name=args.model, logger=logger) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) | ||
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params, | ||
cast_optional_params=args.cast_optional_params) | ||
sym_name = "%s-amp-symbol.json" % (prefix) | ||
save_symbol(sym_name, result_sym, logger) | ||
param_name = '%s-%04d.params' % (prefix + '-amp', epoch) | ||
save_params(param_name, result_arg_params, result_aux_params, logger) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1") | ||
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0)) | ||
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) | ||
mod.set_params(arg_params, aux_params) | ||
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], | ||
label=[mx.nd.ones((1,))])) | ||
result = mod.get_outputs()[0].asnumpy() | ||
logger.info("Inference run successfully") | ||
else: | ||
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \ | ||
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models) | ||
net = gluoncv.model_zoo.get_model(args.model, pretrained=True) | ||
net.hybridize() | ||
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224))) | ||
net.export("{}".format(args.model)) | ||
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params) | ||
net.export("{}-amp".format(args.model), remove_amp_cast=False) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
logger.info("Inference run successfully") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../image-classification/common |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.