Skip to content

Commit

Permalink
Merge pull request pytorch#121 from intel-staging/ov_torchort_integra…
Browse files Browse the repository at this point in the history
…tion_2

Documentation and sample code changes for torch_ort_inference
  • Loading branch information
askhade authored Jun 27, 2022
2 parents c318f37 + ca799ae commit 3599dcf
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 221 deletions.
33 changes: 24 additions & 9 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,32 +147,42 @@ To see torch-ort in action, see https://github.com/microsoft/onnxruntime-trainin

# Accelerate inference for PyTorch models with ONNX Runtime (Preview)

ONNX Runtime for PyTorch accelerates PyTorch model inference using ONNX Runtime.
ONNX Runtime for PyTorch is now extended to support PyTorch model inference using ONNX Runtime.

It is available via the torch-ort-inference python package. This preview package enables OpenVINO™ Execution Provider for ONNX Runtime by default for accelerating inference on various Intel CPUs and integrated GPUs.
It is available via the torch-ort-inference python package. This preview package enables OpenVINO™ Execution Provider for ONNX Runtime by default for accelerating inference on various Intel® CPUs, Intel® integrated GPUs, and Intel® Movidius™ Vision Processing Units - referred to as VPU.

This repository contains the source code for the package, as well as instructions for running the package.

## Prerequisites

- Ubuntu 18.04, 20.04

- Python* 3.7, 3.8 or 3.9

## Install in a local Python environment

By default, torch-ort-inference depends on PyTorch 1.12 and ONNX Runtime OpenVINO EP 1.12.

Install torch-ort-inference with OpenVINO dependencies
1. Install torch-ort-inference with OpenVINO dependencies.

- `pip install torch-ort-inference[openvino]`
- `pip install torch-ort-inference[openvino]`
<br/><br/>
2. Run post-installation script

## Verify your installation
- `python -m torch_ort.configure`

Once you have created your environment, using Python, execute the following steps to validate that your installation is correct.
## Verify your installation

1. Download a inference script
Once you have created your environment, execute the following steps to validate that your installation is correct.

- `wget https://raw.githubusercontent.com/pytorch/ort/main/torch_ort_inference/tests/bert_for_sequence_classification.py`
1. Clone this repo

- `git clone git@github.com:pytorch/ort.git`
<br/><br/>
2. Install extra dependencies

- `pip install wget pandas transformers`

<br/><br/>
3. Run the inference script

- `python ./ort/torch_ort_inference/tests/bert_for_sequence_classification.py`
Expand Down Expand Up @@ -204,6 +214,11 @@ If no provider options are specified by user, OpenVINO™ Execution Provider is
backend = "CPU"
precision = "FP32"
```
For more details on APIs, see [usage.md](/torch_ort_inference/docs/usage.md).

### Note

Currently, Vision models are supported on Intel® VPUs. Support for NLP models may be added in future releases.

## License

Expand Down
6 changes: 6 additions & 0 deletions torch_ort_inference/docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

You can install and run torch-ort-inference in your local environment.

## Prerequisites

- Ubuntu 18.04, 20.04

- Python* 3.7, 3.8 or 3.9

## Run in a Python environment

### Default dependencies
Expand Down
42 changes: 42 additions & 0 deletions torch_ort_inference/docs/usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# APIs for OpenVINO™ integration with TorchORT

This document describes available Python APIs for OpenVINO™ integration with TorchORT to accelerate inference for PyTorch models on various Intel hardware.

## Essential APIs

To add the OpenVINO™ integration with TorchORT package to your PyTorch application, add following 2 lines of code:

```python
from torch_ort import ORTInferenceModule
model = ORTInferenceModule(model)
```

By default, CPU backend with FP32 precision is enabled. You can set different backend and supported precision using OpenVINOProviderOptions as below:

```python
provider_options = OpenVINOProviderOptions(backend = "GPU", precision = "FP16")
model = ORTInferenceModule(model, provider_options = provider_options)
```
Supported backend-precision combinations:
| Backend | Precision |
| --------| --------- |
| CPU | FP32 |
| GPU | FP32 |
| GPU | FP16 |
| MYRIAD | FP16 |

## Additional APIs

To save the inline exported onnx model, use DebugOptions as below:

```python
debug_options = DebugOptions(save_onnx=True, onnx_prefix='<model_name>')
model = ORTInferenceModule(model, debug_options=debug_options)
```

To enable verbose log of the execution of the TorchORT pipeline, use DebugOptions as below:

```python
debug_options = DebugOptions(log_level=LogLevel.VERBOSE)
model = ORTInferenceModule(model, debug_options=debug_options)
```
140 changes: 93 additions & 47 deletions torch_ort_inference/tests/bert_for_sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import time
import pandas as pd
import pathlib

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
Expand All @@ -16,38 +17,65 @@
from torch_ort import ORTInferenceModule, OpenVINOProviderOptions

ov_backend_precisions = {"CPU": ["FP32"], "GPU": ["FP32", "FP16"]}

inference_execution_providers = ["openvino"]

def preprocess_input(tokenizer, sentences):
# Tokenization & Input Formatting
# Config: "do_lower_case": true, "model_max_length": 512
inputs = []

MAX_LEN = 64

for sentence in sentences:
tokenized_inputs = tokenizer(
sentence,
return_tensors="pt",
padding='max_length',
truncation=True)
inputs.append(tokenized_inputs)
# `encode` will:
# (1) Tokenize the sentence.
# (2) Prepend the `[CLS]` token to the start.
# (3) Append the `[SEP]` token to the end.
# (4) Map tokens to their IDs.
encoded_sent = tokenizer.encode(
sentence, # Sentence to encode.
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
)

# Pad our input tokens with value 0.
if len(encoded_sent) < MAX_LEN:
encoded_sent.extend([0]*(MAX_LEN-len(encoded_sent)))

# Truncate to MAX_LEN
if len(encoded_sent) > MAX_LEN:
print("WARNING: During preprocessing, number of tokens for the sentence {}"\
"exceedeed MAX LENGTH {}. This might impact accuracy of the results".format(
sentence,
MAX_LEN
))
encoded_sent = encoded_sent[:MAX_LEN]

# Create the attention mask.
# - If a token ID is 0, then it's padding, set the mask to 0.
# - If a token ID is > 0, then it's a real token, set the mask to 1.
att_mask = [int(token_id > 0) for token_id in encoded_sent]

# Store the input ids and attention masks for the sentence.
inputs.append({'input_ids': torch.unsqueeze(torch.tensor(encoded_sent),0),
'attention_mask': torch.unsqueeze(torch.tensor(att_mask),0)})

return inputs


def infer(model, tokenizer, inputs):
def infer(model, sentences, inputs):
num_sentences = len(sentences)
total_infer_time = 0
results = {}

# Run inference
for i in range(len(inputs)):
for i in range(num_sentences):
input_ids = (inputs[i])['input_ids']
attention_masks = (inputs[i])['attention_mask']
with torch.no_grad():
# warm-up
if i == 0:
t0 = time.time()
model(input_ids, attention_masks)
print("warm up time:", time.time()-t0)
# infer
t0 = time.time()
outputs = model(input_ids, attention_masks)
Expand All @@ -63,18 +91,21 @@ def infer(model, tokenizer, inputs):

# predictions
pred_flat = np.argmax(logits, axis=1).flatten()
orig_sent = tokenizer.decode(input_ids[0],skip_special_tokens=True)
orig_sent = sentences[i]
results[orig_sent] = pred_flat[0]

print("\n Top (20) Results: \n")
print("\n Number of sentences: {}".format(num_sentences))
if num_sentences > 20:
print(" First 20 results:")
print("\t Grammar correctness label (0=unacceptable, 1=acceptable)\n")
count = 0
for k, v in results.items():
print("\t{!r} : {!r}".format(k, v))
if count == 20:
break
count = count + 1
print("\nInference time: {:.4f}s".format(total_infer_time))

print("\n Average inference time: {:.4f}ms".format((total_infer_time/num_sentences)*1000))
print(" Total Inference time: {:.4f}ms".format(total_infer_time * 1000))

def main():
# 1. Basic setup
Expand All @@ -85,7 +116,7 @@ def main():
"--pytorch-only",
action="store_true",
default=False,
help="disables ONNX Runtime",
help="disables ONNX Runtime inference",
)
parser.add_argument(
"--input",
Expand Down Expand Up @@ -119,25 +150,59 @@ def main():
if not args.pytorch_only:
if args.provider is None:
print("OpenVINOExecutionProvider is enabled with CPU and FP32 by default.")
if args.backend or args.precision:
raise ValueError("Provider not specified!! Please specify provider arg along with backend and precision.")
elif args.provider == "openvino":
if args.backend and args.precision:
if args.backend not in list(ov_backend_precisions.keys()):
raise Exception(
"Invalid backend. Valid values are:",
list(ov_backend_precisions.keys()),
)
raise ValueError(
"Invalid backend. Valid values are: {}".format(
list(ov_backend_precisions.keys())))
if args.precision not in ov_backend_precisions[args.backend]:
raise Exception("Invalid precision for provided backend. Valid values are:",
list(ov_backend_precisions[args.backend]))
else:
print(
"OpenVINOExecutionProvider is enabled with CPU and FP32 by default."
+ " Please specify both backend and precision to override.\n"
raise ValueError("Invalid precision for provided backend. Valid values are: {}".format(
list(ov_backend_precisions[args.backend])))
elif args.backend or args.precision:
raise ValueError(
"Please specify both backend and precision to override default options.\n"
)
else:
print("OpenVINOExecutionProvider is enabled with CPU and FP32 by default.")
else:
raise Exception("Invalid execution provider!!")
raise ValueError("Invalid execution provider!! Available providers are: {}".format(inference_execution_providers))
else:
print("ONNXRuntime inference is disabled.")
if args.provider or args.precision or args.backend:
raise ValueError("provider, backend, precision arguments are not applicable for --pytorch-only option.")

# 2. Read input sentence(s)
# Input can be a single sentence, list of single sentences in a .tsv file.
if args.input and args.input_file:
raise ValueError("Please provide either input or input file for inference.")

# 2. Load Model
if args.input is not None:
sentences = [args.input]
elif args.input_file is not None:
file_name = args.input_file
if not os.path.exists(file_name):
raise ValueError("Invalid input file path: %s" % file_name)
if os.stat(file_name).st_size == 0:
raise ValueError("Input file is empty!!")
name, ext = os.path.splitext(file_name)
if ext != ".tsv":
raise ValueError("Invalid input file format. Please provide .tsv file.")
df = pd.read_csv(
file_name,
delimiter="\t",
header=None,
names=["Id", "Sentence"],
skiprows=1,
)
sentences = df.Sentence.values
else:
print("Input not provided! Using default input...")
sentences = ["This is a BERT sample.","User input is valid not."]

# 3. Load Model
# Pretrained model fine-tuned on CoLA dataset from huggingface model hub to predict grammar correctness
model = AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-CoLA"
Expand All @@ -155,31 +220,12 @@ def main():
# Convert model for evaluation
model.eval()

# 3. Read input sentence(s)
# Input can be a single sentence, list of single sentences in a .tsv file.
if args.input is not None:
sentences = [args.input]
elif args.input_file is not None:
if not os.path.exists(args.input_file):
raise ValueError("Invalid input file path: %s" % args.input_file)
df = pd.read_csv(
args.input_file,
delimiter="\t",
header=None,
names=["Id", "Sentence"],
skiprows=1,
)
sentences = df.Sentence.values
else:
print("Input not provided! Using default input...")
sentences = ["This is a sample input."]

# 4. Load Tokenizer & Preprocess input sentences
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-CoLA")
inputs = preprocess_input(tokenizer, sentences)

# 5. Infer
infer(model, tokenizer, inputs)
infer(model, sentences, inputs)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 3599dcf

Please sign in to comment.