Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) #3186

Merged
merged 12 commits into from
Mar 17, 2020
10 changes: 5 additions & 5 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ jobs:
run_tests_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
Expand Down Expand Up @@ -46,7 +46,7 @@ jobs:
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
RUN_CUSTOM_TOKENIZERS: yes
steps:
Expand All @@ -56,7 +56,7 @@ jobs:
run_examples_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
Expand All @@ -69,7 +69,7 @@ jobs:
deploy_doc:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
steps:
- add_ssh_keys:
fingerprints:
Expand All @@ -94,7 +94,7 @@ jobs:
check_repository_consistency:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
resource_class: small
parallelism: 1
steps:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Choose the right framework for every part of a model's lifetime

## Installation

This repo is tested on Python 3.5+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1
This repo is tested on Python 3.6+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1

You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

Transformers is tested on Python 3.5+ and PyTorch 1.1.0
Transformers is tested on Python 3.6+ and PyTorch 1.1.0

## With pip

Expand Down
178 changes: 157 additions & 21 deletions examples/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from time import time
from typing import List

from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
from transformers import (
AutoConfig,
AutoTokenizer,
MemorySummary,
MemoryState,
Frame,
is_tf_available,
is_torch_available,
start_memory_tracing,
stop_memory_tracing,
)
sshleifer marked this conversation as resolved.
Show resolved Hide resolved


if is_tf_available():
Expand Down Expand Up @@ -250,15 +260,21 @@

def create_setup_and_compute(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely lack context, but I feel like I would want to use this with 1 model, 1 batch_size, one slice size. The fact that the code is taking lists of models, lists of batch sizes, and lists of slice sizes adds a fair amount of complexity, (e.g. results[model_name]["memory"][batch_size][slice_size] would just be results['memory'].

If the use case is comparing different runs I guess the signature makes sense. Is that the reasoning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that the script we used for this blog post for instance: https://medium.com/huggingface/benchmarking-transformers-pytorch-and-tensorflow-e2917fb891c2

I kept it as is, just added memory benchmarking in addition to speed (and a little more flexibility in the CL args)

model_names: List[str],
batch_sizes: List[int],
slice_sizes: List[int],
gpu: bool = True,
tensorflow: bool = False,
average_over: int = 3,
no_speed: bool = False,
no_memory: bool = False,
verbose: bool = False,
torchscript: bool = False,
xla: bool = False,
amp: bool = False,
fp16: bool = False,
save_to_csv: bool = False,
csv_filename: str = f"results_{round(time())}.csv",
csv_memory_filename: str = f"memory_{round(time())}.csv",
):
if xla:
tf.config.optimizer.set_jit(True)
Expand All @@ -267,11 +283,25 @@ def create_setup_and_compute(

if tensorflow:
dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
results = _compute_tensorflow(
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
)
else:
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
results = _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
)

print("=========== RESULTS ===========")
for model_name in model_names:
Expand All @@ -280,13 +310,19 @@ def create_setup_and_compute(
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]:
result = results[model_name]["results"][batch_size][slice_size]
memory = results[model_name]["memory"][batch_size][slice_size]
if isinstance(result, str):
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result} " f"{memory}")
else:
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
print(
f"\t\t{model_name}/{batch_size}/{slice_size}: "
f"{(round(1000 * result) / 1000)}"
f"s "
f"{memory}"
)

if save_to_csv:
with open(csv_filename, mode="w") as csv_file:
with open(csv_filename, mode="w") as csv_file, open(csv_memory_filename, mode="w") as csv_memory_file:
fieldnames = [
"model",
"1x8",
Expand Down Expand Up @@ -317,6 +353,8 @@ def create_setup_and_compute(

writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
memory_writer = csv.DictWriter(csv_memory_file, fieldnames=fieldnames)
memory_writer.writeheader()

for model_name in model_names:
model_results = {
Expand All @@ -326,8 +364,52 @@ def create_setup_and_compute(
}
writer.writerow({"model": model_name, **model_results})

model_memory_results = {
f"{bs}x{ss}": results[model_name]["memory"][bs][ss]
for bs in results[model_name]["memory"]
for ss in results[model_name]["memory"][bs]
}
memory_writer.writerow({"model": model_name, **model_memory_results})


def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
def print_summary_statistics(summary: MemorySummary):
print(
"\nLines by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
print(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
print(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
print(f"\nTotal memory increase: {summary.total}")


def _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
):
for c, model_name in enumerate(model_names):
print(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript)
Expand All @@ -337,17 +419,17 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)

max_input_size = tokenizer.max_model_input_sizes[model_name]
batch_sizes = [1, 2, 4, 8]
slice_sizes = [8, 64, 128, 256, 512, 1024]

dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}

for batch_size in batch_sizes:
if fp16:
model.half()
model.to(device)
model.eval()

for slice_size in slice_sizes:
if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
Expand All @@ -362,18 +444,40 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
inference = model
inference(sequence)

print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
if not no_memory:
# model.add_memory_hooks() # Forward method tracing (only for PyTorch models)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) delete this?


# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)

if verbose:
print_summary_statistics(summary)

dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"

if not no_speed:
print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"

except RuntimeError as e:
print("Doesn't fit on GPU.", e)
torch.cuda.empty_cache()
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary


def _compute_tensorflow(model_names, dictionary, average_over, amp):
def _compute_tensorflow(
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
):
for c, model_name in enumerate(model_names):
print(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name)
Expand All @@ -383,11 +487,10 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)

max_input_size = tokenizer.max_model_input_sizes[model_name]
batch_sizes = [1, 2, 4, 8]
slice_sizes = [8, 64, 128, 256, 512, 1024]

dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}

print("Using model", model)

Expand All @@ -409,13 +512,31 @@ def inference(inputs):
# To make sure that the model is traced + that the tensors are on the appropriate device
inference(sequence)

runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
if not no_memory:
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)

if verbose:
print_summary_statistics(summary)

dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"

if not no_speed:
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) number=3 feels like it should be exposed as a higher level kwarg like nruns=3

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the previous code. We can keep it as is for now.

average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"

except tf.errors.ResourceExhaustedError as e:
print("Doesn't fit on GPU.", e)
torch.cuda.empty_cache()
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary


Expand All @@ -433,6 +554,9 @@ def main():
"of all available model "
"architectures.",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Verbose memory tracing")
parser.add_argument("--no_speed", required=False, action="store_true", help="Don't perform speed measurments")
parser.add_argument("--no_memory", required=False, action="store_true", help="Don't perform memory measurments")
parser.add_argument(
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
)
Expand Down Expand Up @@ -477,6 +601,8 @@ def main():
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
parser.add_argument("--batch_sizes", nargs="+", type=int, default=[1, 2, 4, 8])
parser.add_argument("--slice_sizes", nargs="+", type=int, default=[8, 64, 128, 256, 512, 1024])

args = parser.parse_args()
if args.models == "all":
Expand All @@ -501,13 +627,18 @@ def main():
if is_torch_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=False,
gpu=args.torch_cuda,
torchscript=args.torchscript,
fp16=args.fp16,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
)
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
Expand All @@ -516,12 +647,17 @@ def main():
if is_tf_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=True,
xla=args.xla,
amp=args.amp,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
)
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ tensorboardX
tensorboard
scikit-learn
seqeval
psutil
Loading