diff --git a/requirements-benchmark.txt b/requirements-benchmark.txt index c82da257d5..f730b33c6a 100644 --- a/requirements-benchmark.txt +++ b/requirements-benchmark.txt @@ -8,3 +8,4 @@ tqdm == 4.59.0 pandas == 1.2.4 seaborn == 0.11.1 pytorch-lightning >= 1.3 +torchmetrics>=0.7.0, <0.10.1 diff --git a/xformers/benchmarks/benchmark_vit_timm.py b/xformers/benchmarks/benchmark_vit_timm.py index 7089599f38..0e747e6ffd 100644 --- a/xformers/benchmarks/benchmark_vit_timm.py +++ b/xformers/benchmarks/benchmark_vit_timm.py @@ -201,7 +201,9 @@ def __init__( self.head = nn.Linear(dim, num_classes) self.criterion = torch.nn.CrossEntropyLoss() - self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) + # For torchmetrics > 0.11: + # self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) + self.val_accuracy = Accuracy() @staticmethod def linear_warmup_cosine_decay(warmup_steps, total_steps):