Hands-on TorchMetrics - Supercharge Your PyTorch Projects

Image not Found

What is TorchMetrics?

TorchMetrics is a comprehensive library containing over 100+ PyTorch metrics implementations, designed to simplify your workflow with an intuitive API for creating personalized metrics. With TorchMetrics, you gain access to a wide range of benefits:

  • Enhanced Reproducibility: Embrace a standardized interface that enhances reproducibility, making your experiments more reliable and easily shareable.

  • Minimized Boilerplate Code: Say goodbye to tedious boilerplate code. TorchMetrics streamlines the process, allowing you to focus on what truly matters: your research.

  • Distributed-Training Compatibility: Seamlessly incorporate TorchMetrics into your distributed-training setups, ensuring smooth and efficient processing across multiple devices.

  • Rigorously Tested: Rest assured that TorchMetrics has undergone rigorous testing, delivering consistent and dependable results every time.

  • Automatic Batch Accumulation: TorchMetrics automates batch accumulation, eliminating the need for manual calculations and reducing potential errors.

  • Synchronization Across Devices: With TorchMetrics, data synchronization between multiple devices is handled automatically, making your multi-device experiments hassle-free.

Table of content

1. Installation

2. How the Metric class works?

3. What is a MetricCollection?

4. How to implement a custom metric?

5. Harnessing the Power of TorchMetrics with PyTorch Lightning

6. Wrapping up

1. Installation

To install the latest version of TorchMetrics from PyPI use:

!pip install torchmetrics

2. How the Metric class works?

TorchMetrics offers a wide range of pre-built metrics, including Accuracy, Dice, F1 Score, Recall, Mean Absolute Error, and more. Each of these metrics inherits from the base class called “Metric,” serving as the fundamental parent class for all other metrics within the library.

class Metric(Module, ABC):
    """Base class for all metrics present in the Metrics API.

    This class is inherited by all metrics and implements the following functionality:
    1. Handles the transfer of metric states to correct device
    2. Handles the synchronization of metric states across processes

    The three core methods of the base class are
    * ``add_state()`` => Add metric state variable. Only used by subclasses.
    * ``forward()`` => Aggregate and evaluate batch input directly.
    * ``reset()`` => Reset metric state variables to their default value.

    which should almost never be overwritten by child classes. Instead, the following methods should be overwritten
    * ``update()`` => Override this method to update the state variables of your metric class.
    * ``compute()`` => Override this method to compute the final metric value.
    """
    pass

Much like the structure of torch.nn, most metrics in TorchMetrics offer both module-based and functional versions. The functional versions are straightforward Python functions that take torch.tensors as input and return the corresponding metric as a torch.tensor. In the following code is a piece of code that demonstrates how to use TorchMetrics to calculate the mean absolute error (MAE) for a regression problem. Here’s a step-by-step explanation of the code:

# PyTorch library
import torch

# Import our library
import torchmetrics

# Simulate a regression problem
preds = torch.randn(10) # This 10-tensor represents the model's predictions.
target = torch.randn(10) # This 10-tensor represents the actual target values.

# Calculate mean absolute error (MAE) using TorchMetrics
mae = torchmetrics.functional.mean_absolute_error(preds, target) 
# The MAE is a metric commonly used to evaluate the performance of regression models

print("Mean Absolute Error:", mae.item())
# => Mean Absolute Error: 1.2326053380966187

When evaluating machine learning models, the assessment process involves examining their performance from various angles using multiple metrics. A single metric might not be sufficient to provide a comprehensive understanding of how well the model is performing, as different metrics capture different aspects of the model’s behavior. Therefore, it becomes essential to utilize a diverse set of evaluation metrics tailored to the specific task at hand.

For instance, in a classification problem, metrics like accuracy, precision, recall, F1 score, and area under the receiver operating characteristic curve (AUC-ROC) are commonly employed. Each of these metrics sheds light on different aspects of the model’s classification performance, such as overall accuracy, sensitivity, specificity, and the trade-off between precision and recall.

Similarly, in regression problems, metrics like mean squared error (MSE), mean absolute error (MAE), and R-squared are frequently utilized to gauge the model’s ability to approximate the target values accurately.

To effectively evaluate machine learning models and gain deeper insights into their strengths and weaknesses, it is recommended to consider using a combination of metrics. One approach to managing multiple metrics efficiently is by leveraging tools like torchmetrics.MetricCollection, which allows for the seamless computation of various metrics simultaneously, while ensuring their specialized logic and internal states are well-managed.

3. What is a MetricCollection?

it’s stated that the internal computations and state management that metrics perform during their evaluation process. Each metric has its own specific calculations and data structures to keep track of relevant information, such as counts, sums, or intermediate values, while computing the metric. and alos not recommending nesting metrics; scenario where you try to initialize one metric inside another metric’s code, creating a hierarchy of metrics. Example where we initialize a metric inside a custom one:

import torchmetrics

# This is NOT recommended
class NestedMetric(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.accuracy = torchmetrics.Accuracy() # intialize accuract metric

    def update(self, preds, target):
        self.accuracy.update(preds, target) # update the value

    def compute(self):
        return self.accuracy.compute() # compute the value

But instead combinging many metric together would be like that:

import torch
import torchmetrics

# Simulate a classification problem
preds = torch.randn(100, 5).softmax(dim=-1)
target = torch.randint(5, (100,))

# Create a MetricCollection to store metrics
metric_collection = torchmetrics.MetricCollection()
accuracy = torchmetrics.Accuracy()
f1_score = torchmetrics.F1(num_classes=5)

# Add metrics to the collection
metric_collection.add(accuracy, 'accuracy')
metric_collection.add(f1_score, 'f1_score')

# Compute all metrics
results = metric_collection(preds, target)

# Access individual metric results
accuracy_result = results['accuracy']
f1_score_result = results['f1_score']

print("Accuracy:", accuracy_result.item())
print("F1 Score:", f1_score_result.item())

4. How to implement a custom metric?

In the world of machine learning and deep learning, the performance of a model is often evaluated using various metrics. While popular libraries like TorchMetrics offer a wide range of pre-built metrics, there are instances where you may need a custom metric tailored to your specific task or problem. Implementing a custom metric allows you to capture domain-specific insights and measure the model’s success based on your unique requirements.

To get started with creating a custom metric, follow these steps:

1. Define the Metric Logic
Begin by defining the logic for your custom metric. Decide what values the metric should track and how it should be computed based on model predictions and ground-truth labels. Whether it's a complex calculation or a straightforward measure, having a clear understanding of the metric's behavior is crucial.
2. Subclass torchmetrics.Metric
To implement your custom metric, you'll need to create a subclass of torchmetrics.Metric, which provides the necessary infrastructure for handling metric updates, computations, and aggregations. By subclassing torchmetrics.Metric, you inherit a set of methods that facilitate the metric's functioning.
3. Implement Metric Methods

Within your custom metric class, you will need to implement three main methods:

1. __init__(): In this method, initialize any internal state variables required for your metric calculations. This might include counters, sums, or other data structures.

2. update(): This method receives model predictions and corresponding ground-truth labels and updates the internal state variables based on the metric logic defined earlier.

3. compute(): This method computes the final metric value based on the updated internal state variables. It returns the metric's result in a format compatible with TorchMetrics.

Let’s implement MeanSquaredError

import torch
import torchmetrics

class MeanSquaredError(torchmetrics.Metric):
    def __init__(self, compute_on_step=True):
        super().__init__(compute_on_step=compute_on_step)
        self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_samples", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        squared_error = torch.pow(preds - target, 2).sum()
        self.sum_squared_error += squared_error
        self.total_samples += target.numel()

    def compute(self):
        return self.sum_squared_error / self.total_samples

# Simulate a regression problem
preds = torch.randn(100)
target = torch.randn(100)

# Initialize the custom metric
mean_squared_error = MeanSquaredError()

# Update the metric with model predictions and target values
mean_squared_error.update(preds, target)

# Compute the mean squared error
mse = mean_squared_error.compute()

print("Mean Squared Error:", mse.item())

Explanation:

  • In the init method, we initialize two state variables: sum_squared_error, which keeps track of the sum of squared errors, and total_samples, which counts the total number of samples processed.
  • The update method receives model predictions (preds) and target values (target) as inputs. It calculates the squared error between the predictions and targets, updates the sum_squared_error, and increments the total_samples accordingly.
  • The compute method computes the final mean squared error by dividing the sum_squared_error by the total_samples.

When it comes to building and training sophisticated machine learning models, the combination of PyTorch Lightning and TorchMetrics offers an unparalleled synergy. PyTorch Lightning simplifies the model development process by providing a high-level interface for organizing code, managing training loops, and handling distributed training. On the other hand, TorchMetrics equips us with an extensive suite of pre-built metrics, easing the evaluation of model performance.

5. Harnessing the Power of TorchMetrics with PyTorch Lightning

When it comes to building and training sophisticated machine learning models, the combination of PyTorch Lightning and TorchMetrics offers an unparalleled synergy. PyTorch Lightning simplifies the model development process by providing a high-level interface for organizing code, managing training loops, and handling distributed training. On the other hand, TorchMetrics equips us with an extensive suite of pre-built metrics, easing the evaluation of model performance.

Let’s consider a simple example of using TorchMetrics with PyTorch Lightning for a binary classification task. We’ll use the popular Iris dataset for simplicity and build a basic model using PyTorch Lightning. Then, we’ll integrate TorchMetrics to compute evaluation metrics (Accuracy, Precision, and Recall) during the training process.

Let’s begin by importing module that we will be using and implement the classifier:


import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import Iris
from torchvision.transforms import ToTensor
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

# Define the model
class SimpleClassifier(LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        self.log("val_loss", loss)

Let’s train the model and evaluate it afterward:

if __name__ == "__main__":

    # Prepare the data
    dataset = Iris(root="data", download=True, transform=ToTensor())
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16)

    # Initialize TorchMetrics for evaluation
    metric_collection = torchmetrics.MetricCollection()
    accuracy = torchmetrics.Accuracy()
    precision = torchmetrics.Precision(num_classes=3, average='macro')
    recall = torchmetrics.Recall(num_classes=3, average='macro')

    metric_collection.add(accuracy, 'accuracy')
    metric_collection.add(precision, 'precision')
    metric_collection.add(recall, 'recall')

    # Train the model
    model = SimpleClassifier(input_dim=4, hidden_dim=16, output_dim=3)

    trainer = Trainer(max_epochs=10)
    trainer.fit(model, train_loader, val_loader)

    # Evaluate the model using TorchMetrics
    results = trainer.test(model, val_loader, verbose=False)

    # Access individual metric results
    accuracy_result = results[0]['accuracy']
    precision_result = results[0]['precision']
    recall_result = results[0]['recall']

    print("Validation Accuracy:", accuracy_result.item())
    print("Validation Precision:", precision_result.item())
    print("Validation Recall:", recall_result.item())

6. Wrapping up

To wrap up, TorchMetrics is an invaluable tool for PyTorch projects, offering a rich library of pre-built metrics and a user-friendly API. Its seamless integration with PyTorch and PyTorch Lightning simplifies model evaluation, allowing for better insights and informed decision-making. Elevate your model evaluation with TorchMetrics and unleash the true potential of your machine learning endeavors.

You May Also Like