Skip to content

Commit

Permalink
refactor attn layers
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Mar 16, 2023
1 parent 67dc1b7 commit c7d0aaf
Show file tree
Hide file tree
Showing 13 changed files with 447 additions and 533 deletions.
6 changes: 2 additions & 4 deletions examples/common/fdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
"""Monitor rate of change of loss."""
from __future__ import annotations

from typing import Any, Dict

from composer.core import Callback, State
from composer.loggers import Logger


class FDiffMetrics(Callback):
"""Rate of chage of metrics.
tracks and plots the rate of change of metrics effectively taking the numerical
derivative of the metrics
tracks and plots the rate of change of metrics effectively taking the
numerical derivative of the metrics
"""

def __init__(self, diff_train_metrics=True, diff_eval_metrics=True):
Expand Down
14 changes: 9 additions & 5 deletions examples/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM,
ComposerHFPrefixLM, ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape,
generate_attn_bias, scaled_multihead_dot_product_attention,
scaled_multihead_dot_product_self_attention)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
Expand All @@ -37,9 +38,12 @@
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'scaled_multihead_dot_product_self_attention',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
15 changes: 2 additions & 13 deletions examples/llm/scripts/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from composer.utils import get_device, maybe_create_object_store_from_uri
from omegaconf import OmegaConf as om

from examples.llm import TorchCausalAttention
from examples.llm.src.model_registry import COMPOSER_MODEL_REGISTRY


Expand Down Expand Up @@ -127,18 +126,8 @@ def main(cfg):
load_weights_only=True)
# replace flash/triton attention with torch causal attention
for idx in range(cfg.model.n_layers):
torch_causal_attn = TorchCausalAttention(cfg.model)
torch_causal_attn.mhsa.in_proj_weight = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.weight
torch_causal_attn.mhsa.in_proj_bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.bias
torch_causal_attn.mhsa.out_proj.weight = (
orig_model.model.transformer.blocks[idx].causal_attn.mhsa.
out_proj.weight)
torch_causal_attn.mhsa.out_proj.bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.out_proj.bias
export_model.model.transformer.blocks[
idx].causal_attn = torch_causal_attn
export_model.model.transformer.blocks[idx].attn.load_state_dict(
orig_model.model.transformer.blocks[idx].attn.state_dict())
else:
export_model = orig_model

Expand Down
14 changes: 9 additions & 5 deletions examples/llm/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape,
generate_attn_bias, scaled_multihead_dot_product_attention,
scaled_multihead_dot_product_self_attention)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
Expand All @@ -25,9 +26,12 @@
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'scaled_multihead_dot_product_self_attention',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
14 changes: 9 additions & 5 deletions examples/llm/src/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape,
generate_attn_bias, scaled_multihead_dot_product_attention,
scaled_multihead_dot_product_self_attention)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock

__all__ = [
'FlashAttention',
'FlashMHA',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'scaled_multihead_dot_product_self_attention',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
Loading

0 comments on commit c7d0aaf

Please sign in to comment.