From c7d0aaf352f24f301baafa3e0399e31b8e950160 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley Date: Thu, 16 Mar 2023 02:12:07 +0000 Subject: [PATCH] refactor attn layers --- examples/common/fdiff.py | 6 +- examples/llm/__init__.py | 14 +- examples/llm/scripts/export_for_inference.py | 15 +- examples/llm/src/__init__.py | 14 +- examples/llm/src/models/layers/__init__.py | 14 +- examples/llm/src/models/layers/attention.py | 592 ++++++++---------- .../llm/src/models/layers/flash_attention.py | 14 +- examples/llm/src/models/layers/gpt_blocks.py | 17 +- examples/llm/src/models/mosaic_gpt.py | 89 +-- examples/llm/tests/test_flash_triton_torch.py | 181 +++--- examples/llm/tests/test_hf_v_mosaic_gpt.py | 11 +- examples/llm/tests/test_model.py | 12 +- examples/llm/tests/test_training.py | 1 + 13 files changed, 447 insertions(+), 533 deletions(-) diff --git a/examples/common/fdiff.py b/examples/common/fdiff.py index 6ea10f917..c22a3b548 100644 --- a/examples/common/fdiff.py +++ b/examples/common/fdiff.py @@ -7,8 +7,6 @@ """Monitor rate of change of loss.""" from __future__ import annotations -from typing import Any, Dict - from composer.core import Callback, State from composer.loggers import Logger @@ -16,8 +14,8 @@ class FDiffMetrics(Callback): """Rate of chage of metrics. - tracks and plots the rate of change of metrics effectively taking the numerical - derivative of the metrics + tracks and plots the rate of change of metrics effectively taking the + numerical derivative of the metrics """ def __init__(self, diff_train_metrics=True, diff_eval_metrics=True): diff --git a/examples/llm/__init__.py b/examples/llm/__init__.py index cf51d3176..255f17b8d 100644 --- a/examples/llm/__init__.py +++ b/examples/llm/__init__.py @@ -11,8 +11,9 @@ from examples.llm.src.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) from examples.llm.src.models.layers.attention import ( - FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention, - alibi_bias) + MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape, + generate_attn_bias, scaled_multihead_dot_product_attention, + scaled_multihead_dot_product_self_attention) from examples.llm.src.models.layers.flash_attention import (FlashAttention, FlashMHA) from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock @@ -37,9 +38,12 @@ 'ComposerHFPrefixLM', 'ComposerHFT5', 'COMPOSER_MODEL_REGISTRY', - 'TorchCausalAttention', - 'FlashCausalAttention', - 'TritonFlashCausalAttention', + 'scaled_multihead_dot_product_attention', + 'scaled_multihead_dot_product_self_attention', + 'MultiheadAttention', + 'attn_bias_shape', + 'attn_bias_', + 'generate_attn_bias', 'alibi_bias', 'GPTMLP', 'GPTBlock', diff --git a/examples/llm/scripts/export_for_inference.py b/examples/llm/scripts/export_for_inference.py index b9eccc336..320dafd17 100644 --- a/examples/llm/scripts/export_for_inference.py +++ b/examples/llm/scripts/export_for_inference.py @@ -44,7 +44,6 @@ from composer.utils import get_device, maybe_create_object_store_from_uri from omegaconf import OmegaConf as om -from examples.llm import TorchCausalAttention from examples.llm.src.model_registry import COMPOSER_MODEL_REGISTRY @@ -127,18 +126,8 @@ def main(cfg): load_weights_only=True) # replace flash/triton attention with torch causal attention for idx in range(cfg.model.n_layers): - torch_causal_attn = TorchCausalAttention(cfg.model) - torch_causal_attn.mhsa.in_proj_weight = orig_model.model.transformer.blocks[ - idx].causal_attn.mhsa.Wqkv.weight - torch_causal_attn.mhsa.in_proj_bias = orig_model.model.transformer.blocks[ - idx].causal_attn.mhsa.Wqkv.bias - torch_causal_attn.mhsa.out_proj.weight = ( - orig_model.model.transformer.blocks[idx].causal_attn.mhsa. - out_proj.weight) - torch_causal_attn.mhsa.out_proj.bias = orig_model.model.transformer.blocks[ - idx].causal_attn.mhsa.out_proj.bias - export_model.model.transformer.blocks[ - idx].causal_attn = torch_causal_attn + export_model.model.transformer.blocks[idx].attn.load_state_dict( + orig_model.model.transformer.blocks[idx].attn.state_dict()) else: export_model = orig_model diff --git a/examples/llm/src/__init__.py b/examples/llm/src/__init__.py index 93f9d823e..91cc4f5a9 100644 --- a/examples/llm/src/__init__.py +++ b/examples/llm/src/__init__.py @@ -7,8 +7,9 @@ from examples.llm.src.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) from examples.llm.src.models.layers.attention import ( - FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention, - alibi_bias) + MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape, + generate_attn_bias, scaled_multihead_dot_product_attention, + scaled_multihead_dot_product_self_attention) from examples.llm.src.models.layers.flash_attention import (FlashAttention, FlashMHA) from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock @@ -25,9 +26,12 @@ 'ComposerHFPrefixLM', 'ComposerHFT5', 'COMPOSER_MODEL_REGISTRY', - 'TorchCausalAttention', - 'FlashCausalAttention', - 'TritonFlashCausalAttention', + 'scaled_multihead_dot_product_attention', + 'scaled_multihead_dot_product_self_attention', + 'MultiheadAttention', + 'attn_bias_shape', + 'attn_bias_', + 'generate_attn_bias', 'alibi_bias', 'GPTMLP', 'GPTBlock', diff --git a/examples/llm/src/models/layers/__init__.py b/examples/llm/src/models/layers/__init__.py index 26d1789af..ac7092b66 100644 --- a/examples/llm/src/models/layers/__init__.py +++ b/examples/llm/src/models/layers/__init__.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from examples.llm.src.models.layers.attention import ( - FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention, - alibi_bias) + MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape, + generate_attn_bias, scaled_multihead_dot_product_attention, + scaled_multihead_dot_product_self_attention) from examples.llm.src.models.layers.flash_attention import (FlashAttention, FlashMHA) from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock @@ -11,9 +12,12 @@ __all__ = [ 'FlashAttention', 'FlashMHA', - 'TorchCausalAttention', - 'FlashCausalAttention', - 'TritonFlashCausalAttention', + 'scaled_multihead_dot_product_attention', + 'scaled_multihead_dot_product_self_attention', + 'MultiheadAttention', + 'attn_bias_shape', + 'attn_bias_', + 'generate_attn_bias', 'alibi_bias', 'GPTMLP', 'GPTBlock', diff --git a/examples/llm/src/models/layers/attention.py b/examples/llm/src/models/layers/attention.py index 46cd07eb5..0ed0d975e 100644 --- a/examples/llm/src/models/layers/attention.py +++ b/examples/llm/src/models/layers/attention.py @@ -3,380 +3,306 @@ """Attention layers for the GPT models.""" +import math import warnings from typing import Optional import torch -import torch.nn as nn from einops import rearrange from omegaconf import DictConfig - - -class TorchCausalAttention(nn.Module): - - def __init__(self, cfg: DictConfig, device: Optional[str] = None): - super().__init__() - self.mhsa = nn.MultiheadAttention( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - dropout=cfg.attn_pdrop, - bias=True, - batch_first=True, - device=device, - ) - self.mhsa.out_proj._is_residual = True # type: ignore - - warnings.warn( - DeprecationWarning( - 'Using `attn_impl: torch` is deprecated; recommened using `attn_impl: flash`.' - )) - - def forward(self, x, key_padding_mask, attn_mask=None): - if key_padding_mask is not None: - key_padding_mask = ~key_padding_mask - return self.mhsa(x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=True) - - @staticmethod - def mask_shape(n_heads, seq_len, alibi): - if alibi: - return (n_heads, seq_len, seq_len) - return (seq_len, seq_len) - - @staticmethod - def attn_mask_(attn_mask, n_heads, seq_len, alibi=False, alibi_bias_max=8): - # in-place fill causal attn mask - # - # Two important disclaimers - # 1. Torch uses additive attention. If your attn_mask/key_padding mask is a float tensor, it will add the floats - # directly to your attention matrix. If they are boolean masks, True will be converted to -inf before adding the - # mask to your attentions. See https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward - # Basically True/-inf indicates tokens we do not want to attend to. - # - # 2. This is is the exact opposite behavior of Huggingface's tokenizers, which use the convention that True denotes tokens - # we do want to attend to. See https://huggingface.co/docs/transformers/glossary#attention-mask - attn_mask.fill_(float('-inf')) - # attn_mask.triu_(diagonal=1) # triu_ is not implemented for cuda bf16 - # TODO: revert back to triu_ when torch supports triu_ for cuda bf16 - attn_mask.masked_fill_(attn_mask.to(bool).fill_(1).tril_(), 0.) - - if alibi: - device, dtype = attn_mask.device, attn_mask.dtype - a_bias = alibi_bias(n_heads, - seq_len, - full=True, - alibi_bias_max=alibi_bias_max, - device=device, - dtype=dtype) - attn_mask.add_(a_bias.squeeze()) - - return attn_mask - - @staticmethod - def generate_attn_mask( - attn_mask, - batch_size, - heads, - seq_len, - key_padding_mask=None, - alibi=False, - dtype=None, - ): - - # select seq_len subset of attn mask - attn_mask = attn_mask[..., :seq_len, :seq_len] - - if key_padding_mask is not None and _check_apply_key_padding_mask( - key_padding_mask): - attn_mask = attn_mask.expand(batch_size, heads, seq_len, - seq_len).clone() - - kpm_fill_value = -1e4 # numerically stable -inf - attn_mask.masked_fill_( - ~key_padding_mask.view(batch_size, 1, 1, seq_len), - kpm_fill_value) - attn_mask.masked_fill_( - ~key_padding_mask.view(batch_size, 1, seq_len, 1), - kpm_fill_value) - attn_mask = attn_mask.reshape(-1, seq_len, seq_len) - elif alibi: - # WARNING: Alibi with torch attn is not thoroughly tested - # torch mask is supposed to be of shape nzz x SeqLen x SeqLen - # we must braodcast to batch size then flatten batchsize * n_heads dim - # Note: if key_padding_mask is triggered, the needed expansion is already done. - attn_mask = attn_mask.expand(batch_size, heads, seq_len, - seq_len).reshape(-1, seq_len, seq_len) - - return attn_mask - - -class FlashCausalAttention(nn.Module): +from torch import nn + + +def scaled_multihead_dot_product_attention( + query, + key, + value, + n_heads, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, +): + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + k = rearrange(key, 'b s (h d) -> b h s d', h=n_heads) + v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads) + + b, _, s, d = q.shape + + if softmax_scale is None: + softmax_scale = 1 / math.sqrt(d) + + attn_weight = q @ k.transpose(-2, -1) + attn_weight *= softmax_scale + + if attn_bias is not None: + attn_weight = attn_weight + attn_bias + + if key_padding_mask is not None: + attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s)), + -float('inf')) + + if is_causal: + causal_mask = attn_weight.new_ones(s, s, dtype=torch.bool) + causal_mask.tril_() + causal_mask.logical_not_() + attn_weight = attn_weight.masked_fill_(causal_mask.view(1, 1, s, s), + -float('inf')) + + attn_weight = torch.softmax(attn_weight, dim=-1) + + if dropout_p: + attn_weight = torch.nn.functional.dropout(attn_weight, + p=dropout_p, + training=training, + inplace=True) + + out = attn_weight @ v + out = rearrange(out, 'b h s d -> b s (h d)', h=n_heads) + + if needs_weights: + return out, attn_weight + return out, None + + +def scaled_multihead_dot_product_self_attention( + qkv, + n_heads, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, +): + qkv = rearrange(qkv, 'b s (t hd) -> b s t hd', t=3) + return scaled_multihead_dot_product_attention( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + n_heads, + softmax_scale=softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=dropout_p, + training=training, + needs_weights=needs_weights, + ) + + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + + Using torch or triton attention implemetation enables user to also use + additive bias. + """ def __init__(self, cfg: DictConfig, device: Optional[str] = None): super().__init__() - try: - from flash_attn.flash_attention import ( # type: ignore - FlashAttention, FlashMHA) - except ImportError as e: - raise e + self.attn_impl = cfg.get('attn_impl') self.clip_qkv = cfg.get('attn_clip_qkv') self.attn_qk_ln = cfg.get('attn_qk_ln') - self.softmax_scale = cfg.get('softmax_scale') + self.d_model = cfg.d_model self.n_heads = cfg.n_heads + self.softmax_scale = cfg.get('softmax_scale') + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.attn_dropout_p = cfg.get('attn_pdrop') + + self.Wqkv = nn.Linear(self.d_model, + 3 * self.d_model, + bias=True, + device=device) + # for param init fn; enables shape based init of fused layers + fuse_splits = (cfg.d_model, 2 * cfg.d_model) + self.Wqkv._fused = (0, fuse_splits) # type: ignore + + if self.attn_qk_ln: + self.q_ln = nn.LayerNorm(self.d_model, device=device) + self.k_ln = nn.LayerNorm(self.d_model, device=device) + + self.is_causal = True # causal attn impl + if self.attn_impl == 'flash': + try: + from flash_attn import flash_attention # type: ignore + except ImportError as e: + raise e + + self.inner_attn = flash_attention.FlashAttention( # type: ignore + attention_dropout=self.attn_dropout_p, + device=device, + softmax_scale=self.softmax_scale) + elif self.attn_impl == 'triton': + if self.attn_dropout_p: + raise ValueError( + 'Triton kernel does not support attention dropout.') + + try: + from examples.llm.src.models import layers # type: ignore + except ImportError as e: + raise e + + warnings.warn( + 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash`.') + + self.inner_attn = layers.flash_attention.FlashAttention( # type: ignore + num_heads=cfg.n_heads, + softmax_scale=self.softmax_scale, + device=device) + + elif self.attn_impl == 'torch': + pass - if self.attn_qk_ln or self.clip_qkv or self.softmax_scale: - self.W_qkv = nn.Linear(self.d_model, - 3 * self.d_model, - bias=True, - device=device) - self.inner_attn = FlashAttention(attention_dropout=cfg.attn_pdrop, - device=device, - softmax_scale=self.softmax_scale) - self.out_proj = nn.Linear(self.d_model, - self.d_model, - bias=True, - device=device) - # for param init fn - fuse_splits = (cfg.d_model, 2 * cfg.d_model) - self.W_qkv._fused = (0, fuse_splits) # type: ignore - self.out_proj._is_residual = True # type: ignore - - if self.attn_qk_ln: - self.q_ln = nn.LayerNorm(self.d_model, device=device) - self.k_ln = nn.LayerNorm(self.d_model, device=device) else: - self.mhsa = FlashMHA( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - attention_dropout=cfg.attn_pdrop, - bias=True, - batch_first=True, - causal=True, - device=device, - ) - # for param init fn - fuse_splits = (cfg.d_model, 2 * cfg.d_model) - self.mhsa.Wqkv._fused = (0, fuse_splits) # type: ignore - self.mhsa.out_proj._is_residual = True - - def forward(self, x, key_padding_mask, attn_mask=None): - assert attn_mask is None - - if self.attn_qk_ln or self.clip_qkv or self.softmax_scale: - qkv = self.W_qkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - if self.attn_qk_ln: - # Applying layernorm to qk - dtype = qkv.dtype - q, k, v = qkv.split(self.d_model, dim=-1) - q = self.q_ln(q).to(dtype) - k = self.k_ln(k).to(dtype) - qkv = torch.cat([q, k, v], dim=-1) - - # attention + raise ValueError(f"{cfg.get('attn_impl')=} is an invalid setting.") + + self.out_proj = nn.Linear(self.d_model, + self.d_model, + bias=True, + device=device) + self.out_proj._is_residual = True # type: ignore + + def forward(self, + x, + attn_bias=None, + key_padding_mask=None, + needs_weights=False): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + if self.attn_qk_ln: + # Applying layernorm to qk + dtype = qkv.dtype + q, k, v = qkv.split(self.d_model, dim=-1) + q = self.q_ln(q).to(dtype) + k = self.k_ln(k).to(dtype) + qkv = torch.cat([q, k, v], dim=-1) + + # attention + if self.attn_impl == 'flash': + if needs_weights: + raise NotImplementedError( + 'FlashAttn cannot return attention weights.') + if attn_bias is not None: + raise RuntimeError( + 'attn_impl: flash does not support an attn bias.') qkv = rearrange(qkv, - 'b s (three h d) -> b s three h d', - three=3, + 'b s (t h d) -> b s t h d', + t=3, h=self.n_heads) context, attn_weights = self.inner_attn( qkv, key_padding_mask=key_padding_mask, - causal=True, + causal=self.is_causal, need_weights=False) - return self.out_proj(rearrange( - context, 'b s h d -> b s (h d)')), attn_weights - - else: - return self.mhsa(x, - key_padding_mask=key_padding_mask, - need_weights=False) - - @staticmethod - def mask_shape(*args, **kwargs): - return None - - @staticmethod - def attn_mask_(*args, **kwargs): - return None - - @staticmethod - def generate_attn_mask( - attn_mask, - batch_size, - heads, - seq_len, - key_padding_mask=None, - alibi=False, - dtype=None, - ): - return attn_mask # None - - -class TritonFlashCausalAttention(nn.Module): - """Multi-headed self attention using triton FlashAttn kernel. + context = rearrange(context, 'b s h d -> b s (h d)') - This also includes bias for Alibi integration. - """ - - def __init__(self, cfg: DictConfig, device: Optional[str] = None): - super().__init__() - try: - from examples.llm.src.models.layers.flash_attention import ( # type: ignore - FlashAttention, FlashMHA) - except ImportError as e: - raise e - - assert cfg.attn_pdrop == 0, 'triton kernel does not support attn_dropout' - - self.clip_qkv = cfg.get('attn_clip_qkv') - self.attn_qk_ln = cfg.get('attn_qk_ln') - self.d_model = cfg.d_model - self.n_heads = cfg.n_heads - - if self.attn_qk_ln or self.clip_qkv: - self.Wqkv = nn.Linear(self.d_model, - 3 * self.d_model, - bias=True, - device=device) - self.inner_attn = FlashAttention( - num_heads=cfg.n_heads, - softmax_scale=cfg.get('softmax_scale'), - device=device) - self.out_proj = nn.Linear(self.d_model, - self.d_model, - bias=True, - device=device) - # for param init fn - fuse_splits = (cfg.d_model, 2 * cfg.d_model) - self.Wqkv._fused = (0, fuse_splits) # type: ignore - self.out_proj._is_residual = True # type: ignore - - if self.attn_qk_ln: - self.q_ln = nn.LayerNorm(self.d_model, device=device) - self.k_ln = nn.LayerNorm(self.d_model, device=device) - else: - self.mhsa = FlashMHA( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - bias=True, - batch_first=True, - causal=True, - softmax_scale=cfg.get('softmax_scale'), - device=device, - ) - # for param init fn - fuse_splits = (cfg.d_model, 2 * cfg.d_model) - self.mhsa.Wqkv._fused = (0, fuse_splits) # type: ignore - self.mhsa.out_proj._is_residual = True # type: ignore - - warnings.warn( - 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' - 'it uses more memory. When training larger models this can trigger ' - 'alloc retries which hurts performance. If encountered, we recommend ' - 'using `attn_impl: flash`.') - - def forward(self, x, key_padding_mask=None, attn_mask=None): - if self.attn_qk_ln or self.clip_qkv: - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - if self.attn_qk_ln: - # Applying layernorm to qk - dtype = qkv.dtype - q, k, v = qkv.split(self.d_model, dim=-1) - q = self.q_ln(q).to(dtype) - k = self.k_ln(k).to(dtype) - qkv = torch.cat([q, k, v], dim=-1) - - # attention + elif self.attn_impl == 'triton': + if needs_weights: + raise NotImplementedError( + 'Triton variant of FlashAttn cannot return attention weigths.' + ) context, attn_weights = self.inner_attn( qkv, key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - is_causal=True) - - return self.out_proj(context), attn_weights + attn_bias=attn_bias, + is_causal=self.is_causal) + elif self.attn_impl == 'torch': + context, attn_weights = scaled_multihead_dot_product_self_attention( + qkv, + self.n_heads, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=self.is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + ) else: - return self.mhsa(x, - key_padding_mask=None, - attn_mask=attn_mask, - need_weights=False) + raise RuntimeError('Internal Logic Error') + + return self.out_proj(context), attn_weights + - @staticmethod - def mask_shape(n_heads, seq_len, alibi): +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi): + if attn_impl == 'flash': + return None + elif attn_impl == 'triton': + # in triton, key_padding_mask must be integrated into attn_bias + return (1, n_heads, 1, seq_len) if alibi else (1, 1, 1, seq_len) + elif attn_impl == 'torch': return (1, n_heads, 1, seq_len) if alibi else None + else: + raise ValueError(f'{attn_impl=} is an invalid setting.') - @staticmethod - def attn_mask_(attn_mask, n_heads, seq_len, alibi=False, alibi_bias_max=8): - if attn_mask is not None: - # in-place fill causal attn mask - attn_mask.zero_() +def attn_bias_(attn_impl, + attn_bias, + n_heads, + seq_len, + alibi=False, + alibi_bias_max=8): + if attn_impl == 'flash': + return None + elif attn_impl == 'triton': + attn_bias.zero_() + if alibi: + # in place add alibi to attn bias + device, dtype = attn_bias.device, attn_bias.dtype + attn_bias.add_( + alibi_bias(n_heads, + seq_len, + full=False, + alibi_bias_max=alibi_bias_max, + device=device, + dtype=dtype)) + return attn_bias + elif attn_impl == 'torch': + if attn_bias is not None: + attn_bias.zero_() if alibi: - device, dtype = attn_mask.device, attn_mask.dtype - attn_mask.add_( + # in place add alibi to attn bias + device, dtype = attn_bias.device, attn_bias.dtype + attn_bias.add_( alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype)) + return attn_bias + else: + raise ValueError(f'{attn_impl=} is an invalid setting.') + + +def generate_attn_bias(attn_impl, + attn_bias, + seq_len, + batch_size, + key_padding_mask=None): + if attn_bias is not None: + # select seq_len subset of attn mask + attn_bias = attn_bias[..., :seq_len, :seq_len] + + if attn_impl == 'triton' and key_padding_mask is not None: + attn_bias = attn_bias.expand(batch_size, -1, -1, -1) + attn_bias.masked_fill( + ~key_padding_mask.view((batch_size, 1, 1, seq_len)), -float('inf')) - return attn_mask - - @staticmethod - def generate_attn_mask( - attn_mask, - batch_size, - heads, - seq_len, - key_padding_mask=None, - alibi=False, - dtype=None, - ): - if attn_mask is not None: - # select seq_len subset of attn mask - attn_mask = attn_mask[..., :seq_len, :seq_len] - - if key_padding_mask is not None and _check_apply_key_padding_mask( - key_padding_mask): - if attn_mask is None: - attn_mask = key_padding_mask.new_zeros( - ((batch_size, 1, seq_len, seq_len)), dtype=dtype) - - kpm_fill_value = -1e4 # numerically stable -inf - attn_mask = attn_mask.masked_fill( - ~key_padding_mask.view((batch_size, 1, 1, seq_len)), - kpm_fill_value) - attn_mask = attn_mask.masked_fill( - ~key_padding_mask.view((batch_size, 1, seq_len, 1)), - kpm_fill_value) - - return attn_mask - - -def _check_apply_key_padding_mask(key_padding_mask): - if key_padding_mask.bool().logical_not().any(): - # check to verify all tokens after the first invalid tokens are invalid. - # if there are no valid tokens after the first invalid token, - # key_padding_mask isn't required given causal mask will eliminate - # unwanted token interaction. - # WARNING: this approach only works for right padded causal attn - # NOTE: I chose this algorithm given its vectorized; there is room for improvement... - c_sum = key_padding_mask.cumsum(1) - num_valid_tokens = c_sum[:, -1].long() - vals = c_sum[range(key_padding_mask.size(0)), num_valid_tokens - 1] - return any(vals != num_valid_tokens) - return False + return attn_bias def alibi_bias(n_heads, diff --git a/examples/llm/src/models/layers/flash_attention.py b/examples/llm/src/models/layers/flash_attention.py index 9d192d3b0..37ed8842d 100644 --- a/examples/llm/src/models/layers/flash_attention.py +++ b/examples/llm/src/models/layers/flash_attention.py @@ -44,7 +44,7 @@ def forward( self, qkv, key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, is_causal: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Multiheaded softmax attention. @@ -52,7 +52,7 @@ def forward( Arguments: qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) key_padding_mask: not implemented for triton kernel. - attn_mask: If specified, a 4D mask of floats which will be added to the attention weight. Must braodcast to (B, H, S, S). + attn_bias: If specified, a 4D mask of floats which will be added to the attention weight. Must braodcast to (B, H, S, S). is_causal: If specified, applies a causal mask as attention mask. Default: ``False``. """ from flash_attn import flash_attn_triton # type: ignore @@ -63,11 +63,11 @@ def forward( if key_padding_mask is not None and key_padding_mask.bool().logical_not( ).any(): raise NotImplementedError( - f'assumes key_padding_mask is taken care of by attn_mask') + f'assumes key_padding_mask is taken care of by attn_bias') qkv = rearrange(qkv, 'b s (t h d) -> b s t h d', t=3, h=self.num_heads) attn_output = flash_attn_triton.flash_attn_qkvpacked_func( - qkv, attn_mask, is_causal, self.softmax_scale) + qkv, attn_bias, is_causal, self.softmax_scale) output = rearrange(attn_output, 'b s h d -> b s (h d)') return output, None @@ -110,14 +110,14 @@ def __init__(self, def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, need_weights: bool = False): r"""Multiheaded softmax attention. Args: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) key_padding_mask: not implemented for triton kernel. - attn_mask: If specified, a 4D mask of floats which will be added to the attention weight. Must braodcast to (B, H, S, S). + attn_bias: If specified, a 4D mask of floats which will be added to the attention weight. Must braodcast to (B, H, S, S). need_weights: not implemented for triton kernel. """ if need_weights: @@ -127,6 +127,6 @@ def forward(self, context, attn_weights = self.inner_attn( qkv, key_padding_mask=key_padding_mask, - attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=self.causal) return self.out_proj(context), attn_weights diff --git a/examples/llm/src/models/layers/gpt_blocks.py b/examples/llm/src/models/layers/gpt_blocks.py index e1525e21e..bbe39287c 100644 --- a/examples/llm/src/models/layers/gpt_blocks.py +++ b/examples/llm/src/models/layers/gpt_blocks.py @@ -9,6 +9,8 @@ import torch.nn as nn from omegaconf import DictConfig +from examples.llm.src.models.layers.attention import MultiheadAttention + class GPTMLP(nn.Module): @@ -29,15 +31,10 @@ def forward(self, x): class GPTBlock(nn.Module): - def __init__(self, - cfg: DictConfig, - causal_attn_cls, - device: Optional[str] = None): + def __init__(self, cfg: DictConfig, device: Optional[str] = None): super().__init__() - if cfg.get('alibi', False): - assert cfg.attn_impl == 'triton' or cfg.attn_impl == 'torch', 'Only triton kernel or torch supports alibi' self.ln_1 = nn.LayerNorm(cfg.d_model, device=device) - self.causal_attn = causal_attn_cls(cfg, device) + self.attn = MultiheadAttention(cfg, device) self.ln_2 = nn.LayerNorm(cfg.d_model, device=device) self.mlp = GPTMLP(cfg, device=device) self.resid_attn_dropout = nn.Dropout(cfg.resid_pdrop) @@ -46,11 +43,13 @@ def __init__(self, def forward( self, x: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.ByteTensor] = None, - attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln_1(x) - b, _ = self.causal_attn(a, key_padding_mask, attn_mask) + b, _ = self.attn(a, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask) x = x + self.resid_attn_dropout(b) m = self.ln_2(x) n = self.mlp(m) diff --git a/examples/llm/src/models/mosaic_gpt.py b/examples/llm/src/models/mosaic_gpt.py index e328e18d4..ab665969b 100644 --- a/examples/llm/src/models/mosaic_gpt.py +++ b/examples/llm/src/models/mosaic_gpt.py @@ -29,39 +29,15 @@ def __init__(self, cfg: DictConfig): super().__init__() assert cfg.name == 'mosaic_gpt', f'Tried to build MosaicGPT model with cfg.name={cfg.name}' self.cfg = cfg - if cfg.attn_impl == 'torch': - self.causal_attn_cls = attention.TorchCausalAttention - elif cfg.attn_impl == 'flash': - self.causal_attn_cls = attention.FlashCausalAttention - elif cfg.attn_impl == 'triton': - self.causal_attn_cls = attention.TritonFlashCausalAttention - else: - raise ValueError(f'Unknown attn_impl={cfg.attn_impl}') - - if cfg.get('attn_qk_ln') and cfg.attn_impl not in ['flash', 'triton']: - raise NotImplementedError( - 'LayerNorm over queries and keys in attention is only implemented with flash and triton attention.' - ) - if cfg.get('attn_clip_qkv') and cfg.attn_impl not in [ - 'flash', 'triton' - ]: - raise NotImplementedError( - 'QKV clipping only implemented with flash and triton attention.' - ) - - if cfg.get('softmax_scale') and cfg.attn_impl not in [ - 'flash', 'triton' - ]: - raise NotImplementedError( - 'softmax_scale only implemented with flash and triton attention.' - ) + self.attn_impl = cfg.attn_impl self.alibi = cfg.get('alibi', False) self.alibi_bias_max = cfg.get('alibi_bias_max', 8 if self.alibi else None) if self.alibi and cfg.attn_impl not in ['torch', 'triton']: raise NotImplementedError( 'alibi only implemented with torch and triton attention.') + # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414) # both report this helping with stabilizing training self.embedding_fraction = cfg.get('embedding_fraction', 1) @@ -84,9 +60,7 @@ def __init__(self, cfg: DictConfig): self.transformer.update({ 'blocks': nn.ModuleList([ - gpt_blocks.GPTBlock(cfg, - causal_attn_cls=self.causal_attn_cls, - device=cfg.init_device) + gpt_blocks.GPTBlock(cfg, device=cfg.init_device) for _ in range(cfg.n_layers) ]) }) @@ -114,15 +88,10 @@ def __init__(self, cfg: DictConfig): self.apply(self.param_init_fn) # define attn mask - self._attn_mask_initialized = False - mask_shape = self.causal_attn_cls.mask_shape(cfg.n_heads, - cfg.max_seq_len, - self.alibi) - if mask_shape is not None: - self.register_buffer( - 'attn_mask', torch.empty(mask_shape, device=cfg.init_device)) - else: - self.attn_mask = None + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attention.attn_bias_shape( + self.attn_impl, cfg.n_heads, cfg.max_seq_len, self.alibi) if cfg.get('no_bias', False): for module in self.modules(): @@ -135,27 +104,30 @@ def __init__(self, cfg: DictConfig): if cfg.get('verbose') and cfg.get('verbose') > 2: print(self) - def _attn_mask(self, + def _attn_bias(self, batch_size=None, seq_len=None, key_padding_mask=None, + device=None, dtype=None): - if not self._attn_mask_initialized: - self.causal_attn_cls.attn_mask_(self.attn_mask, - self.cfg.n_heads, - self.cfg.max_seq_len, - alibi=self.alibi, - alibi_bias_max=self.alibi_bias_max) - self._attn_mask_initialized = True - - return self.causal_attn_cls.generate_attn_mask( - self.attn_mask, - batch_size, - self.cfg.n_heads, - seq_len, - key_padding_mask=key_padding_mask, - alibi=self.alibi, - dtype=dtype) + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = torch.empty(self.attn_bias_shape, + device=device, + dtype=dtype) + attention.attn_bias_(self.attn_impl, + self.attn_bias, + self.cfg.n_heads, + self.cfg.max_seq_len, + alibi=self.alibi, + alibi_bias_max=self.alibi_bias_max) + self._attn_bias_initialized = True + + return attention.generate_attn_bias(self.attn_impl, + self.attn_bias, + seq_len, + batch_size, + key_padding_mask=key_padding_mask) def forward(self, input_ids: torch.LongTensor, @@ -183,9 +155,10 @@ def forward(self, assert isinstance(self.transformer.emb_drop, nn.Module) # pyright x = self.transformer.emb_drop(x_shrunk) - attn_mask = self._attn_mask(batch_size=B, + attn_bias = self._attn_bias(batch_size=B, seq_len=S, key_padding_mask=key_padding_mask, + device=x.device, dtype=x.dtype) if self.cfg.attn_impl == 'flash' and key_padding_mask is None: # HazyResearch FlashMHA appears to use more memory when `key_padding_mask=None` @@ -197,7 +170,9 @@ def forward(self, else: mod_key_padding_mask = key_padding_mask for block in self.transformer.blocks: # type: ignore - x = block(x, mod_key_padding_mask, attn_mask) + x = block(x, + attn_bias=attn_bias, + key_padding_mask=mod_key_padding_mask) x = self.transformer.ln_f(x) # type: ignore # output embedding weight tied to input embedding assert isinstance(self.transformer.wte, nn.Module) # pyright diff --git a/examples/llm/tests/test_flash_triton_torch.py b/examples/llm/tests/test_flash_triton_torch.py index eb7c44098..d1b6cd199 100644 --- a/examples/llm/tests/test_flash_triton_torch.py +++ b/examples/llm/tests/test_flash_triton_torch.py @@ -12,45 +12,77 @@ def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2): @pytest.mark.gpu -def test_flash_torch(device='cuda'): - from examples.llm.src.models.layers.attention import ( # type: ignore - FlashCausalAttention, TorchCausalAttention) +@pytest.mark.parametrize('attn_impl_0', ['flash', 'triton', 'torch']) +@pytest.mark.parametrize('attn_impl_1', ['flash', 'triton', 'torch']) +@pytest.mark.parametrize('attn_clip_qkv', [True, False]) +@pytest.mark.parametrize('attn_qk_ln', [True, False]) +@pytest.mark.parametrize('alibi', [True, False]) +def test_attn_impl(attn_impl_0, + attn_impl_1, + attn_clip_qkv, + attn_qk_ln, + alibi, + device='cuda'): + """Compare all attn impl with each other. + + Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. + """ + from examples.llm.src.models.layers import attention # type: ignore + + if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): + pytest.xfail('flash attn does not support alibi') reproducibility.seed_all(7) cfg = om.create({ + 'attn_impl': 'flash', 'd_model': 256, 'n_heads': 2, 'attn_pdrop': 0, + 'attn_clip_qkv': attn_clip_qkv, + 'attn_qk_ln': attn_qk_ln, }) n, s, f = 2, 16, cfg.d_model - fca = FlashCausalAttention(cfg).to(device) - tca = TorchCausalAttention(cfg).to(device) - - def gen_tca_mask(): - ms = TorchCausalAttention.mask_shape(cfg.n_heads, s, False) - attn_mask = torch.empty(*ms).to(device) - TorchCausalAttention.attn_mask_(attn_mask, cfg.n_heads, s) - return attn_mask + cfg.attn_impl = attn_impl_0 + attn0 = attention.MultiheadAttention(cfg).to(device) + cfg.attn_impl = attn_impl_1 + attn1 = attention.MultiheadAttention(cfg).to(device) - # clone weights - tca.mhsa.in_proj_weight.data = fca.mhsa.Wqkv.weight.data.clone().detach() - tca.mhsa.in_proj_bias.data = fca.mhsa.Wqkv.bias.data.clone().detach() - tca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone( - ).detach() - tca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone().detach() + attn1.load_state_dict(attn0.state_dict()) key_padding_mask = torch.ones(n, s).to(device).bool() + + def gen_bias(attn_impl, key_padding_mask): + attn_bias = None + bs = attention.attn_bias_shape(attn_impl, cfg.n_heads, s, alibi) + if bs is not None: + attn_bias = torch.empty(*bs, device=device) + attention.attn_bias_(attn_impl, + attn_bias, + cfg.n_heads, + s, + alibi=alibi, + alibi_bias_max=8) + attn_bias = attention.generate_attn_bias( + attn_impl, attn_bias, s, n, key_padding_mask=key_padding_mask) + + if attn_impl == 'triton': + return attn_bias, key_padding_mask + + return attn_bias, key_padding_mask + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True x1.requires_grad = True with torch.autocast(x0.device.type): - y0, _ = fca(x0, key_padding_mask, attn_mask=None) - y1, _ = tca(x1, key_padding_mask, attn_mask=gen_tca_mask()) + attn_bias, kpm = gen_bias(attn0.attn_impl, key_padding_mask) + y0, _ = attn0(x0, attn_bias=attn_bias, key_padding_mask=kpm) + attn_bias, kpm = gen_bias(attn1.attn_impl, key_padding_mask) + y1, _ = attn1(x1, attn_bias=attn_bias, key_padding_mask=kpm) y0 *= key_padding_mask.unsqueeze(-1) y1 *= key_padding_mask.unsqueeze(-1) @@ -62,70 +94,72 @@ def gen_tca_mask(): assert allclose_helper(y0, y1) - assert allclose_helper(tca.mhsa.out_proj.bias.grad, - fca.mhsa.out_proj.bias.grad) - assert allclose_helper(tca.mhsa.out_proj.weight.grad, - fca.mhsa.out_proj.weight.grad) - assert allclose_helper(tca.mhsa.in_proj_bias.grad, fca.mhsa.Wqkv.bias.grad) - assert allclose_helper(tca.mhsa.in_proj_weight.grad, - fca.mhsa.Wqkv.weight.grad) + torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + for n, p in attn0.named_parameters(): + tp = torch_name_param_map[n] + assert allclose_helper(p, tp) + assert allclose_helper(p.grad, tp.grad) assert allclose_helper(x0.grad, x1.grad) @pytest.mark.gpu -@pytest.mark.parametrize('attn_clip_qkv,attn_qk_ln', [ - (False, False), - (False, True), - (True, False), - (True, True), -]) -def test_flash_triton(attn_clip_qkv, attn_qk_ln, device='cuda'): - from examples.llm.src.models.layers.attention import ( # type: ignore - FlashCausalAttention, TritonFlashCausalAttention) +@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch']) +def test_vs_mha(attn_impl, device='cuda'): + """Compare diff attn_impl to torch.nn.MultiheadAttention.""" + from examples.llm.src.models.layers import attention # type: ignore - reproducibility.seed_all(7) + reproducibility.seed_all(17) cfg = om.create({ + 'attn_impl': attn_impl, 'd_model': 256, 'n_heads': 2, 'attn_pdrop': 0, - 'attn_clip_qkv': attn_clip_qkv, - 'attn_qk_ln': attn_qk_ln, + 'attn_clip_qkv': False, + 'attn_qk_ln': False, }) n, s, f = 2, 16, cfg.d_model - fca = FlashCausalAttention(cfg).to(device) - tfca = TritonFlashCausalAttention(cfg).to(device) + mmhsa = attention.MultiheadAttention(cfg).to(device) + tmhsa = torch.nn.MultiheadAttention( + embed_dim=cfg.d_model, + num_heads=cfg.n_heads, + dropout=cfg.attn_pdrop, + bias=True, + batch_first=True, + device=device, + ) + + def gen_tca_mask(): + # generate causal mask for torch attn + ms = (s, s) + attn_mask = torch.empty(*ms).to(device) + attn_mask.fill_(float('-inf')) + attn_mask.masked_fill_(attn_mask.to(torch.bool).fill_(1).tril_(), 0.) + return attn_mask + # clone weights - if cfg.attn_qk_ln or cfg.attn_clip_qkv: - tfca.Wqkv.weight.data = fca.W_qkv.weight.data.clone().detach() - tfca.Wqkv.bias.data = fca.W_qkv.bias.data.clone().detach() - tfca.out_proj.weight.data = fca.out_proj.weight.data.clone().detach() - tfca.out_proj.bias.data = fca.out_proj.bias.data.clone().detach() - if cfg.attn_qk_ln: - tfca.q_ln.weight.data = fca.q_ln.weight.data.clone().detach() - tfca.q_ln.bias.data = fca.q_ln.bias.data.clone().detach() - tfca.k_ln.weight.data = fca.k_ln.weight.data.clone().detach() - tfca.k_ln.bias.data = fca.k_ln.bias.data.clone().detach() - else: - tfca.mhsa.Wqkv.weight.data = fca.mhsa.Wqkv.weight.data.clone().detach() - tfca.mhsa.Wqkv.bias.data = fca.mhsa.Wqkv.bias.data.clone().detach() - tfca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone( - ).detach() - tfca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone( - ).detach() - - key_padding_mask = torch.ones(n, s).to(device) + tmhsa.in_proj_weight.data = mmhsa.Wqkv.weight.data.clone().detach() + tmhsa.in_proj_bias.data = mmhsa.Wqkv.bias.data.clone().detach() + tmhsa.out_proj.weight.data = mmhsa.out_proj.weight.data.clone().detach() + tmhsa.out_proj.bias.data = mmhsa.out_proj.bias.data.clone().detach() + + key_padding_mask = torch.ones(n, s).to(device).bool() x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True x1.requires_grad = True with torch.autocast(x0.device.type): - y0, _ = fca(x0, key_padding_mask, attn_mask=None) - y1, _ = tfca(x1, key_padding_mask, attn_mask=None) + y0, _ = mmhsa(x0, attn_bias=None, key_padding_mask=key_padding_mask) + y1, _ = tmhsa(x1, + x1, + x1, + attn_mask=gen_tca_mask(), + key_padding_mask=~key_padding_mask, + need_weights=True) y0 *= key_padding_mask.unsqueeze(-1) y1 *= key_padding_mask.unsqueeze(-1) @@ -137,25 +171,10 @@ def test_flash_triton(attn_clip_qkv, attn_qk_ln, device='cuda'): assert allclose_helper(y0, y1) - if cfg.attn_qk_ln or cfg.attn_clip_qkv: - assert allclose_helper(tfca.out_proj.bias.grad, fca.out_proj.bias.grad) - assert allclose_helper(tfca.out_proj.weight.grad, - fca.out_proj.weight.grad) - if cfg.attn_qk_ln: - assert allclose_helper(tfca.q_ln.bias.grad, fca.q_ln.bias.grad) - assert allclose_helper(tfca.q_ln.weight.grad, fca.q_ln.weight.grad) - assert allclose_helper(tfca.k_ln.bias.grad, fca.k_ln.bias.grad) - assert allclose_helper(tfca.k_ln.weight.grad, fca.k_ln.weight.grad) - assert allclose_helper(tfca.Wqkv.bias.grad, fca.W_qkv.bias.grad) - assert allclose_helper(tfca.Wqkv.weight.grad, fca.W_qkv.weight.grad) - else: - assert allclose_helper(tfca.mhsa.out_proj.bias.grad, - fca.mhsa.out_proj.bias.grad) - assert allclose_helper(tfca.mhsa.out_proj.weight.grad, - fca.mhsa.out_proj.weight.grad) - assert allclose_helper(tfca.mhsa.Wqkv.bias.grad, - fca.mhsa.Wqkv.bias.grad) - assert allclose_helper(tfca.mhsa.Wqkv.weight.grad, - fca.mhsa.Wqkv.weight.grad) + assert allclose_helper(tmhsa.out_proj.bias.grad, mmhsa.out_proj.bias.grad) + assert allclose_helper(tmhsa.out_proj.weight.grad, + mmhsa.out_proj.weight.grad) + assert allclose_helper(tmhsa.in_proj_bias.grad, mmhsa.Wqkv.bias.grad) + assert allclose_helper(tmhsa.in_proj_weight.grad, mmhsa.Wqkv.weight.grad) assert allclose_helper(x0.grad, x1.grad) diff --git a/examples/llm/tests/test_hf_v_mosaic_gpt.py b/examples/llm/tests/test_hf_v_mosaic_gpt.py index 446cd3f63..07a904f13 100644 --- a/examples/llm/tests/test_hf_v_mosaic_gpt.py +++ b/examples/llm/tests/test_hf_v_mosaic_gpt.py @@ -180,15 +180,8 @@ def test_compare_hf_v_mosaic_gpt(attn_impl, dropout, strict, alibi, mask_val, '.mlp.c_fc.': '.mlp.mlp_up.', '.mlp.c_proj.': '.mlp.mlp_down.', } - if attn_impl == 'torch': - hf_2_mosaic_key_mods[ - '.attn.c_attn.weight'] = '.causal_attn.mhsa.in_proj_weight' - hf_2_mosaic_key_mods[ - '.attn.c_attn.bias'] = '.causal_attn.mhsa.in_proj_bias' - hf_2_mosaic_key_mods['.attn.c_proj.'] = '.causal_attn.mhsa.out_proj.' - else: - hf_2_mosaic_key_mods['.attn.c_attn.'] = '.causal_attn.mhsa.Wqkv.' - hf_2_mosaic_key_mods['.attn.c_proj.'] = '.causal_attn.mhsa.out_proj.' + hf_2_mosaic_key_mods['.attn.c_attn.'] = '.attn.Wqkv.' + hf_2_mosaic_key_mods['.attn.c_proj.'] = '.attn.out_proj.' # convert hf gpt statedict to mosaic gpt statedict using the dict and list above _hf_model_statedict = {} diff --git a/examples/llm/tests/test_model.py b/examples/llm/tests/test_model.py index 248b2a93b..02c159872 100644 --- a/examples/llm/tests/test_model.py +++ b/examples/llm/tests/test_model.py @@ -157,18 +157,20 @@ def test_attention_mechanism(batch_size=2): axis=1) expected_zerod_weights |= torch_key_padding - attn_mask = model.model._attn_mask(batch_size=batch_size, + attn_bias = model.model._attn_bias(batch_size=batch_size, seq_len=test_cfg.max_seq_len, key_padding_mask=key_padding_mask) for block in model.model.transformer.blocks: a = block.ln_1(x) - b, attention_weights = block.causal_attn(a, - key_padding_mask, - attn_mask=attn_mask) + b, attention_weights = block.attn(a, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + needs_weights=True) zerod_weights = (attention_weights == 0) - assert torch.equal(expected_zerod_weights, zerod_weights) + assert torch.equal(expected_zerod_weights.expand(*zerod_weights.shape), + zerod_weights) x = x + block.resid_attn_dropout(b) m = block.ln_2(x) n = block.mlp(m) diff --git a/examples/llm/tests/test_training.py b/examples/llm/tests/test_training.py index f55157c2c..3704dbfd8 100644 --- a/examples/llm/tests/test_training.py +++ b/examples/llm/tests/test_training.py @@ -59,6 +59,7 @@ def test_train(device, logit_scale): ) test_cfg = gpt_tiny_cfg(conf_path='yamls/mosaic_gpt/125m.yaml') + test_cfg.eval_subset_num_batches = 2 if logit_scale: test_cfg.model.logit_scale = logit_scale