Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2 should not store/compute cached activations during finetuning #2356

Closed
wants to merge 2 commits into from

Conversation

thomwolf
Copy link
Member

This PR tries to fix the issue with large memory usage from GPT2 during fine-tuning.

Quick estimations

@LysandreJik compared memory usage with @minimaxir GPT2-simple (https://github.com/minimaxir/gpt-2-simple):

Small model, batch size 4, sequence length 512 (roughly similar):

  • us => 9.9GB,
  • GPT2-simple => 8.5GB

Increasing to a 1024 length:

  • us => 20.4GB...,
  • GPT2-simple => still 8.5GB

Medium model, batch size de 4, sequence length de 512

  • us => 23.36GB. OOM on a titan with 1024 seq len.
  • GPT2-simple throws an error related to layers not contained in the checkpoint

Possible reason

Investigating our run_lm_finetuning script and GPT2 model showed that we are alway computing/storing cached hidden-states (which are normally only useful for decoding).

This PR attempt to fix this most probable source of large memory usage.
It cleans up a little bit GPT2 codebase at the same time.

I haven't tried it yet on a large scale test.

cc @LysandreJik @arnicas

@minimaxir
Copy link

Not sure which size of GPT-2 you're testing with, but the 355M version utilizes gradient checkpointing for finetuning in gpt-2-simple, which is not the case with the 124M version w/ Transformers.

That might be a useful test case.

@arnicas
Copy link

arnicas commented Dec 29, 2019

I just tried this with gpt-2-medium on my poetry dataset and have the same memory error as before. Complete info below:

python run_lm_finetuning.py --output_dir=output --model_type=gpt2 --model_name_or_path=gpt2-medium --do_train --train_data_file=all_gen_lines.txt --per_gpu_train_batch_size=1
12/29/2019 17:48:47 - WARNING - __main__ -   Process rank: -1, device: cuda, n_gpu: 2, distributed training: False, 16-bits training: False
12/29/2019 17:48:47 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json from cache at /home/jupyter/.cache/torch/transformers/98aa65385e18b0efd17acd8bf64dcdf21406bb0c99c801c2d3c9f6bfd1f48f29.5f9150c569dadadaa1e66830d29254aa5cf43f8ccd76dc0c81e0102c67032367
12/29/2019 17:48:47 - INFO - transformers.configuration_utils -   Model config {
  "attn_pdrop": 0.1,
  "embd_pdrop": 0.1,
  "finetuning_task": null,
  "initializer_range": 0.02,
  "is_decoder": false,
  "layer_norm_epsilon": 1e-05,
  "n_ctx": 1024,
  "n_embd": 1024,
  "n_head": 16,
  "n_layer": 24,
  "n_positions": 1024,
  "n_special": 0,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "predict_special_tokens": true,
  "pruned_heads": {},
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torchscript": false,
  "use_bfloat16": false,
  "vocab_size": 50257
}

12/29/2019 17:48:48 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json from cache at /home/jupyter/.cache/torch/transformers/f20f05d3ae37c4e3cd56764d48e566ea5adeba153dcee6eb82a18822c9c731ec.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
12/29/2019 17:48:48 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt from cache at /home/jupyter/.cache/torch/transformers/6d882670c55563617571fe0c97df88626fb5033927b40fc18a8acf98dafd4946.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
12/29/2019 17:48:48 - INFO - transformers.modeling_utils -   loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin from cache at /home/jupyter/.cache/torch/transformers/4b337a4f3b7d3e1518f799e238af607498c02938a3390152aaec7d4dabca5a02.8769029be4f66a5ae1055eefdd1d11621b901d510654266b8681719fff492d6e
12/29/2019 17:49:02 - INFO - __main__ -   Training/evaluation parameters Namespace(adam_epsilon=1e-08, block_size=1024, cache_dir='', config_name='', device=device(type='cuda'), do_eval=False, do_lower_case=False, do_train=True, eval_all_checkpoints=False, eval_data_file=None, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=1, learning_rate=5e-05, local_rank=-1, logging_steps=50, max_grad_norm=1.0, max_steps=-1, mlm=False, mlm_probability=0.15, model_name_or_path='gpt2-medium', model_type='gpt2', n_gpu=2, no_cuda=False, num_train_epochs=1.0, output_dir='output', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=4, per_gpu_train_batch_size=1, save_steps=50, save_total_limit=None, seed=42, server_ip='', server_port='', tokenizer_name='', train_data_file='all_gen_lines.txt', warmup_steps=0, weight_decay=0.0)
12/29/2019 17:49:02 - INFO - __main__ -   Loading features from cached file gpt2-medium_cached_lm_1024_all_gen_lines.txt.bin
12/29/2019 17:49:02 - INFO - __main__ -   ***** Running training *****
12/29/2019 17:49:02 - INFO - __main__ -     Num examples = 2061
12/29/2019 17:49:02 - INFO - __main__ -     Num Epochs = 1
12/29/2019 17:49:02 - INFO - __main__ -     Instantaneous batch size per GPU = 1
12/29/2019 17:49:02 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 2
12/29/2019 17:49:02 - INFO - __main__ -     Gradient Accumulation steps = 1
12/29/2019 17:49:02 - INFO - __main__ -     Total optimization steps = 1031
Epoch:   0%|                                                                                | 0/1 [00:00<?, ?it/s/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
                                                                                                                 Traceback (most recent call last):                                              | 1/1031 [00:05<1:30:32,  5.27s/it]
  File "run_lm_finetuning.py", line 717, in <module>
    main()
  File "run_lm_finetuning.py", line 667, in main
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
  File "run_lm_finetuning.py", line 298, in train
    outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 385, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 549, in forward
    inputs_embeds=inputs_embeds)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 460, in forward
    head_mask=head_mask[i])
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 236, in forward
    m = self.mlp(self.ln_2(x))
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 214, in forward
    h = self.act(self.c_fc(x))
  File "/home/jupyter/miniconda3/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 100, in gelu
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 11.17 GiB total capacity; 10.77 GiB already allocated; 14.06 MiB free; 66.92 MiB cached)

Epoch:   0%|                                                                                | 0/1 [00:05<?, ?it/s]
Iteration:   0%|                                                               | 1/1031 [00:05<1:41:31,  5.91s/it]
(base) jupyter@lynn-ukpavilion:~/code/transformers/examples$ nvidia-smi
Sun Dec 29 17:49:34 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0    71W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K80           Off  | 00000000:00:05.0 Off |                    0 |
| N/A   67C    P0    88W / 149W |      0MiB / 11441MiB |     94%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
(base) jupyter@lynn-ukpavilion:~/code/transformers/examples$
(base) jupyter@lynn-ukpavilion:~/code/transformers/examples$ git status
On branch fix-gpt2-finetuning-memory

@stale
Copy link

stale bot commented Mar 4, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Mar 4, 2020
@stale stale bot closed this Mar 11, 2020
@rakeshchada
Copy link
Contributor

@thomwolf @LysandreJik : Curious about the status of this. It seems like the memory issues still exists with "run_lm_finetuning.py" and GPT-2. For instance, even a batch size of 1 doesn't help prevent OOM error when fine-tuning GPT-2 large with a sequence length of 1024 (despite using FP-16). Is there anything we could do here (apart from gradient checkpointing) that would make the memory usage lower as Thomas listed in his first comment above? Thanks.

@LysandreJik LysandreJik deleted the fix-gpt2-finetuning-memory branch April 27, 2022 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants