diff --git a/examples/llm/src/models/mosaic_gpt.py b/examples/llm/src/models/mosaic_gpt.py index 34e295f87..e328e18d4 100644 --- a/examples/llm/src/models/mosaic_gpt.py +++ b/examples/llm/src/models/mosaic_gpt.py @@ -6,6 +6,7 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ +import math import warnings from typing import Optional @@ -92,6 +93,20 @@ def __init__(self, cfg: DictConfig): self.transformer.update( {'ln_f': nn.LayerNorm(cfg.d_model, device=cfg.init_device)}) + # enables scaling output logits; similar to a softmax "temperature" + # PaLM paper uses scale 1/sqrt(cfg.d_model) + self.logit_scale = None + if self.cfg.get('logit_scale') is not None: + logit_scale = self.cfg.get('logit_scale') + if isinstance(logit_scale, str): + if logit_scale == 'inv_sqrt_d_model': + logit_scale = 1 / math.sqrt(self.cfg.d_model) + else: + raise ValueError( + f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) + self.logit_scale = logit_scale + if cfg.init_device != 'meta': print( f'You are using {cfg.init_device=}, but you can also use cfg.init_device="meta" with Composer + FSDP for fast initialization.' @@ -188,6 +203,14 @@ def forward(self, assert isinstance(self.transformer.wte, nn.Module) # pyright assert isinstance(self.transformer.wte.weight, torch.Tensor) # pyright logits = F.linear(x, self.transformer.wte.weight, None) + + if self.logit_scale is not None: + if self.logit_scale == 0: + warnings.warn( + f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.' + ) + logits *= self.logit_scale + return logits # Param Initialization, needed for device='meta' fast initialization diff --git a/examples/llm/tests/test_training.py b/examples/llm/tests/test_training.py index 70fb09b3c..f55157c2c 100644 --- a/examples/llm/tests/test_training.py +++ b/examples/llm/tests/test_training.py @@ -46,7 +46,8 @@ def gpt_tiny_cfg(conf_path='yamls/mosaic_gpt/125m.yaml'): not torch.cuda.is_available(), reason='testing with cuda requires GPU')), ]) -def test_train(device): +@pytest.mark.parametrize('logit_scale', [None, 0.036, 'inv_sqrt_d_model']) +def test_train(device, logit_scale): if not os.path.isdir('./my-copy-c4/val'): pytest.xfail('c4 dataset not set up as expected') @@ -58,6 +59,8 @@ def test_train(device): ) test_cfg = gpt_tiny_cfg(conf_path='yamls/mosaic_gpt/125m.yaml') + if logit_scale: + test_cfg.model.logit_scale = logit_scale if device == 'cpu': pytest.xfail(