Skip to content

Initial release

Compare
Choose a tag to compare
@Borda Borda released this 12 Mar 13:33
· 1803 commits to master since this release

What is Torchmetrics

TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducability
  • Reduces Boilerplate
  • Distributed-training compatible
  • Automatic accumulation over batches
  • Automatic synchronization between multiple devices

You can use TorchMetrics in any PyTorch model, or with in PyTorch Lightning to enjoy additional features:

  • Module metrics are automatically placed on the correct device.
  • Native support for logging metrics in Lightning to reduce even more boilerplate.

Using functional metrics

Similar to torch.nn, most metrics have both a module-based and a functional version. The functional version implements the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor.

import torch
# import our library
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)

Using Module metrics

Nearly all functional metrics have a corresponding module-based metric that calls it a functional counterpart underneath. The module-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:

  • Accumulation of multiple batches
  • Automatic synchronization between multiple devices
  • Metric arithmetic
import torch
# import our library
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

Built-in metrics

  • Accuracy
  • AveragePrecision
  • AUC
  • AUROC
  • F1
  • Hamming Distance
  • ROC
  • ExplainedVariance
  • MeanSquaredError
  • R2Score
  • bleu_score
  • embedding_similarity

And many more!

Contributors

@Borda, @SkafteNicki, @williamFalcon, @teddykoker, @justusschock, @tadejsv, @edenlightning, @ydcjeff, @ddrevicky, @ananyahjha93, @awaelchli, @rohitgr7, @akihironitta, @manipopopo, @Diuven, @arnaudgelas, @s-rog, @c00k1ez, @tgaddair, @elias-ramzi, @cuent, @jpcarzolio, @bryant1410, @shivdhar, @Sordie, @krzysztofwos, @abhik-99, @bernardomig, @peblair, @InCogNiTo124, @j-dsouza, @pranjaldatta, @ananthsub, @deng-cy, @abhinavg97, @tridao, @prampey, @abrahambotros, @ozen, @ShomyLiu, @yuntai, @pwwang

If we forgot someone due to not matching commit email with GitHub account, let us know :]