Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) #3186

Merged
merged 12 commits into from
Mar 17, 2020

Conversation

thomwolf
Copy link
Member

@thomwolf thomwolf commented Mar 9, 2020

This PR add some utilities to benchmark (RAM) memory consumption of the models.
This is actually a generic utility that can work with any arbitrary python code

Ex:

import torch
from transformers import GPT2Model, GPT2Tokenizer
from transformers import start_memory_tracing, stop_memory_tracing

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

sequence = tokenizer.encode("Hello how are you", return_tensors='pt')

# Line by line memory tracing (all code in the module `transformers`).
trace = start_memory_tracing(modules_to_trace="transformers")
output = model(sequence)
summary = stop_memory_tracing(trace)

# Summary contain three fields:
# `sequential`: list of line by line consumption (with line code and location)
# `cumulative`: list of cumulative line by line consumption (when lines are executed several times) ordered from the most memory consuming line to the least (also with line code and location)
# `total`: total memory consumption of the script (default to sum memory increase at each line and ignore released mem, can be seet to count increase and release by less reliable on ubuntu).
# Each `Memory` object contain CPU, GPU and CPU + GPU memory, each both in int and human readable string

print(f"Total memory consumption: {summary.total}")
top_line = summary.cumulative[0]
print(f"Consumed {top_line.cpu_gpu}: {top_line.frame.line_text} at {top_line.frame.filename}:{top_line.frame.line_number}")

Incorporated in the ./examples/benchmark.py script. Example of command-line run:

(py37) bash-3.2$ python ./examples/benchmarks.py --models gpt2 --torch --batch_sizes 1 --slice_sizes 64 256 512 512 512 --no_speed --verbose
Running with arguments Namespace(amp=False, average_over=30, batch_sizes=[1], csv_filename=None, fp16=False, keras_predict=False, models=['gpt2'], no_memory=False, no_speed=True, save_to_csv=False, slice_sizes=[64, 256, 512, 512, 512], tensorflow=False, torch=True, torch_cuda=False, torchscript=False, verbose=False, xla=False)
1 / 1
Token indices sequence length is longer than the specified maximum sequence length for this model (2708 > 1024). Running this sequence through the model will result in indexing errors
....
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:487: mem 0.000B:                 presents = presents + (present,)
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:489: mem 0.000B:             if self.output_attentions:
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:477: mem 0.000B:         for i, (block, layer_past) in enumerate(zip(self.h, past)):
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:492: mem 0.000B:         hidden_states = self.ln_f(hidden_states)
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:494: mem 0.000B:         hidden_states = hidden_states.view(*output_shape)
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:496: mem 0.000B:         if self.output_hidden_states:
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:499: mem 0.000B:         outputs = (hidden_states,)
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:500: mem 0.000B:         if self.output_past:
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:501: mem 0.000B:             outputs = outputs + (presents,)
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:502: mem 0.000B:         if self.output_hidden_states:
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:504: mem 0.000B:         if self.output_attentions:
/Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:509: mem 0.000B:         return outputs  # last hidden state, (presents), (all hidden_states), (attentions)

Top 5 script lines consuming the most memory:
0 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/activations.py:31: mem 276.004MB:     return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
1 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_utils.py:1311: mem 151.520MB:         x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
2 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:146: mem 146.004MB:         w = w * b - 1e4 * (1 - b)
3 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:143: mem 132.004MB:             w = w / math.sqrt(v.size(-1))
4 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:187: mem 36.000MB:         present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
5 => /Users/thomwolf/Documents/GitHub/transformers/src/transformers/modeling_gpt2.py:159: mem 33.000MB:         outputs = [torch.matmul(w, v)]

Memory increase computed by summing traced script lines: 843.758MB
=========== RESULTS ===========
        ======= MODEL CHECKPOINT: gpt2 =======
                ===== BATCH SIZE: 1 =====
                gpt2/1/64: N/A 75.176MB
                gpt2/1/256: N/A 349.695MB
                gpt2/1/512: N/A 843.758MB
                gpt2/1/512: N/A 843.758MB
                gpt2/1/512: N/A 843.758MB

@codecov-io
Copy link

codecov-io commented Mar 12, 2020

Codecov Report

Merging #3186 into master will decrease coverage by 0.50%.
The diff coverage is 30.76%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3186      +/-   ##
==========================================
- Coverage   78.15%   77.64%   -0.51%     
==========================================
  Files          98       99       +1     
  Lines       16641    16795     +154     
==========================================
+ Hits        13006    13041      +35     
- Misses       3635     3754     +119     
Impacted Files Coverage Δ
src/transformers/modeling_utils.py 91.34% <15.38%> (-3.07%) ⬇️
src/transformers/benchmark_utils.py 31.74% <31.74%> (ø)
src/transformers/__init__.py 98.92% <100.00%> (+0.01%) ⬆️
src/transformers/configuration_gpt2.py 97.29% <100.00%> (+0.07%) ⬆️
src/transformers/modeling_gpt2.py 86.07% <100.00%> (ø)
src/transformers/modeling_tf_utils.py 91.53% <0.00%> (-2.17%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e03129a...cb67ca6. Read the comment docs.

@thomwolf thomwolf changed the title [WIP] Better memory benchmarking for the models CPU and GPU memory benchmarking utilities Mar 12, 2020
@thomwolf thomwolf requested a review from LysandreJik March 12, 2020 22:02

Memory = namedtuple("Memory", ["bytes", "string"])
MemoryState = namedtuple("MemoryState", ["frame", "cpu", "gpu", "cpu_gpu"])
MemorySummary = namedtuple("MemorySummary", ["sequential", "cumulative", "total"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the alternative, typed syntax for NamedTuples?

I like it a lot, is slightly more powerful, and same Python requirements (until we use type hints which is 3.6): https://docs.python.org/3/library/typing.html#typing.NamedTuple

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes that's a lot better!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW could also use a pd.DataFrame and do the diff code with .diff in like two lines. Would also make csv saving cheap. I think examples might also have a pandas dependency implicitly through pytorch-lightning

@@ -496,3 +498,253 @@ def _resumable_file_manager():
json.dump(meta, meta_file)

return cache_path

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we create a new benchmarking_utils.py file for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed!

@thomwolf thomwolf assigned thomwolf and unassigned thomwolf Mar 12, 2020
@thomwolf thomwolf requested a review from sshleifer March 12, 2020 23:01
Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this. Excited to use!
Left some comments which might be out of scope for this PR

examples/benchmarks.py Show resolved Hide resolved
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
if not no_memory:
# model.add_memory_hooks() # Forward method tracing (only for PyTorch models)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) delete this?

@@ -250,15 +257,21 @@

def create_setup_and_compute(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely lack context, but I feel like I would want to use this with 1 model, 1 batch_size, one slice size. The fact that the code is taking lists of models, lists of batch sizes, and lists of slice sizes adds a fair amount of complexity, (e.g. results[model_name]["memory"][batch_size][slice_size] would just be results['memory'].

If the use case is comparing different runs I guess the signature makes sense. Is that the reasoning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that the script we used for this blog post for instance: https://medium.com/huggingface/benchmarking-transformers-pytorch-and-tensorflow-e2917fb891c2

I kept it as is, just added memory benchmarking in addition to speed (and a little more flexibility in the CL args)

)
)
print(
"\nLines with lowest memory consumption:\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this useful output? are there negatives in cumulative?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are (in particular on my Ubuntu, there are some memory releases during execution which render the output noisier).

Also if you don't keep the output of the model, the last line is just a large release of memory to cancel all you've allocated haha.

)
)
print(
"\nLines with lowest memory consumption:\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we share a def print_summary_statistics with _compute_pytorch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed

dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"

if not no_speed:
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) number=3 feels like it should be exposed as a higher level kwarg like nruns=3

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the previous code. We can keep it as is for now.


Memory = namedtuple("Memory", ["bytes", "string"])
MemoryState = namedtuple("MemoryState", ["frame", "cpu", "gpu", "cpu_gpu"])
MemorySummary = namedtuple("MemorySummary", ["sequential", "cumulative", "total"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW could also use a pd.DataFrame and do the diff code with .diff in like two lines. Would also make csv saving cheap. I think examples might also have a pandas dependency implicitly through pytorch-lightning

@@ -66,6 +67,47 @@ def num_parameters(self, only_trainable: bool = False) -> int:
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)

@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does rss refer to in this context?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"resident set size" as in the psutil RSS memory info https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info

I didn't comment much these PyTorch hooks because I feel like the tracing methods are more general (can be used both in PT and TF) and they give the same results. I kept them for now in the codebase though.

Copy link
Member Author

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sshleifer, regarding the panda dependency: note that these tracing methods are in the library, not in the examples.

And we are very careful about the dependencies we add for the main library (also note that pytorch-lightning doesn't depend on pandas).

@thomwolf thomwolf changed the title CPU and GPU memory benchmarking utilities CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) Mar 17, 2020
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM. Glad the benchmarking script was upgraded with that!

Comment on lines +51 to +60
class UsedMemoryState(NamedTuple):
""" `UsedMemoryState` are named tuples with the following fields:
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file)
- 'cpu_memory': CPU RSS memory state *before* executing the line
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided)
"""

frame: Frame
cpu_memory: int
gpu_memory: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's clean

@LysandreJik LysandreJik merged commit 2187c49 into master Mar 17, 2020
@LysandreJik LysandreJik deleted the memory-benchmark branch March 17, 2020 14:17
@patrickvonplaten patrickvonplaten mentioned this pull request Mar 19, 2020
42 tasks
jplu pushed a commit to jplu/transformers that referenced this pull request Mar 25, 2020
… (now only 3.6+) (huggingface#3186)

* memory benchmark rss

* have both forward pass and line-by-line mem tracing

* cleaned up tracing

* refactored and cleaning up API

* no f-strings yet...

* add GPU mem logging

* fix GPU memory monitoring

* style and quality

* clean up and doc

* update with comments

* Switching to python 3.6+

* fix quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants