Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor attn layers #240

Merged
merged 18 commits into from
Mar 18, 2023
18 changes: 9 additions & 9 deletions examples/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM,
ComposerHFPrefixLM, ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
Expand All @@ -31,15 +30,16 @@
) from e

__all__ = [
'FlashAttention',
'FlashMHA',
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
15 changes: 2 additions & 13 deletions examples/llm/scripts/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from composer.utils import get_device, maybe_create_object_store_from_uri
from omegaconf import OmegaConf as om

from examples.llm import TorchCausalAttention
from examples.llm.src.model_registry import COMPOSER_MODEL_REGISTRY


Expand Down Expand Up @@ -127,18 +126,8 @@ def main(cfg):
load_weights_only=True)
# replace flash/triton attention with torch causal attention
for idx in range(cfg.model.n_layers):
torch_causal_attn = TorchCausalAttention(cfg.model)
torch_causal_attn.mhsa.in_proj_weight = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.weight
torch_causal_attn.mhsa.in_proj_bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.bias
torch_causal_attn.mhsa.out_proj.weight = (
orig_model.model.transformer.blocks[idx].causal_attn.mhsa.
out_proj.weight)
torch_causal_attn.mhsa.out_proj.bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.out_proj.bias
export_model.model.transformer.blocks[
idx].causal_attn = torch_causal_attn
export_model.model.transformer.blocks[idx].attn.load_state_dict(
orig_model.model.transformer.blocks[idx].attn.state_dict())
else:
export_model = orig_model

Expand Down
17 changes: 8 additions & 9 deletions examples/llm/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,26 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
LLMTokenizer)

__all__ = [
'build_text_denoising_dataloader',
'flash_attn_fn',
'triton_flash_attn_fn',
'MixtureOfDenoisersCollator',
'FlashAttention',
'FlashMHA',
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
17 changes: 8 additions & 9 deletions examples/llm/src/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock

__all__ = [
'FlashAttention',
'FlashMHA',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
Loading