Skip to content

Commit

Permalink
Dynamic quantization + minor improvements inference APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
dskhudia committed Aug 19, 2022
1 parent 5ec9c7a commit 7fbcae4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
3 changes: 2 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time,
get_file, is_tar)
from composer.utils.import_helpers import MissingConditionalImportError, import_object
from composer.utils.inference import export_for_inference, export_with_logger
from composer.utils.inference import export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import is_model_deepspeed, is_notebook
from composer.utils.object_store import (LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, S3ObjectStore,
Expand Down Expand Up @@ -39,6 +39,7 @@
'ensure_folder_has_no_conflicting_files',
'export_for_inference',
'export_with_logger',
'quantize_dynamic',
'format_name_with_dist',
'format_name_with_dist_and_time',
'is_tar',
Expand Down
28 changes: 21 additions & 7 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
import copy
import functools
import logging
import os
import tempfile
Expand All @@ -29,10 +30,24 @@

log = logging.getLogger(__name__)

__all__ = ['export_for_inference', 'ExportFormat', 'export_with_logger']
__all__ = ['export_for_inference', 'ExportFormat', 'export_with_logger', 'quantize_dynamic']

Transform = Callable[[nn.Module], nn.Module]

quantize_dynamic = functools.partial(torch.quantization.quantize_dynamic, qconfig_spec={torch.nn.Linear})
""" This is the most common way to use dynamic quantization.
Example:
from composer.utils import quantize_dynamic
export_for_inference(
...
transforms = [quantize_dynamic],
...
)
A user can always redefine it with extra options. This also serves as an example of what to pass to transforms.
"""


class ExportFormat(StringEnum):
"""Enum class for the supported export formats.
Expand Down Expand Up @@ -149,15 +164,14 @@ def export_for_inference(
try:
export_model = torch.jit.script(model)
except Exception as e:
log.warning(
'Scripting with torch.jit.script failed with the following exception. Trying torch.jit.trace!',
exc_info=True)
if sample_input is not None:
log.warning('Scripting with torch.jit.script failed. Trying torch.jit.trace!',)
export_model = torch.jit.trace(model, sample_input)
else:
raise RuntimeError(
'Scripting with torch.jit.script failed and sample inputs are not provided for tracing with torch.jit.trace'
) from e
log.warning(
'Scripting with torch.jit.script failed and sample inputs are not provided for tracing '
'with torch.jit.trace',
exc_info=True)

if export_model is not None:
torch.jit.save(export_model, local_save_path)
Expand Down
38 changes: 38 additions & 0 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,41 @@ def test_export_with_other_logger(model_cls, sample_input):
save_object_store=None,
sample_input=ANY,
transforms=None)


class LinModel(nn.Module):

def __init__(self):
super().__init__()
self.lin1 = nn.Linear(256, 128)
self.lin2 = nn.Linear(128, 256)

def forward(self, x):
x = self.lin1(x)
x = self.lin2(x)
return x


@pytest.mark.parametrize(
'model_cls',
[
(LinModel),
],
)
def test_dynamic_quantize(model_cls):
model = model_cls()
model.eval()

save_format = 'torchscript'
with tempfile.TemporaryDirectory() as tempdir:
save_path_no_quantize = os.path.join(tempdir, f'model_no_quantize.pt')
inference.export_for_inference(model=model, save_format=save_format, save_path=save_path_no_quantize)
save_path_quantize = os.path.join(tempdir, f'model_quantize.pt')
inference.export_for_inference(model=model,
save_format=save_format,
save_path=save_path_quantize,
transforms=[inference.quantize_dynamic])
no_quantize_size = os.path.getsize(save_path_no_quantize)
quantize_size = os.path.getsize(save_path_quantize)
# Size different should be almost 4x
assert no_quantize_size > 3 * quantize_size, "Quantization didn't work"

0 comments on commit 7fbcae4

Please sign in to comment.