Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OnnxIOFloat16ToFloat32 Pass #1149

Merged
merged 20 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/api/passes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ OnnxFloatToFloat16
--------------------
.. autoconfigclass:: olive.passes.OnnxFloatToFloat16

.. _onnx_io_float16_to_float32:

OnnxIOFloat16ToFloat32
----------------------
.. autoconfigclass:: olive.passes.OnnxIOFloat16ToFloat32

.. _ort_mixed_precision:

OrtMixedPrecision
Expand Down
13 changes: 13 additions & 0 deletions docs/source/features/passes/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,19 @@ b. More fine-grained control of the conversion conditions is also possible:

See [Float16 Conversion](https://onnxruntime.ai/docs/performance/model-optimizations/float16.html#float16-conversion) for more detailed description of the available configuration parameters.

## Inputs/Outputs Float16 to Float32 Conversion

Certain environments such as Onnxruntime WebGPU prefers Float32 logits. The `OnnxIOFloat16ToFloat32` pass converts the inputs and outputs to use Float32 instead of Float16.

### Example Configuration

a. The most basic configuration, which is suitable for many models, leaves all configuration options set to their default values:
```json
{
"type": "OnnxIOFloat16ToFloat32"
}
```

## Mixed Precision Conversion
Converting model to mixed precision.

Expand Down
23 changes: 18 additions & 5 deletions examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"cuda_fp16": [["convert", "optimize_cuda", "perf_tuning"]],
"cuda_int4": [["convert", "optimize_cuda", "blockwise_quant_int4", "perf_tuning"]],
"slicegpt": [["slice"]],
"web": [["builder", "io_float16_to_float32"]],
}
SUPPORTED_INFERENCE_CONFIG = {
"cpu_fp32": {
Expand Down Expand Up @@ -54,6 +55,7 @@
DEVICE_TO_EP = {
"cpu": "CPUExecutionProvider",
"gpu": "CUDAExecutionProvider",
"web": "JsExecutionProvider",
}


Expand All @@ -64,8 +66,8 @@ def get_args(raw_args):
"--model_type",
type=str,
default=None,
choices=["cpu_fp32", "cpu_int4", "cuda_fp16", "cuda_int4"],
help="Choose from cpu_fp32, cpu_int4, cuda_fp16, cuda_int4",
choices=["cpu_fp32", "cpu_int4", "cuda_fp16", "cuda_int4", "web"],
help="Choose from cpu_fp32, cpu_int4, cuda_fp16, cuda_int4 or web",
)
parser.add_argument(
"--finetune_method",
Expand Down Expand Up @@ -141,11 +143,22 @@ def main(raw_args=None):
template_json["systems"]["local_system"]["config"]["accelerators"] = [
{"device": device, "execution_providers": [DEVICE_TO_EP[device.lower()]]}
]

new_json_file = f"phi2_genai_{device.lower()}.json"
new_json_file = "phi2_web.json"
with open(new_json_file, "w") as f:
json.dump(template_json, f, indent=4)
elif model_type == "web":
json_file_template = "phi2_genai.json"
with open(json_file_template) as f:
template_json = json.load(f)
template_json["passes"]["builder"]["config"]["precision"] = "int4"
template_json["systems"]["local_system"]["config"]["accelerators"] = [
{"device": "GPU", "execution_providers": ["JsExecutionProvider"]}
]
fl_type = {"type": "OnnxIOFloat16ToFloat32"}
template_json["passes"]["fp32_logits"] = fl_type
new_json_file = "phi2_web.json"
with open(new_json_file, "w") as f:
json.dump(template_json, f, indent=4)

else:
if not args.optimum_optimization and not args.slicegpt and version.parse(OrtVersion) < version.parse("1.18.0"):
# Check if onnxruntime version is supported
Expand Down
48 changes: 48 additions & 0 deletions examples/phi3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Phi3 optimization with Olive
This folder contains an example of optimizing [the Phi-3-Mini-4K-Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model in HF for different hardware targets with Olive.


## Prerequisites
* einops
* Pytorch: >=2.2.0 \
_The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs._
* [Package onnxruntime](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages): >=1.18.0
* [Package onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai): >=0.2.0. If you target GPU, pls install onnxruntime and onnxruntime-genai gpu packages.

Install the dependencies
```
pip install -r requirements.txt
```

## Usage
we will use the `phi3.py` script to generate optimized model for a chosen hardware target by running the following commands.

```
python phi3.py [--target HARDWARE_TARGET] [--precision DATA_TYPE] [--inference] [--prompt PROMPT] [--max_length LENGTH]

# Examples
python phi3.py --target web

python phi3.py --target mobile --inference --prompt "Write a story starting with once upon a time" --max_length 200
```

- `--target`: cpu, cuda, mobile, web
- `--precision`: optional. fp32, fp16, int4. fp32 or int4(default) for cpu target; fp32 or fp16 or int4(default) for gpu target; int4(default) for mobile or web
- `--inference`: run the optimized model, for non-web models inference.
- `--prompt`: optional, the prompt text fed into the model. Take effect only when `--inference` is set.
- `--max_length`: optional, the max length of the output from the model. Take effect only when `--inference` is set.


This script includes
1. Generate the Olive configuration file for your need including the chosen HW target, the preferred model precision.
2. Generate optimized model with Olive based on the configuration file for the chosen HW target
3. (optional) Inference the optimized model with ONNX Runtime Generation API. Not supported for web target


If you have an Olive configuration file, you can also run the olive command for model generation:
```
olive run [--config CONFIGURATION_FILE]

# Examples
olive run --config phi3_mobile_int4.json
```
212 changes: 212 additions & 0 deletions examples/phi3/phi3.py
jambayk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# -------------------------------------------------------------------------
Fixed Show fixed Hide fixed
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import argparse
import json
import time
from pathlib import Path

import onnxruntime_genai as og

from olive.workflows import run as olive_run

# flake8: noqa: T201


TARGETS = ["cpu", "gpu", "mobile", "web"]
jambayk marked this conversation as resolved.
Show resolved Hide resolved

TARGET_TO_EP = {
"cpu": "CPUExecutionProvider",
"mobile": "CPUExecutionProvider",
"gpu": "CUDAExecutionProvider",
jambayk marked this conversation as resolved.
Show resolved Hide resolved
"web": "JsExecutionProvider",
}


def get_args(raw_args):
parser = argparse.ArgumentParser(description="phi3 optimization")

parser.add_argument(
"--target",
type=str,
default=None,
jambayk marked this conversation as resolved.
Show resolved Hide resolved
choices=["cpu", "cuda", "mobile", "web"],
jambayk marked this conversation as resolved.
Show resolved Hide resolved
help="Choose from cpu, cuda, mobile or web",
)
parser.add_argument(
"--precision",
type=str,
default=None,
jambayk marked this conversation as resolved.
Show resolved Hide resolved
choices=["fp32", "fp16", "int4"],
help="Choose from fp32 or int4(default) for cpu target; "
"fp32 or fp16 or int4(default) for gpu target; int4(default) for mobile or web",
)
parser.add_argument(
"--inference",
action="store_true",
help="Run inference with optimized model",
)
parser.add_argument(
"--prompt",
nargs="*",
type=str,
default=["Write a joke"],
help="The prompt text fed into the model. Not supported with Web target.",
)
parser.add_argument(
"--max_length",
type=int,
default=200,
help="Max length for generation. Not supported with Web target.",
)

return parser.parse_args(raw_args)


def main(raw_args=None):
args = get_args(raw_args)
if not args.target:
raise ValueError("Please specify target")

if not args.precision:
args.precision = "int4"
elif args.target in ("mobile", "web") and args.precision != "int4":
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
jambayk marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("mobile or web only supports int4(default)")
elif args.target == "cpu" and args.precision == "fp16":
raise ValueError("Choose from fp32 or int4(default) for cpu target")

if args.inference and args.target == "web":
raise ValueError("Web model inference is not supported in this script")

# Generate Olive configuration file for specific target
print("\nGenerating Olive configuration file...")
config_file = generate_config(args)
print("Olive configuration file is generated...\n")

# Generate optimized model for specific target
print("Generating optimized model for", args.target, " ...\n")
footprints = olive_run(config_file)
if footprints:
print("\nOptimized model is generated...")

if args.inference:
prompts = "Write a joke" if not args.prompt else "".join(args.prompt)

chat_template = "<|user|>\n{input}<|end|>\n<|assistant|>"
prompts = f"{chat_template.format(input=prompts)}"

max_length = 200 if not args.max_length else args.max_length

output_model_path = get_output_model_path(footprints)
genai_run(prompts, str(output_model_path), max_length)


def generate_config(args):

json_file_template = "phi3_template.json"
with open(json_file_template) as f:
template_json = json.load(f)

target = str(args.target)
device = "GPU" if target in ("cuda", "web") else "CPU"
execution_providers = [TARGET_TO_EP[target.lower()]]
template_json["systems"]["local_system"]["config"]["accelerators"] = [
{"device": device, "execution_providers": execution_providers}
]

model_builder = {
"type": "ModelBuilder",
"config": {
"precision": args.precision,
},
}
template_json["passes"]["builder"] = model_builder

if target == "mobile":
template_json["passes"]["builder"]["config"]["int4_accuracy_level"] = 4

elif target == "web":
fl_type = {"type": "OnnxIOFloat16ToFloat32"}
template_json["passes"]["fp32_logits"] = fl_type

new_json_file = f"phi3_{target.lower()}_{args.precision}.json"
with open(new_json_file, "w") as f:
json.dump(template_json, f, indent=4)

return new_json_file


def get_output_model_path(footprints):
# only one model output in phi2 optimization
for footprint in footprints.values():
for model_id in footprint.nodes:
model_path = Path(footprint.get_model_path(model_id))
break
return model_path


def genai_run(prompt, model_path, max_length):

print("\nModel inference starts...")

print("Loading model...")
app_started_timestamp = time.time()
model = og.Model(model_path)
model_loaded_timestamp = time.time()
print("Model loaded in {:.2f} seconds".format(model_loaded_timestamp - app_started_timestamp))

print("Creating tokenizer...")
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
input_tokens = tokenizer.encode(prompt)
started_timestamp = time.time()

print("Creating generator ...")
params = og.GeneratorParams(model)
# optimal search options for Phi3
search_options = {
"max_length": max_length,
"top_k": 40,
"top_p": 0.95,
"temperature": 0.8,
"repetition_penalty": 1.0,
}
params.set_search_options(**search_options)
params.input_ids = input_tokens
generator = og.Generator(model, params)
print("Generator created")

first = True
new_tokens = []

print("\n", prompt)

try:
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
if first:
first_token_timestamp = time.time()
first = False

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
new_tokens.append(new_token)
except KeyboardInterrupt:
print(" --control+c pressed, aborting generation--")

del generator

run_time = time.time() - started_timestamp
print(
"\n\n"
Fixed Show fixed Hide fixed
f"Prompt tokens: {len(input_tokens)}, New tokens: {len(new_tokens)},"
f" Time to first: {(first_token_timestamp - started_timestamp):.2f}s,"
f" New tokens per second: {len(new_tokens)/run_time:.2f} tps"
)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions examples/phi3/phi3_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_name": "microsoft/Phi-3-mini-4k-instruct",
"task": "text-generation",
"from_pretrained_args": {
"trust_remote_code": true
}
}
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
{
"device": "CPU",
"execution_providers": [
"CPUExecutionProvider"
]
}
]
}
}
},
"passes": {

},
"engine": {
"cache_dir": "cache",
jambayk marked this conversation as resolved.
Show resolved Hide resolved
"output_dir": "Opt_model"
}
}
5 changes: 5 additions & 0 deletions examples/phi3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
einops
onnx>=1.15.0
onnxscript>=0.1.0.dev20240126
torch>=2.2.0
transformers>=4.36.2
1 change: 1 addition & 0 deletions olive/hardware/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"MIGraphXExecutionProvider",
"TensorrtExecutionProvider",
"OpenVINOExecutionProvider",
"JsExecutionProvider",
],
"npu": ["QNNExecutionProvider"],
}
Loading
Loading