Skip to content

Commit

Permalink
Merge pull request facebookresearch#23 from fairinternal/benchmark_plots
Browse files Browse the repository at this point in the history
[feat] Generate plots from the benchmarks
  • Loading branch information
blefaudeux authored Apr 13, 2021
2 parents c579909 + 7f36772 commit 38b655a
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party =pytest,setuptools,sklearn,torch,tqdm
known_third_party =matplotlib,pandas,pytest,seaborn,setuptools,sklearn,torch,tqdm
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ Flexible Transformers, defined by interoperable and optimized building blocks th
[ ] Waay more tests, find more invariants depending on the blocks

[ ] Benchmark:
[ ] add at least something basic to check training
[ ] measure throughput and memory
[x] add at least something basic to check training
[x] measure throughput and memory
[ ] autogenerate text report
[ ] autogenerate curves
[x] autogenerate curves

## Architecture, code
[x] Remove the AttrDict dependency
Expand Down Expand Up @@ -80,6 +80,14 @@ Models live in `xformers/models`. As a general rule, one should try to write the
These live in `xformers/benchmarks`. Sweeping over different attention settings to log max memory use and runtime can for instance be done by invoking
`python3 benchmarks/benchmark_attention.py`. Specifying a subset to test is done through command line arguments, for instance `python3 benchmarks/benchmark_attention.py --causal True --attentions random --activations gelu -fp16 True`.

Some examples, generated on CPU:

![](docs/plots/memory_vs_attention.png)

![](docs/plots/runtime_vs_attention.png)



## Bibliography
DRAFT, needs a proper citation format, ..

Expand Down
70 changes: 60 additions & 10 deletions benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import json
import time
from typing import Dict, Optional
from typing import Any, Dict, List, Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.model_selection import ParameterGrid
Expand All @@ -20,6 +23,8 @@

# Credits: Sean Naren

_use_cuda = torch.cuda.is_available()


def _train_for_several_steps(
block: xFormerEncoderBlock,
Expand All @@ -36,8 +41,9 @@ def _train_for_several_steps(
# and this makes it bad for tests
optim = torch.optim.SGD(block.parameters(), lr=lr, momentum=0.9)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
if _use_cuda:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

start_time = time.time()
for _ in range(num_steps):
Expand All @@ -55,8 +61,11 @@ def _train_for_several_steps(
torch.nn.utils.clip_grad_norm_(block.parameters(), clip_norm, norm_type)
optim.step()

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
if _use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
else:
max_memory = -1
run_time = time.time() - start_time

return {"run_time": run_time, "max_memory": round(max_memory, 1)}
Expand Down Expand Up @@ -98,9 +107,7 @@ def test_xformer_encoder_block(
sequence_length=sequence_length,
embed_dim=embed_dim,
dropout=dropout,
)

block.to(device)
).to(device)

return benchmark_model(
num_steps=num_steps,
Expand Down Expand Up @@ -168,6 +175,43 @@ def instantiate_xformer(
return block


def plot(args, results: List[Dict[str, Any]]):
df = pd.DataFrame(results)

HEADS = args.heads[-1]
AMP = [args.pytorch_amp] if args.pytorch_amp is not None else True
EMB = args.embedding_dim[-1]
CAUSAL = args.causal if args.causal is not None else True
BATCH_SIZE = args.batch_size[-1]
ACTIVATION = args.activations[-1]

df_filtered = df[
(df["activation"] == ACTIVATION)
& (df["heads"] == HEADS)
& (df["autocast"] == AMP)
& (df["embed_dim"] == EMB)
& (df["causal"] == CAUSAL)
& (df["batch_size"] == BATCH_SIZE)
]

sns.barplot(
x="sequence_length", y="max_memory", hue="attention_name", data=df_filtered
)
plt.xlabel("Sequence length")
plt.ylabel("Max memory being used")
plt.title("Memory use")
plt.savefig("memory_vs_attention.png")
plt.clf()

sns.barplot(
x="sequence_length", y="run_time", hue="attention_name", data=df_filtered
)
plt.xlabel("Sequence length")
plt.ylabel("Average epoch time")
plt.title("Runtime")
plt.savefig("runtime_vs_attention.png")


if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
Expand All @@ -186,17 +230,19 @@ def instantiate_xformer(
"-sl", "--sequence_length", nargs="+", default=[128, 512, 768], type=int
)
parser.add_argument("-bs", "--batch_size", nargs="+", default=[8, 16, 32], type=int)
parser.add_argument("-hd", "--heads", nargs="+", default=[8, 16], type=int)

parser.add_argument(
"-fp16", "--pytorch_amp", action="store", default=None, type=bool
)
parser.add_argument("-causal", "--causal", action="store", default=None, type=bool)
parser.add_argument("-plot", "--plot", action="store_true", default=False)

args = parser.parse_args()

# Setup the test configs
constants = {
"device": torch.device("cuda"),
"device": torch.device("cuda") if _use_cuda else torch.device("cpu"),
"num_warmup": 5,
"num_steps": 10,
"dropout": 0.0,
Expand All @@ -209,7 +255,7 @@ def instantiate_xformer(
if args.pytorch_amp is not None
else [False, True],
"causal": [args.causal] if args.causal is not None else [False, True],
"heads": [8, 16],
"heads": args.heads,
"activation": args.activations,
"attention_name": args.attentions,
"feedforward_name": list(FEEDFORWARD_REGISTRY.keys()),
Expand All @@ -233,3 +279,7 @@ def instantiate_xformer(
grid_outputs.append(results)

print(json.dumps(grid_outputs, sort_keys=True, indent=4))

# Optional plots
if args.plot:
plot(args, grid_outputs)
Binary file added docs/plots/memory_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/runtime_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions requirements-benchmark.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
torch >= 1.5.1
scikit-learn == 0.24.1
tqdm == 4.59.0
pandas == 1.2.4
seaborn == 0.11.1
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Get core deps.
-r requirements.txt
-r requirements-benchmark.txt


# Tools for static checking.
black == 20.8b1
Expand Down

0 comments on commit 38b655a

Please sign in to comment.