Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x] TensorRT: add INT8 with calibration #19011

Merged
merged 1 commit into from
Sep 15, 2020
Merged

Conversation

Kh4L
Copy link
Contributor

@Kh4L Kh4L commented Aug 25, 2020

Description

This PR adds INT8 with calibration support to MXNet-TensorRT.
It enables TensorRT internal optimization to create an INT8 engine (that will contain some INT8 kernels, if they are faster than the FP16 or FP32 ones).
In this first version, the quantization and de-quantization values are computed during the calibration phase. During this phase (of a number of iterations set by the calibration_iters), the user is expect to provide samples representing the inference data, used to calibrate the engine. The inference model is slower during this phase.
Once the calibration is done, the MXNet-TensorRT inference model is ready for fast inference with INT8.

Saving and loading of the calibration tables will be added in a later PR.

Usage

We set calibration_iters to the number of batches we can feed with the calibration dataset.
For instance:

calibration_iters = 500

We call optimize_for:
Symbolic

trt_sym = sym.optimize_for('TensorRT', args=arg_params, aux=aux_params, ctx=mx.gpu(0),
                           int8_mode=True,
                           calibration_iters=calibration_iters)

Gluon

model.optimize_for(input_data, backend='TensorRT', static_alloc=True, static_shape=True,
                   backend_opts={'int8_mode':True, 'calibration_iters':calibration_iters})

We create the executor and we feed the calibration data:

# (len(calibration_dataset) / batch_size)  == calibration_iters
for d in calibration_dataset:
    executor.forward(is_train=False, data=data)

The calibration is slower than regular inference. Once it's done, we get a info message on stdout.
The executor with TRT INT8 engines is ready!

@mxnet-bot
Copy link

Hey @Kh4L , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [website, centos-cpu, unix-gpu, miscellaneous, windows-gpu, centos-gpu, clang, unix-cpu, windows-cpu, edge, sanity]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3}
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100==0.24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu could you assist reviewing this PR? I think the structure wrt installation of packages and creating a new job might not be in line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this address returns 404 for me: https://developer.download.nvidia.com/compute/redist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://developer.download.nvidia.com/compute/redist isn't supposed to be accessed by itself
The pip install line is the standard way of installing DALI:
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html

@marcoabreu
Copy link
Contributor

@KellenSunderland

@Kh4L Kh4L force-pushed the backport_int8 branch 6 times, most recently from 8090dd1 to 99efa1c Compare September 1, 2020 06:34
@Kh4L
Copy link
Contributor Author

Kh4L commented Sep 1, 2020

@mxnet-bot run ci [sanity]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [sanity]

@Kh4L Kh4L changed the title [WIP] TensorRT: add INT8 with calibration TensorRT: add INT8 with calibration Sep 1, 2020
@Kh4L
Copy link
Contributor Author

Kh4L commented Sep 2, 2020

@mxnet-bot run ci [unix-gpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-gpu]

Copy link
Contributor

@samskalicky samskalicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed subgraph property portions, looks good to me. Please find someone to review the TRT-specific changes (maybe @ptrendx or @KellenSunderland) and the CI changes (maybe @josephevans?) too.

pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100==0.24
wget -nc http://data.mxnet.io/data/val_256_q90.rec
python3.6 tests/python/tensorrt/rec2idx.py val_256_q90.rec val_256_q90.idx
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in mxnet we switched to pytest and are no longer using nose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that discussion, can you point me to the RFC so I can catch up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to add the [1.x] tag, this is 1.x PR, where we still use nosetests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha is this fine to leave it as nosetests as this is the 1.8 branch?

@Kh4L Kh4L changed the title TensorRT: add INT8 with calibration [1.x] TensorRT: add INT8 with calibration Sep 4, 2020
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DALI has this script, why including it here too?

Signed-off-by: Serge Panev <spanev@nvidia.com>
Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ptrendx ptrendx merged commit 606933f into apache:v1.x Sep 15, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants