From eb94a1d19fcb3472f799050cf1996e22f6946001 Mon Sep 17 00:00:00 2001
From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com>
Date: Wed, 6 Mar 2024 13:55:50 -0800
Subject: [PATCH] Add Mistral fp16 config (#980)
## Describe your changes
Add float16 configuration to Mistral.
## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.
## (Optional) Issue link
---
docs/source/examples.md | 2 +-
examples/mistral/mistral.py | 93 +++++++++++++++++++
examples/mistral/mistral_fp16_optimize.json | 98 +++++++++++++++++++++
examples/mistral/readme.md | 14 ++-
4 files changed, 205 insertions(+), 2 deletions(-)
create mode 100644 examples/mistral/mistral.py
create mode 100644 examples/mistral/mistral_fp16_optimize.json
diff --git a/docs/source/examples.md b/docs/source/examples.md
index ff184e799..8ae99dca6 100644
--- a/docs/source/examples.md
+++ b/docs/source/examples.md
@@ -3,7 +3,7 @@
|Scenario| Model|Examples|Hardware Targeted Optimization|
|---|-----------|-----------|-----------|
|NLP|llama2|[Link](https://github.com/microsoft/Olive/tree/main/examples/llama2)|`CPU`: with ONNX Runtime optimizations for optimized FP32 ONNX model
`CPU`: with ONNX Runtime optimizations for optimized INT8 ONNX model
`CPU`: with ONNX Runtime optimizations for optimized INT4 ONNX model
`GPU`: with ONNX Runtime optimizations for optimized FP16 ONNX model
`GPU`: with ONNX Runtime optimizations for optimized INT4 ONNX model
`GPU`: with QLoRA for model fine tune and ONNX Runtime optimizations for optimized INT4 ONNX model
`AzureML compute`: with AzureML compute to fine tune and optimize for your local GPUs
-||mistral|[Link](https://github.com/microsoft/Olive/tree/main/examples/mistral)|`CPU`: with Optimum conversion and ONNX Runtime optimizations and Intel® Neural Compressor static quantization for optimized INT8 ONNX model
+||mistral|[Link](https://github.com/microsoft/Olive/tree/main/examples/mistral)|`CPU`: with Optimum conversion and ONNX Runtime optimizations and Intel® Neural Compressor static quantization for optimized INT8 ONNX model
`GPU` with ONNX Runtime optimizations fp16
||open llama|[Link](https://github.com/microsoft/Olive/tree/main/examples/open_llama)|`GPU`: with Optimum conversion and merging and ONNX Runtime optimizations for optimized ONNX model
`GPU`: with SparseGPT and TorchTRT conversion for an optimized PyTorch model with sparsity
`GPU`: with PyTorch LoRA/QLoRA/LoftQ for model fine tune
`GPU`: with ONNX Runtime QLoRA for model fine tune
`AzureML compute`: with Optimum conversion and merging and ONNX Runtime optimizations in AzureML
`CPU`: with Optimum conversion and merging and ONNX Runtime optimizations and Intel® Neural Compressor 4-bits weight-only quantization for optimized INT4 ONNX model
||phi|[Link](https://github.com/microsoft/Olive/tree/main/examples/phi)|`GPU`: with PyTorch QLoRA for model fine tune
||phi2|[Link](https://github.com/microsoft/Olive/tree/main/examples/phi2)|`CPU`: with ONNX Runtime optimizations fp32/int4
`GPU` with ONNX Runtime optimizations fp16/int4.
diff --git a/examples/mistral/mistral.py b/examples/mistral/mistral.py
new file mode 100644
index 000000000..bbdbe9ed4
--- /dev/null
+++ b/examples/mistral/mistral.py
@@ -0,0 +1,93 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import argparse
+import json
+import shutil
+from pathlib import Path
+
+import onnxruntime as ort
+import torch
+from transformers import AutoConfig, LlamaTokenizer
+
+from olive.workflows import run as olive_run
+
+# ruff: noqa: T201, T203
+
+
+def optimize(model_name: str, optimized_model_des: Path):
+ ort.set_default_logger_severity(4)
+ cur_dir = Path(__file__).resolve().parent
+
+ # Optimize the model with Olive
+ print(f"\nOptimizing {model_name}")
+
+ olive_config = None
+ with (cur_dir / "mistral_optimize.json").open() as fin:
+ olive_config = json.load(fin)
+
+ olive_config["input_model"]["config"]["model_path"] = model_name
+ olive_run(olive_config)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--optimize", action="store_true", help="Runs the optimization step")
+ parser.add_argument(
+ "--model-id",
+ dest="model_id",
+ type=str,
+ default="mistralai/Mistral-7B-v0.1",
+ help="Model Id to load",
+ )
+ parser.add_argument("--inference", action="store_true", help="Runs the inference step")
+ args = parser.parse_args()
+
+ script_dir = Path(__file__).resolve().parent
+ optimized_model_dir = script_dir / "models" / "convert-optimize-perf_tuning" / "mistral_gpu-cuda_model"
+
+ if args.optimize:
+ shutil.rmtree(optimized_model_dir, ignore_errors=True)
+
+ if args.optimize or not optimized_model_dir.exists():
+ optimize(args.model_id, optimized_model_dir)
+
+ if args.inference:
+ prompt = "Is it normal to have a dark ring around the iris of my eye?"
+
+ tokenizer = LlamaTokenizer.from_pretrained(args.model_id)
+ tokens = tokenizer(prompt, return_tensors="pt")
+ tokenizer = None
+
+ config = AutoConfig.from_pretrained(args.model_id)
+ num_heads = config.num_key_value_heads
+ head_size = config.hidden_size // config.num_attention_heads
+ past_seq_len = 0
+
+ position_ids = tokens.attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(tokens.attention_mask == 0, 1)
+
+ onnx_inputs = {
+ "input_ids": tokens.input_ids.numpy(),
+ "attention_mask": tokens.attention_mask.numpy(),
+ "position_ids": position_ids.numpy(),
+ }
+ for i in range(config.num_hidden_layers):
+ onnx_inputs[f"past_key_values.{i}.key"] = torch.rand(
+ 1, num_heads // 1, past_seq_len, head_size, dtype=torch.float16
+ ).numpy()
+ onnx_inputs[f"past_key_values.{i}.value"] = torch.rand(
+ 1, num_heads // 1, past_seq_len, head_size, dtype=torch.float16
+ ).numpy()
+
+ model_path = optimized_model_dir / "model.onnx"
+
+ session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
+ session.run(None, onnx_inputs)[0]
+
+ print("Inference test completed successfully!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/mistral/mistral_fp16_optimize.json b/examples/mistral/mistral_fp16_optimize.json
new file mode 100644
index 000000000..856db1846
--- /dev/null
+++ b/examples/mistral/mistral_fp16_optimize.json
@@ -0,0 +1,98 @@
+{
+ "input_model": {
+ "type": "PyTorchModel",
+ "config": {
+ "hf_config": {
+ "model_name": "mistralai/Mistral-7B-v0.1",
+ "model_class": "MistralForCausalLM"
+ }
+ }
+ },
+ "evaluators": {
+ "common_evaluator": {
+ "metrics": [
+ {
+ "name": "latency",
+ "type": "latency",
+ "sub_types": [
+ {
+ "name": "avg",
+ "priority": 1
+ }
+ ],
+ "user_config": {
+ "user_script": "user_script.py",
+ "dataloader_func": "create_dataloader",
+ "batch_size": 1,
+ "inference_settings": {
+ "onnx": {
+ "session_options": {
+ "enable_profiling": false
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ },
+ "passes": {
+ "convert": {
+ "type": "OptimumConversion",
+ "config": {
+ "target_opset": 14,
+ "extra_args": {
+ "legacy": false,
+ "no_post_process": false
+ }
+ }
+ },
+ "optimize": {
+ "type": "OrtTransformersOptimization",
+ "config": {
+ "model_type": "gpt2",
+ "use_gpu": true,
+ "keep_io_types": false,
+ "num_heads": 32,
+ "hidden_size": 4096,
+ "opt_level": 0,
+ "optimization_options": {
+ "use_multi_head_attention": false
+ },
+ "save_as_external_data": true,
+ "all_tensors_to_one_file": true,
+ "float16": true,
+ "use_gqa": true
+ }
+ },
+ "perf_tuning": {
+ "type": "OrtPerfTuning",
+ "config": {
+ "user_script": "user_script.py",
+ "dataloader_func": "create_dataloader",
+ "batch_size": 1,
+ "enable_profiling": false
+ }
+ }
+ },
+ "pass_flows": [
+ [
+ "convert",
+ "optimize",
+ "perf_tuning"
+ ]
+ ],
+ "engine": {
+ "evaluate_input_model": false,
+ "evaluator": "common_evaluator",
+ "cache_dir": "cache",
+ "output_name": "mistral",
+ "output_dir": "models",
+ "execution_providers": [
+ "CUDAExecutionProvider"
+ ],
+ "clean_cache": false,
+ "log_severity_level": 0,
+ "log_to_file": true
+ }
+}
diff --git a/examples/mistral/readme.md b/examples/mistral/readme.md
index 2a90ea0ee..2f4e5b71c 100644
--- a/examples/mistral/readme.md
+++ b/examples/mistral/readme.md
@@ -28,10 +28,22 @@ git config --system core.longpaths true
```
## Usage
+CPU:
```bash
-python -m olive.workflows.run --config mistral_optimize.json
+python mistral.py --optimize --config mistral_optimize.json
```
+GPU:
+```bash
+python mistral.py --optimize --config mistral_fp16_optimize.json
+```
+## Test Inference
+To test inference on the model run the script with `--inference`
+```bash
+CUDA_VISIBLE_DEVICES=6 python mistral.py --inference
+```
+Currently inference only supports float16 model running on gpu
+
### Local model
if the input model is saved locally, you can specify the configuration like the following:
```json