From 7e40466242dc255e9a8262e31e6b90944f0f2047 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 8 Dec 2022 10:09:23 +0000 Subject: [PATCH] Fix CI after torchmetrics update It now takes an argument: https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html [ghstack-poisoned] --- xformers/benchmarks/benchmark_vit_timm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/benchmarks/benchmark_vit_timm.py b/xformers/benchmarks/benchmark_vit_timm.py index 524812c540..7089599f38 100644 --- a/xformers/benchmarks/benchmark_vit_timm.py +++ b/xformers/benchmarks/benchmark_vit_timm.py @@ -201,7 +201,7 @@ def __init__( self.head = nn.Linear(dim, num_classes) self.criterion = torch.nn.CrossEntropyLoss() - self.val_accuracy = Accuracy() + self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) @staticmethod def linear_warmup_cosine_decay(warmup_steps, total_steps):