Skip to content

Commit

Permalink
Fix Onnx Export for Composer HuggingFaceModels (#1557)
Browse files Browse the repository at this point in the history
Fix dictionary input error, GPU export, and opset_version errors
  • Loading branch information
nik-mosaic authored Oct 1, 2022
1 parent 1b7ffce commit 0e5af6c
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 15 deletions.
2 changes: 1 addition & 1 deletion composer/callbacks/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def export_model(self, state: State, logger: Logger):
save_path=self.save_path,
logger=logger,
save_object_store=self.save_object_store,
sample_input=(self.sample_input,),
sample_input=(self.sample_input, {}),
transforms=self.transforms)
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2568,5 +2568,5 @@ def export_for_inference(
save_path=save_path,
logger=self.logger,
save_object_store=save_object_store,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
transforms=transforms)
36 changes: 32 additions & 4 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def export_for_inference(
save_path: str,
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
dynamic_axes: Optional[Any] = None,
surgery_algs: Optional[Union[Callable[[nn.Module], nn.Module], Sequence[Callable[[nn.Module], nn.Module]]]] = None,
transforms: Optional[Sequence[Transform]] = None,
load_path: Optional[str] = None,
Expand All @@ -86,6 +87,8 @@ def export_for_inference(
sample_input (Any, optional): Example model inputs used for tracing. This is needed for "onnx" export.
The ``sample_input`` need not match the batch size you intend to use for inference. However, the model
should accept the ``sample_input`` as is. (default: ``None``)
dynamic_axes (Any, optional): Dictionary specifying the axes of input/output tensors as dynamic. May be required
for exporting models using older versions of PyTorch when types cannot be inferred.
surgery_algs (Union[Callable, Sequence[Callable]], optional): Algorithms that should be applied to the model
before loading a checkpoint. Each should be callable that takes a model and returns modified model.
``surgery_algs`` are applied before ``transforms``. (default: ``None``)
Expand Down Expand Up @@ -118,12 +121,25 @@ def export_for_inference(
if dist.get_global_rank() != 0:
return

# make a copy of the model so that we don't modify the original model
# Make a copy of the model so that we don't modify the original model
model = copy.deepcopy(model)

# make a copy of the sample input so that we don't modify the original sample input
# Make a copy of the sample input so that we don't modify the original sample input
sample_input = copy.deepcopy(sample_input)

# Move model and sample input to CPU for export
cpu = torch.device('cpu')
model.to(device=cpu)
if sample_input is not None:
sample_input = ensure_tuple(sample_input)
for i in range(len(sample_input)):
if isinstance(sample_input[i], torch.Tensor):
sample_input[i] = sample_input[i].to(cpu) # type: ignore
elif isinstance(sample_input[i], dict):
for key, value in sample_input[i].items():
if isinstance(value, torch.Tensor):
sample_input[i][key] = value.to(cpu)

# Apply surgery algorithms in the given order
for alg in ensure_tuple(surgery_algs):
model = alg(model)
Expand Down Expand Up @@ -178,14 +194,26 @@ def export_for_inference(
if save_format == ExportFormat.ONNX:
if sample_input is None:
raise ValueError(f'sample_input argument is required for onnx export')
sample_input = ensure_tuple(sample_input)

input_names = []

# Extract input names from sample_input if it contains dicts
for i in range(len(sample_input)):
if isinstance(sample_input[i], dict):
input_names += list(sample_input[i].keys())

# Default input name if no dict present
if input_names == []:
input_names = ['input']

torch.onnx.export(
model,
sample_input,
local_save_path,
input_names=['input'],
input_names=input_names,
output_names=['output'],
dynamic_axes=dynamic_axes,
opset_version=13,
)

# upload if required.
Expand Down
10 changes: 9 additions & 1 deletion examples/exporting_for_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@
"print(f\"The predicted classes are {np.argmax(outputs[0], axis=1)}\")"
]
},
{
"cell_type": "markdown",
"id": "0bc52f62",
"metadata": {},
"source": [
"If our input is a dictionary, as if often the case when using a Composer [HuggingFaceModel](https://docs.mosaicml.com/en/stable/examples/huggingface_models.html), we'll need to make sure all the elements of our input dictionary are numpy arrays before calling `ort_session.run()`."
]
},
{
"cell_type": "markdown",
"id": "ca091f8e",
Expand Down Expand Up @@ -454,7 +462,7 @@
"export_for_inference(model=model, \n",
" save_format=save_format, \n",
" save_path=model_save_path, \n",
" sample_input=(input,),\n",
" sample_input=(input, {}),\n",
" surgery_algs=[cf.apply_squeeze_excite],\n",
" load_path=checkpoint_path)"
]
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def package_files(prefix: str, directory: str, extension: str):
]

extra_deps['onnx'] = [
'onnx>=1.11.0,<2',
'onnxruntime>=1.11.0,<2',
'onnx>=1.12.0,<2',
'onnxruntime>=1.12.1,<2',
]

extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_inference_callback_torchscript(model_cls):
save_path=save_path,
logger=trainer.logger,
save_object_store=None,
sample_input=(exp_for_inf_callback.sample_input,),
sample_input=(exp_for_inf_callback.sample_input, {}),
transforms=None)


Expand Down Expand Up @@ -78,5 +78,5 @@ def test_inference_callback_onnx(model_cls):
save_path=save_path,
logger=trainer.logger,
save_object_store=None,
sample_input=(exp_for_inf_callback.sample_input,),
sample_input=(exp_for_inf_callback.sample_input, {}),
transforms=None)
182 changes: 178 additions & 4 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.utils.data import DataLoader

from composer.core import State
from composer.functional import apply_gated_linear_units
from composer.loggers import InMemoryLogger, Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.models import composer_resnet
Expand Down Expand Up @@ -62,6 +63,181 @@ def test_export_for_inference_torchscript(model_cls, sample_input):
)


def test_huggingface_export_for_inference_onnx():
pytest.importorskip('onnx')
pytest.importorskip('onnxruntime')
pytest.importorskip('transformers')

import onnx
import onnx.checker
import onnxruntime as ort
import transformers

from composer.models import HuggingFaceModel

# HuggingFace Bert Model
# dummy sequence batch with 2 labels, 32 sequence length, and 30522 (bert) vocab size).
input_ids = torch.randint(low=0, high=30522, size=(2, 32))
labels = torch.randint(low=0, high=1, size=(2,))
token_type_ids = torch.zeros(size=(2, 32), dtype=torch.int64)
attention_mask = torch.randint(low=0, high=1, size=(2, 32))
sample_input = {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
}
dynamic_axes = {
'input_ids': {
0: 'batch_size',
1: 'seq_len'
},
'labels': {
0: 'batch_size'
},
'token_type_ids': {
0: 'batch_size',
1: 'seq_len'
},
'attention_mask': {
0: 'batch_size',
1: 'seq_len'
},
}
# non pretrained model to avoid a slow test that downloads the weights.
config = transformers.AutoConfig.from_pretrained('bert-base-uncased', num_labels=2, hidden_act='gelu_new')
hf_model = transformers.AutoModelForSequenceClassification.from_config(config) # type: ignore (thirdparty)

model = HuggingFaceModel(hf_model)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
apply_gated_linear_units(model, optimizer)

model.eval()

orig_out = model(sample_input)

save_format = 'onnx'
with tempfile.TemporaryDirectory() as tempdir:
save_path = os.path.join(tempdir, f'model.{save_format}')
inference.export_for_inference(
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input, {}),
dynamic_axes=dynamic_axes,
)
loaded_model = onnx.load(save_path)

onnx.checker.check_model(loaded_model)

ort_session = ort.InferenceSession(save_path)

for key, value in sample_input.items():
sample_input[key] = value.numpy()

loaded_model_out = ort_session.run(None, sample_input)

torch.testing.assert_close(
orig_out['logits'].detach().numpy(),
loaded_model_out[1],
rtol=1e-4, # lower tolerance for ONNX
atol=1e-3, # lower tolerance for ONNX
msg=f'output mismatch with {save_format}',
)


@pytest.mark.gpu
def test_gpu_huggingface_export_for_inference_onnx():
pytest.importorskip('onnx')
pytest.importorskip('onnxruntime')
pytest.importorskip('transformers')

import onnx
import onnx.checker
import onnxruntime as ort
import transformers

from composer.functional import apply_fused_layernorm
from composer.models import HuggingFaceModel

# HuggingFace Bert Model
# dummy sequence batch with 2 labels, 32 sequence length, and 30522 (bert) vocab size).
input_ids = torch.randint(low=0, high=30522, size=(2, 32))
labels = torch.randint(low=0, high=1, size=(2,))
token_type_ids = torch.zeros(size=(2, 32), dtype=torch.int64)
attention_mask = torch.randint(low=0, high=1, size=(2, 32))
sample_input = {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
}
dynamic_axes = {
'input_ids': {
0: 'batch_size',
1: 'seq_len'
},
'labels': {
0: 'batch_size'
},
'token_type_ids': {
0: 'batch_size',
1: 'seq_len'
},
'attention_mask': {
0: 'batch_size',
1: 'seq_len'
},
}
# non pretrained model to avoid a slow test that downloads the weights.
config = transformers.AutoConfig.from_pretrained('bert-base-uncased', num_labels=2, hidden_act='gelu_new')
hf_model = transformers.AutoModelForSequenceClassification.from_config(config) # type: ignore (thirdparty)

model = HuggingFaceModel(hf_model)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
apply_gated_linear_units(model, optimizer)
apply_fused_layernorm(model, optimizer)

model.eval()
orig_out = model(sample_input)

gpu = torch.device('cuda:0')
model.to(gpu)
for key, val in sample_input.items():
sample_input[key] = val.to(gpu)

save_format = 'onnx'
with tempfile.TemporaryDirectory() as tempdir:
save_path = os.path.join(tempdir, f'model.{save_format}')
inference.export_for_inference(
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input, {}),
dynamic_axes=dynamic_axes,
)
loaded_model = onnx.load(save_path)

onnx.checker.check_model(loaded_model)

ort_session = ort.InferenceSession(save_path)

for key, value in sample_input.items():
sample_input[key] = value.cpu().numpy()

loaded_model_out = ort_session.run(None, sample_input)

torch.testing.assert_close(
orig_out['logits'].detach().numpy(),
loaded_model_out[1],
rtol=1e-4, # lower tolerance for ONNX
atol=1e-3, # lower tolerance for ONNX
msg=f'output mismatch with {save_format}',
)


@pytest.mark.parametrize(
'model_cls, sample_input',
[
Expand All @@ -87,7 +263,7 @@ def test_export_for_inference_onnx(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
)
loaded_model = onnx.load(save_path)
onnx.checker.check_model(loaded_model)
Expand Down Expand Up @@ -152,7 +328,7 @@ def test_export_for_inference_onnx_ddp(model_cls, sample_input):
model=state.model.module,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
)

loaded_model = onnx.load(save_path)
Expand Down Expand Up @@ -247,7 +423,6 @@ def test_export_with_file_artifact_logger(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
logger=mock_logger,
)

Expand Down Expand Up @@ -292,7 +467,6 @@ def test_export_with_other_logger(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
logger=mock_logger,
)

Expand Down

0 comments on commit 0e5af6c

Please sign in to comment.