Skip to content

Commit

Permalink
Update on "Fix CI after torchmetrics update"
Browse files Browse the repository at this point in the history
It now takes an argument: https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html

Change in pytorch lightning:
Lightning-AI/torchmetrics@20eab43

Somehow this is failing with a SEGFAULT on my A100 (in a triton kernel):
```
#0  0x00007fffc0f62e10 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#1  0x00007fffc0f9303c in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#2  0x00007fffc0f2ea13 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#3  0x00007fffc0f94603 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#4  0x00007fffc119e4a0 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#5  0x00007fffc0f3728f in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#6  0x00007fffc0f3999f in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#7  0x00007fffc0fdb1c2 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#8  0x00007fff502234c0 in _launch ()
   from /data/home/XXXXX/.triton/cache/704a3e6949e60326bc68d18a620bee50/layer_norm_fw.so
#9  0x00007fff3c0eea25 in launch ()
   from /data/home/XXXXX/.triton/cache/2cebb5590a024a2e06fe9de08c6b7079/k_dropout_bw.so
#10 0x0000555555698422 in cfunction_call (func=0x7fff3c6e5760, args=<optimized out>, kwargs=<optimized out>)
    at /usr/local/src/conda/python-3.10.6/Objects/methodobject.c:552
```

[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Dec 8, 2022
1 parent 7e40466 commit d119652
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements-benchmark.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ tqdm == 4.59.0
pandas == 1.2.4
seaborn == 0.11.1
pytorch-lightning >= 1.3
torchmetrics>=0.7.0, <0.10.1
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_vit_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def __init__(
self.head = nn.Linear(dim, num_classes)

self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
# For torchmetrics > 0.11:
# self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.val_accuracy = Accuracy()

@staticmethod
def linear_warmup_cosine_decay(warmup_steps, total_steps):
Expand Down

0 comments on commit d119652

Please sign in to comment.