Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🌵 Olive kv_cache_config under io_config #1121

Merged
merged 10 commits into from
May 7, 2024
36 changes: 35 additions & 1 deletion docs/source/overview/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,41 @@ find more details in [Olive Models](https://microsoft.github.io/Olive/api/models
- `input_shapes: [List[List[int]]]` The input shapes of the model.
- `output_names: [List[str]]` The output names of the model.
- `dynamic_axes: [Dict[str, Dict[str, str]]]` The dynamic axes of the model. The key is the name of the input or output and the value is a dictionary that contains the dynamic axes of the input or output. The key of the value dictionary is the index of the dynamic axis and the value is the name of the dynamic axis. For example, `{"input": {"0": "batch_size"}, "output": {"0": "batch_size"}}` means the first dimension of the input and output is dynamic and the name of the dynamic axis is `batch_size`.

- `string_to_int_dim_params: List[str]` The list of input names in dynamic axes that need to be converted to int value.
- `kv_cache_config: Union[bool, Dict[str, str]]` The key value cache configuration.
- If it is `False`, Olive will not use key value cache.
- If it is `True`, Olive will infer the cache configuration from the input_names/input_shapes and input model based on default `kv_cache_config`.
- If it is a dictionary, it should contains the key value cache configuration. Here is an default configuration example:
- `ort_past_key_name`: "past_key_<id>"
Template for the past key name. The `<id>` will be replaced by the id of the past key.
- `ort_past_value_name`: "past_value_<id>"
Template for the past value name. The `<id>` will be replaced by the id of the past value.
- `ort_present_key_name`: "present_key_<id>"
Template for the present key name. The `<id>` will be replaced by the id of the present key.
- `ort_present_value_name`: "present_value_<id>"
Template for the present value name. The `<id>` will be replaced by the id of the present value.
- `world_size`: 1
It is only used for distributed models.
- `num_hidden_layers`: null
If null, Olive will infer the number of hidden layers from the model.
- `num_attention_heads`: null
If null, Olive will infer the number of attention heads from the model.
- `hidden_size`: null
If null, Olive will infer the hidden size from the model.
- `past_sequence_length`: null
If null, Olive will infer the past sequence length from the model.
- `batch_size`: 0
The batch size of the model. If it is 0, Olive will use the batch size from the input_shapes if `input_ids`.
- `dtype`: "float32"
The data type of the model.
- `shared_kv`: false
Whether to share the key value cache between the past and present key value cache. If it is true, the dynamic axes of the past and present key value cache will be the same.
- `sequence_length_idx`: 2
For most of the cases, the input shape for kv_cache is like (batch_size, num_attention_heads/world_size, sequence_length, hidden_size/num_attention_heads). The `sequence_length` is the index of the sequence length in the input shape.
- `past_kv_dynamic_axis`: null
The dynamic axis of the past key value cache. If it is null, Olive will infer the dynamic axis.
- `present_kv_dynamic_axis`: null
The dynamic axis of the present key value cache. If it is null, Olive will infer the dynamic axis.
- <a name="hf_config"></a> `hf_config: [Dict]` Instead of `model_path` or `model_loader`, the model can be specified using a dictionary describing a huggingface
model. This dictionary specifies the following items:

Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/llama2_qlora.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
trajepl marked this conversation as resolved.
Show resolved Hide resolved
},
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"task": "text-generation",
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/llama2_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"hf_config": {
"model_name": "<model_name_placeholder>",
"model_class": "LlamaForCausalLM",
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/llama2_tensor_parallel.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model":{
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_class": "LlamaForCausalLM",
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/notebook/llama2/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"model_path": {
"type": "azureml_registry_model",
"config": {
Expand Down
33 changes: 0 additions & 33 deletions examples/llama2/notebook/llama2/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from itertools import chain
from typing import List, Tuple, Union

import torch
Expand All @@ -19,13 +18,6 @@
# -----------------------------------------------------------------------------


def get_merged_decoder_with_past_dummy_inputs(model: PyTorchModelHandler):
"""Get dummy inputs for merged decoder model with past_key_values."""
# Dummy values for export
batch_size, seq_length, past_seq_length = 2, 8, 0
return get_merged_sample_with_past_kv_inputs(model, batch_size, seq_length, past_seq_length)


def get_merged_sample_with_past_kv_inputs(
model: PyTorchModelHandler,
batch_size: int,
Expand Down Expand Up @@ -152,31 +144,6 @@ def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str
return dynamic_axes


def get_merged_decoder_with_past_io_config(model: PyTorchModelHandler):
config = model.get_hf_model_config()

input_names = [
"input_ids",
"attention_mask",
"position_ids",
*list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers)
)
),
]
output_names = [
"logits",
*list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))),
]
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
return {
"input_names": input_names,
"dynamic_axes": dynamic_axes,
"output_names": output_names,
}


# -----------------------------------------------------------------------------
# Metric Data Loader
# -----------------------------------------------------------------------------
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/notebook/llama2_multiep/config_cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_class": "LlamaForCausalLM",
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/notebook/llama2_multiep/config_gpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
trajepl marked this conversation as resolved.
Show resolved Hide resolved
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_class": "LlamaForCausalLM",
Expand Down
14 changes: 11 additions & 3 deletions examples/llama2/notebook/llama2_multiep/config_multi_ep.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_dummy_inputs",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_class": "LlamaForCausalLM",
Expand Down
33 changes: 0 additions & 33 deletions examples/llama2/notebook/llama2_multiep/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from itertools import chain
from typing import List, Tuple, Union

import torch
Expand All @@ -19,13 +18,6 @@
# -----------------------------------------------------------------------------


def get_merged_decoder_with_past_dummy_inputs(model: PyTorchModelHandler):
"""Get dummy inputs for merged decoder model with past_key_values."""
# Dummy values for export
batch_size, seq_length, past_seq_length = 2, 8, 0
return get_merged_sample_with_past_kv_inputs(model, batch_size, seq_length, past_seq_length)


def get_merged_sample_with_past_kv_inputs(
model: PyTorchModelHandler,
batch_size: int,
Expand Down Expand Up @@ -152,31 +144,6 @@ def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str
return dynamic_axes


def get_merged_decoder_with_past_io_config(model: PyTorchModelHandler):
config = model.get_hf_model_config()

input_names = [
"input_ids",
"attention_mask",
"position_ids",
*list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers)
)
),
]
output_names = [
"logits",
*list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))),
]
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
return {
"input_names": input_names,
"dynamic_axes": dynamic_axes,
"output_names": output_names,
}


# -----------------------------------------------------------------------------
# Metric Data Loader
# -----------------------------------------------------------------------------
Expand Down
37 changes: 0 additions & 37 deletions examples/llama2/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------

from argparse import Namespace
from itertools import chain
from typing import List, Tuple, Union

import torch
Expand All @@ -20,13 +19,6 @@
# -----------------------------------------------------------------------------


def get_merged_decoder_with_past_dummy_inputs(model: PyTorchModelHandler):
"""Get dummy inputs for merged decoder model with past_key_values."""
# Dummy values for export
batch_size, seq_length, past_seq_length = 2, 8, 0
return get_merged_sample_with_past_kv_inputs(model, batch_size, seq_length, past_seq_length)


def get_merged_sample_with_past_kv_inputs(
model: PyTorchModelHandler,
batch_size: int,
Expand Down Expand Up @@ -156,35 +148,6 @@ def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str
return dynamic_axes


def get_merged_decoder_with_past_io_config(model: PyTorchModelHandler):
if model.hf_config is not None:
config = model.get_hf_model_config()
else:
# Using Namespace class to access dict items like class attributes
config = Namespace(**model.model_attributes)

input_names = [
"input_ids",
"attention_mask",
"position_ids",
*list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers)
)
),
]
output_names = [
"logits",
*list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))),
]
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
return {
"input_names": input_names,
"dynamic_axes": dynamic_axes,
"output_names": output_names,
}


# -----------------------------------------------------------------------------
# Metric Data Loader
# -----------------------------------------------------------------------------
Expand Down
14 changes: 11 additions & 3 deletions examples/phi2/phi2_optimize_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"dummy_inputs_func": "dummy_inputs",
"io_config": "get_io_config",
"io_config": {
"input_names": ["input_ids", "attention_mask"],
"output_names": ["logits"],
"input_shapes": [[2,8], [2,40]],
"input_types": ["int32", "int32"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "sequence_length"}
},
"kv_cache_config": true
},
"hf_config": {
"model_name": "microsoft/phi-2",
"task": "text-generation",
Expand Down
Loading
Loading