Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fused CrossEntropy and Update Vocab Sizes #251

Merged
merged 18 commits into from
Mar 24, 2023
6 changes: 4 additions & 2 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@ because more memory will enable you to use larger microbatch sizes.

# Optimizing Performance
The YAMLs in this repo are relatively well tuned for medium-to-large NVIDIA A100-40GB clusters.
On different devices with more / less GPU memory,
you may wish to edit the `device_train_microbatch_size` or `fsdp_config` values.

If you are running with a CUDA-compatible GPU and have installed the LLM requirements, we turn on by default a kernel fusion optimization for the Cross Entropy loss function at the end of the model. This should not affect your model convergence, but if you would like to disable this, you can set `model.loss_fn=torch_crossentropy`. To re-enable, set `model.loss_fn=fused_crossentropy` or omit it from your YAML.

On devices with more / less GPU memory, you may wish to edit the `device_train_microbatch_size` or `fsdp_config` values.
In general, larger microbatch sizes and disabling `activation_checkpointing` lead to higher throughput.

Note that each YAML specifies a `global_train_batch_size`, which is an optimization choice, i.e. the **math** being performed,
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/mcloud/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ parameters:
n_layers: 24
mlp_ratio: 4
max_seq_len: 2048
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/mcloud/mcli-1b-max-seq-len-8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ parameters:
n_layers: 24
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
1 change: 1 addition & 0 deletions examples/llm/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ omegaconf==2.2.3
wandb==0.13.6
pytest>=7.2.1,<8
torchmetrics==0.11.3
xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v0.2.8#subdirectory=csrc/xentropy
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
n_layers: int = 24,
mlp_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50257,
vocab_size: int = 50368,
init_std: float = 0.02,
attn_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
Expand Down Expand Up @@ -123,6 +123,8 @@ def __init__(
self.use_cache = use_cache
if 'name' in kwargs:
del kwargs['name']
if 'loss_fn' in kwargs:
del kwargs['loss_fn']
super().__init__(**kwargs)

self._validate_config()
Expand Down
21 changes: 18 additions & 3 deletions examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,22 @@ def __init__(self, om_model_config: DictConfig):
'Perplexity':
Perplexity(),
}
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
warnings.warn('Using Fused Cross Entropy Loss.')
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
except:
raise ValueError(
'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU and `pip install .[llm]`, or (2) set your config model.loss_fn=torch_crossentropy.'
)
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
else:
raise ValueError(
f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].'
)

def get_targets(self, batch):
targets = torch.roll(batch['labels'], shifts=-1)
Expand All @@ -427,9 +443,8 @@ def eval_forward(self, batch, outputs=None):

def loss(self, outputs, batch):
targets = self.get_targets(batch)
return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
targets.view(-1),
ignore_index=-100)
return self.loss_fn(outputs.view(-1, outputs.size(-1)),
targets.view(-1))

def get_metrics(self, is_train=False):
return self.train_metrics if is_train else self.eval_metrics
Expand Down
56 changes: 56 additions & 0 deletions examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,62 @@ def test_determinism(attention_type: str, precision):
optimizer_2.step()


@pytest.mark.gpu
def test_loss_fn():
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.

We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip

reproducibility.seed_all(1111)

conf_path = 'yamls/mosaic_gpt/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model)
model_2 = copy.deepcopy(model_1)
assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)

optimizer_1 = DecoupledAdamW(model_1.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)
optimizer_2 = DecoupledAdamW(model_2.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)

for i in range(25):
batch = gen_random_batch(2, test_cfg)
output_1 = model_1(batch)
output_2 = model_2(batch)
assert output_1.allclose(output_2, rtol=1e-4,
atol=1e-4), f'differed at step {i}'

loss_1 = model_1.loss(output_1, batch)
loss_2 = model_2.loss(output_2, batch)
assert loss_1.allclose(loss_2, rtol=1e-3,
atol=1e-3), f'differed at step {i}'
loss_1.backward()
loss_2.backward()
optimizer_1.step()
optimizer_2.step()

for p1, p2 in zip(model_1.parameters(), model_2.parameters()):
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
assert p1.data.shape == p2.data.shape
assert p1.data.allclose(p2.data, rtol=1e-5,
atol=1e-4), f'differed at step {i}'


@pytest.mark.parametrize('prefixlm', [False, True])
def test_opt_wrapping(prefixlm):
conf = {
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/125m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 12
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/13b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 40
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 24
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/30b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 48
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/350m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 24
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/3b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 32
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 80
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/760m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 24
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/yamls/mosaic_gpt/7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
n_layers: 32
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
Expand Down
3 changes: 2 additions & 1 deletion examples/llm/yamls/mosaic_gpt/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ model:
n_layers: 2
mlp_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50257
vocab_size: 50368
init_std: 0.02
attn_pdrop: 0.0
resid_pdrop: 0.0
emb_pdrop: 0.0
attn_impl: torch
loss_fn: torch_crossentropy
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved

# Tokenizer
tokenizer:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _dependencies_as_dict(deps: List[str]) -> Dict[str, str]:
"""map, e.g., 'foo>=1.5,<1.6' -> {'foo': '>=1.5,<1.6'}"""
ret = {}
for dep in deps:
elems = re.split('([=><])', dep.strip())
elems = re.split('([=><@])', dep.strip())
ret[elems[0]] = ''.join(elems[1:])
return ret

Expand All @@ -73,6 +73,7 @@ def _merge_dependencies(deps_base: List[str],
# a GPU on your machine
base_dict.pop('flash-attn', None)
base_dict.pop('triton', None)
base_dict.pop('xentropy-cuda-lib', None)
return [k + v for k, v in base_dict.items()] # 'foo': '>3' -> 'foo>3'


Expand Down