Initial release
What is Torchmetrics
TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- A standardized interface to increase reproducability
- Reduces Boilerplate
- Distributed-training compatible
- Automatic accumulation over batches
- Automatic synchronization between multiple devices
You can use TorchMetrics in any PyTorch model, or with in PyTorch Lightning to enjoy additional features:
- Module metrics are automatically placed on the correct device.
- Native support for logging metrics in Lightning to reduce even more boilerplate.
Using functional metrics
Similar to torch.nn
, most metrics have both a module-based and a functional version. The functional version implements the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor.
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target)
Using Module metrics
Nearly all functional metrics have a corresponding module-based metric that calls it a functional counterpart underneath. The module-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
- Accumulation of multiple batches
- Automatic synchronization between multiple devices
- Metric arithmetic
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
Built-in metrics
- Accuracy
- AveragePrecision
- AUC
- AUROC
- F1
- Hamming Distance
- ROC
- ExplainedVariance
- MeanSquaredError
- R2Score
- bleu_score
- embedding_similarity
And many more!
Contributors
@Borda, @SkafteNicki, @williamFalcon, @teddykoker, @justusschock, @tadejsv, @edenlightning, @ydcjeff, @ddrevicky, @ananyahjha93, @awaelchli, @rohitgr7, @akihironitta, @manipopopo, @Diuven, @arnaudgelas, @s-rog, @c00k1ez, @tgaddair, @elias-ramzi, @cuent, @jpcarzolio, @bryant1410, @shivdhar, @Sordie, @krzysztofwos, @abhik-99, @bernardomig, @peblair, @InCogNiTo124, @j-dsouza, @pranjaldatta, @ananthsub, @deng-cy, @abhinavg97, @tridao, @prampey, @abrahambotros, @ozen, @ShomyLiu, @yuntai, @pwwang
If we forgot someone due to not matching commit email with GitHub account, let us know :]