Skip to content

Commit

Permalink
Add LoRA support to AI Edge Transformers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704425190
  • Loading branch information
hheydary authored and copybara-github committed Dec 20, 2024
1 parent 9d387ec commit 174f290
Show file tree
Hide file tree
Showing 7 changed files with 844 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'gemma',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = gemma1.build_2b_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
27 changes: 22 additions & 5 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.layers import lora as lora_utils
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
Expand Down Expand Up @@ -66,6 +67,7 @@ def forward(
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: kv_utils.KVCacheEntry = None,
lora: Optional[lora_utils.LoRAEntry] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
"""Forward function of the TransformerBlock.
Expand All @@ -75,6 +77,7 @@ def forward(
mask (torch.Tensor): the optional mask tensor.
input_pos (torch.Tensor): the optional input position tensor.
kv_cache (KVCacheEntry): the optional kv cache entry.
lora (LoRAEntry): the optional lora entry.
Returns:
output activation from this transformer block, and updated kv cache (if
Expand All @@ -83,7 +86,9 @@ def forward(
kv = None
if self.config.parallel_residual:
x_norm = self.pre_atten_norm(x)
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
atten_func_out = self.atten_func(
x_norm, rope, mask, input_pos, kv_cache, lora
)
if kv_cache is None:
attn_out = atten_func_out
else:
Expand All @@ -92,7 +97,9 @@ def forward(
output = x + attn_out + ff_out
else:
x_norm = self.pre_atten_norm(x)
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
atten_func_out = self.atten_func(
x_norm, rope, mask, input_pos, kv_cache, lora
)
if kv_cache is None:
attn_out = atten_func_out
else:
Expand Down Expand Up @@ -152,6 +159,7 @@ def forward(
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
lora: Optional[lora_utils.LoRAEntry] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
"""Forward function of the CausalSelfAttention layer, which can support
Expand All @@ -163,6 +171,7 @@ def forward(
mask (torch.Tensor): the optional mask tensor.
input_pos (torch.Tensor): the optional input position tensor.
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
lora (LoRAEntry): the optional lora entry.
Returns:
output activation from this self attention layer, and the updated
Expand Down Expand Up @@ -201,6 +210,11 @@ def forward(
dim=-1,
)

if lora is not None:
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)

q = self.query_norm(q)
k = self.key_norm(k)

Expand All @@ -218,18 +232,21 @@ def forward(
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
k, v = kv_cache.k_cache, kv_cache.v_cache

y = self.sdpa_func(
sdpa_out = self.sdpa_func(
q,
k,
v,
self.config.head_dim,
mask=mask,
softcap=self.config.logit_softcap,
)
y = y.reshape(B, T, -1)
sdpa_out = sdpa_out.reshape(B, T, -1)

# Compute the output projection.
y = self.output_projection(y)
y = self.output_projection(sdpa_out)
if lora is not None:
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)

return y if kv_cache is None else (y, kv_cache)


Expand Down
Loading

0 comments on commit 174f290

Please sign in to comment.