Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structural metrics #245

Merged
merged 8 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### 1.6.0dev - UNRELEASED

#### Features

#### New Features
* [Metrics] - [#245](https://github.com/a-r-j/graphein/pull/221) Adds a selection of structural metrics relevant to protein structures.
* [Tensor Operations] - [#244](https://github.com/a-r-j/graphein/pull/221) Adds suite of utilities for working directly with tensor-based representations of proteins.

#### Improvements
Expand Down
2 changes: 2 additions & 0 deletions graphein/ml/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .gdt import GDT_TS, gdt
from .tm_score import TMScore, tm_score
100 changes: 100 additions & 0 deletions graphein/ml/metrics/gdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
from torchmetrics import Metric

from ...protein.tensor.types import AtomTensor, CoordTensor


def gdt(
y: Union[CoordTensor, AtomTensor],
y_hat: Union[CoordTensor, AtomTensor],
ca_only: bool = True,
cutoff: float = 4,
ts: bool = True,
) -> torch.Tensor:
"""Global Distance Deviation Test metric (GDDT).

https://en.wikipedia.org/wiki/Global_distance_test


The GDT score is calculated as the largest set of amino acid residues'
alpha carbon atoms in the model structure falling within a defined
distance cutoff of their position in the experimental structure, after
iteratively superimposing the two structures. By the original design the
GDT algorithm calculates 20 GDT scores, i.e. for each of 20 consecutive distance
cutoffs (``0.5 Å, 1.0 Å, 1.5 Å, ... 10.0 Å``). For structure similarity assessment
it is intended to use the GDT scores from several cutoff distances, and scores
generally increase with increasing cutoff. A plateau in this increase may
indicate an extreme divergence between the experimental and predicted structures,
such that no additional atoms are included in any cutoff of a reasonable distance.
The conventional GDT_TS total score in CASP is the average result of cutoffs at
``1``, ``2``, ``4``, and ``8`` Å.

:param y: Tensor of groundtruth (reference) atom positions.
:type y: Union[graphein.protein.tensor.CoordTensor, graphein.protein.tensor.AtomTensor]
:param y_hat: Tensor of atom positions.
:type y_hat: Union[graphein.protein.tensor.CoordTensor, graphein.protein.tensor.AtomTensor]
:param ca_only: Whether or not to consider only Ca positions. Default is ``True``.
:type ca_only: bool
:param cutoff: Custom threshold to use.
:type cutoff: float
:param ts: Whether or not to use "Total Score" mode, where the scores over the thresholds
``1, 2, 4, 8`` are averaged (as per CASP).
:type ts: bool
:returns: GDT score (torch.FloatTensor)
:rtype: torch.Tensor
"""
if y.ndim == 3:
y = y[:, 1, :] if ca_only else y.reshape(-1, 3)
if y_hat.ndim == 3:
y_hat = y_hat[:, 1, :] if ca_only else y_hat.reshape(-1, 3)
# Get distance between points
dist = torch.norm(y - y_hat, dim=1)

if not ts:
# Return fraction of distances below cutoff
return (dist < cutoff).sum() / dist.numel()
# Return mean fraction of distances below cutoff for each cutoff (1, 2, 4, 8)
count_1 = (dist < 1).sum() / dist.numel()
count_2 = (dist < 2).sum() / dist.numel()
count_4 = (dist < 4).sum() / dist.numel()
count_8 = (dist < 8).sum() / dist.numel()
return torch.mean(torch.tensor([count_1, count_2, count_4, count_8]))


class GDT_TS(Metric):
def __init__(self):
"""Torchmetrics implementation of GDT_TS."""
super().__init__()
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state(
"correct", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
higher_is_better = True
full_state_update = True

@property
def higher_is_better(self):
return True

def update(
self, preds: torch.Tensor, target: torch.Tensor, batch: torch.Tensor
):
"""Update method for metric.

:param pred: Tensor of predictions.
:type pred: torch.Tensor
:param target: Tensor of target structures.
:type target: torch.Tensor
:param batch. Batch tensor, indicating which indices belong to which example in the batch.
Assumes a PyTorch Geometric batching scheme.
type batch: torch.Tensor.
"""
y = unbatch(target, batch)
y_hat = unbatch(preds, batch)

for i, j in zip(y, y_hat):
self.correct += gdt(i, j)
self.total += 1

def compute(self) -> float:
return self.correct / self.total
82 changes: 82 additions & 0 deletions graphein/ml/metrics/tm_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Union

import torch
from torchmetrics import Metric

from ...protein.tensor.types import AtomTensor, CoordTensor


def tm_score(
y_hat: Union[CoordTensor, AtomTensor], y: Union[CoordTensor, AtomTensor]
) -> torch.Tensor:
"""Compute TMScore between ``y_hat`` and ``y``. Requires aligned structures.

TM-score is a measure of similarity between two protein structures.
The TM-score is intended as a more accurate measure of the global
similarity of full-length protein structures than the often used RMSD
measure. The TM-score indicates the similarity between two structures
by a score between ``[0, 1]``, where 1 indicates a perfect match
between two structures (thus the higher the better). Generally scores
below 0.20 corresponds to randomly chosen unrelated proteins whereas
structures with a score higher than 0.5 assume roughly the same fold.
A quantitative study shows that proteins of TM-score = 0.5 have a
posterior probability of 37% in the same CATH topology family and of
13% in the same SCOP fold family. The probabilities increase rapidly
when TM-score > 0.5. The TM-score is designed to be independent of
protein lengths.

https://en.wikipedia.org/wiki/Template_modeling_score

:param y_hat: Tensor of atom positions (aligned to ``y``).
:type y_hat: Union[graphein.protein.tensor.types.CoordTensor,
graphein.protein.tensor.types.AtomTensor]
:param y: Tensor of groundtruth/reference atom positions.
:type y: Union[graphein.protein.tensor.types.CoordTensor,
graphein.protein.tensor.types.AtomTensor]
:returns: TMScore of aligned pair.
:rtype: torch.Tensor
"""
# Get CA
if y_hat.ndim == 3:
y_hat = y_hat[:, 1, :]
if y.ndim == 3:
y = y[:, 1, :]

l_target = y.shape[0]

d0_l_target = 1.24 * np.power(l_target - 15, 1 / 3) - 1.8

di = torch.pairwise_distance(y_hat, y)

return torch.sum(1 / (1 + (di / d0_l_target) ** 2)) / l_target


class TMScore(Metric):
def __init__(self):
super().__init__()
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state(
"correct", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
higher_is_better: bool = True
full_state_update: bool = True

@property
def higher_is_better(self):
return True

def update(
self,
pred: Union[COORD_TENSOR, ATOM_TENSOR],
target: Union[COORD_TENSOR, ATOM_TENSOR],
batch: torch.Tensor,
):
y = unbatch(target, batch)
y_hat = unbatch(pred, batch)

for i, j in zip(y, y_hat):
self.correct += tm_score(i, j)
self.total += 1

def compute(self):
return self.correct / self.total