Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tschuprow's T and Pearson's Contingency Coefficient #1334

Merged
merged 22 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
.. _Tschuprow's T: https://en.wikipedia.org/wiki/Tschuprow%27s_T
.. _Pearson's Contingency Coefficient: https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/pearcont.htm
26 changes: 26 additions & 0 deletions docs/source/nominal/pearsons_contingency_coefficient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Pearson's Contingency Coefficient
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

#################################
Pearson's Contingency Coefficient
#################################

Module Interface
________________

.. autoclass:: torchmetrics.PearsonsContingencyCoefficient
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.pearsons_contingency_coefficient
:noindex:

pearsons_contingency_coefficient_matrix
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.pearsons_contingency_coefficient_matrix
:noindex:
26 changes: 26 additions & 0 deletions docs/source/nominal/tschuprows_t.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Tschuprow's T
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

#############
Tschuprow's T
#############

Module Interface
________________

.. autoclass:: torchmetrics.TschuprowsT
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.tschuprows_t
:noindex:

tschuprows_t_matrix
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.tschuprows_t_matrix
:noindex:
1 change: 1 addition & 0 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pandas # cannot pin version due to numpy version incompatibility
dython # todo: pin version, but some version resolution issue
scipy
stancld marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TschuprowsT # noqa: E402
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
Expand Down Expand Up @@ -152,6 +152,7 @@
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PearsonCorrCoef",
"PearsonsContingencyCoefficient",
"PermutationInvariantTraining",
"Perplexity",
"Precision",
Expand Down Expand Up @@ -186,6 +187,7 @@
"SymmetricMeanAbsolutePercentageError",
"TotalVariation",
"TranslationEditRate",
"TschuprowsT",
"UniversalImageQualityIndex",
"WeightedMeanAbsolutePercentageError",
"WordErrorRate",
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.nominal.cramers import cramers_v
from torchmetrics.functional.nominal.pearson import pearsons_contingency_coefficient
from torchmetrics.functional.nominal.tschuprows import tschuprows_t
stancld marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
Expand Down Expand Up @@ -131,6 +133,7 @@
"pairwise_linear_similarity",
"pairwise_manhattan_distance",
"pearson_corrcoef",
"pearsons_contingency_coefficient",
"permutation_invariant_training",
"perplexity",
"pit_permutate",
Expand Down Expand Up @@ -165,6 +168,7 @@
"symmetric_mean_absolute_percentage_error",
"total_variation",
"translation_edit_rate",
"tschuprows_t",
"universal_image_quality_index",
"spectral_angle_mapper",
"weighted_mean_absolute_percentage_error",
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix # noqa: F401
from torchmetrics.functional.nominal.pearson import ( # noqa: F401
pearsons_contingency_coefficient,
pearsons_contingency_coefficient_matrix,
)
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix # noqa: F401
70 changes: 14 additions & 56 deletions src/torchmetrics/functional/nominal/cramers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,14 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
from torchmetrics.functional.nominal.utils import _handle_nan_in_data
from torchmetrics.utilities.prints import rank_zero_warn


def _cramers_input_validation(nan_strategy: str, nan_replace_value: Optional[Union[int, float]]) -> None:
if nan_strategy not in ["replace", "drop"]:
raise ValueError(
f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
)
if nan_strategy == "replace" and not isinstance(nan_replace_value, (int, float)):
raise ValueError(
"Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
f"but got {nan_replace_value}"
)


def _compute_expected_freqs(confmat: Tensor) -> Tensor:
"""Compute the expected frequenceis from the provided confusion matrix."""
margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
expected_freqs = torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
return expected_freqs


def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Chi-square test of independenc of variables in a confusion matrix table.

Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
"""
expected_freqs = _compute_expected_freqs(confmat)
# Get degrees of freedom
df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
if df == 0:
return torch.tensor(0.0, device=confmat.device)

if df == 1 and bias_correction:
diff = expected_freqs - confmat
direction = diff.sign()
confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())

return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)


def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
"""Drop all rows and columns containing only zeros."""
confmat = confmat[confmat.sum(1) != 0]
confmat = confmat[:, confmat.sum(0) != 0]
return confmat
from torchmetrics.functional.nominal.utils import (
_compute_bias_corrected_values,
_compute_chi_squared,
_drop_empty_rows_and_cols,
_handle_nan_in_data,
_nominal_input_validation,
_unable_to_use_bias_correction_warning,
)


def _cramers_v_update(
Expand Down Expand Up @@ -110,15 +71,11 @@ def _cramers_v_compute(confmat: Tensor, bias_correction: bool) -> Tensor:
n_rows, n_cols = confmat.shape

if bias_correction:
phi_squared_corrected = torch.max(
torch.tensor(0.0, device=confmat.device), phi_squared - ((n_rows - 1) * (n_cols - 1)) / (cm_sum - 1)
phi_squared_corrected, rows_corrected, cols_corrected = _compute_bias_corrected_values(
phi_squared, n_rows, n_cols, cm_sum
)
rows_corrected = n_rows - (n_rows - 1) ** 2 / (cm_sum - 1)
cols_corrected = n_cols - (n_cols - 1) ** 2 / (cm_sum - 1)
if min(rows_corrected, cols_corrected) == 1:
rank_zero_warn(
"Unable to compute Cramer's V using bias correction. Please consider to set `bias_correction=False`."
)
_unable_to_use_bias_correction_warning(metric_name="Cramer's V")
return torch.tensor(float("nan"), device=confmat.device)
cramers_v_value = torch.sqrt(phi_squared_corrected / min(rows_corrected - 1, cols_corrected - 1))
else:
Expand All @@ -136,7 +93,7 @@ def cramers_v(
r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series.

.. math::
V = \sqrt{\frac{\chi^2 / 2}{\min(r - 1, k - 1)}}
V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}

where

Expand Down Expand Up @@ -172,6 +129,7 @@ def cramers_v(
>>> cramers_v(preds, target)
tensor(0.5284)
"""
_nominal_input_validation(nan_strategy, nan_replace_value)
num_classes = len(torch.cat([preds, target]).unique())
confmat = _cramers_v_update(preds, target, num_classes, nan_strategy, nan_replace_value)
return _cramers_v_compute(confmat, bias_correction)
Expand Down Expand Up @@ -210,7 +168,7 @@ def cramers_v_matrix(
[0.0542, 0.0000, 0.0000, 1.0000, 0.1100],
[0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])
"""
_cramers_input_validation(nan_strategy, nan_replace_value)
_nominal_input_validation(nan_strategy, nan_replace_value)
num_variables = matrix.shape[1]
cramers_v_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
for i, j in itertools.combinations(range(num_variables), 2):
Expand Down
165 changes: 165 additions & 0 deletions src/torchmetrics/functional/nominal/pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
from torchmetrics.functional.nominal.utils import (
_compute_chi_squared,
_drop_empty_rows_and_cols,
_handle_nan_in_data,
_nominal_input_validation,
)


def _pearsons_contingency_coefficient_update(
preds: Tensor,
target: Tensor,
num_classes: int,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
"""Computes the bins to update the confusion matrix with for Pearson's Contingency Coefficient calculation.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
target: 1D or 2D tensor of categorical (nominal) data
num_classes: Integer specifing the number of classes
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```

Returns:
Non-reduced confusion matrix
"""
preds = preds.argmax(1) if preds.ndim == 2 else preds
target = target.argmax(1) if target.ndim == 2 else target
preds, target = _handle_nan_in_data(preds, target, nan_strategy, nan_replace_value)
return _multiclass_confusion_matrix_update(preds, target, num_classes)


def _pearsons_contingency_coefficient_compute(confmat: Tensor) -> Tensor:
"""Compute Pearson's Contingency Coefficient based on a pre-computed confusion matrix.

Args:
confmat: Confusion matrix for observed data

Returns:
Pearson's Contingency Coefficient
"""
confmat = _drop_empty_rows_and_cols(confmat)
cm_sum = confmat.sum()
chi_squared = _compute_chi_squared(confmat, bias_correction=False)
phi_squared = chi_squared / cm_sum

tschuprows_t_value = torch.sqrt(phi_squared / (1 + phi_squared))
return tschuprows_t_value.clamp(0.0, 1.0)


def pearsons_contingency_coefficient(
preds: Tensor,
target: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Compute `Pearson's Contingency Coefficient`_ measuring the association between two categorical (nominal)
ata series.
stancld marked this conversation as resolved.
Show resolved Hide resolved

.. math::
T = \sqrt{\frac{\chi^2 / n}{\frac{1 + \chi^2 / n}}}

where

.. math::
\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}

Pearson's Contingency Coefficient is a symmetric coefficient, i.e.
stancld marked this conversation as resolved.
Show resolved Hide resolved

.. math::
T(preds, target) = T(target, preds)

The output values lies in [0, 1].
Borda marked this conversation as resolved.
Show resolved Hide resolved

Args:
preds: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
target: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
Borda marked this conversation as resolved.
Show resolved Hide resolved
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``

Returns:
Pearson's Contingency Coefficient

Example:
>>> from torchmetrics.functional import pearsons_contingency_coefficient
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> pearsons_contingency_coefficient(preds, target)
tensor(0.6948)
"""
_nominal_input_validation(nan_strategy, nan_replace_value)
num_classes = len(torch.cat([preds, target]).unique())
confmat = _pearsons_contingency_coefficient_update(preds, target, num_classes, nan_strategy, nan_replace_value)
return _pearsons_contingency_coefficient_compute(confmat)


def pearsons_contingency_coefficient_matrix(
matrix: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Compute `Pearson's Contingency Coefficient`_ statistic between a set of multiple variables.

This can serve as a convenient tool to compute Pearson's Contingency Coefficient for analyses
of correlation between categorical variables in your dataset.

Args:
matrix: A tensor of categorical (nominal) data, where:
- rows represent a number of data points
- columns represent a number of categorical (nominal) features
Borda marked this conversation as resolved.
Show resolved Hide resolved
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``

Returns:
Pearson's Contingency Coefficient statistic for a dataset of categorical variables

Example:
>>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> pearsons_contingency_coefficient_matrix(matrix)
tensor([[1.0000, 0.2326, 0.1959, 0.2262, 0.2989],
[0.2326, 1.0000, 0.1386, 0.1895, 0.1329],
[0.1959, 0.1386, 1.0000, 0.1840, 0.2335],
[0.2262, 0.1895, 0.1840, 1.0000, 0.2737],
[0.2989, 0.1329, 0.2335, 0.2737, 1.0000]])
"""
_nominal_input_validation(nan_strategy, nan_replace_value)
num_variables = matrix.shape[1]
pearsons_cont_coef_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
for i, j in itertools.combinations(range(num_variables), 2):
x, y = matrix[:, i], matrix[:, j]
num_classes = len(torch.cat([x, y]).unique())
confmat = _pearsons_contingency_coefficient_update(x, y, num_classes, nan_strategy, nan_replace_value)
pearsons_cont_coef_matrix_value[i, j] = pearsons_cont_coef_matrix_value[
j, i
] = _pearsons_contingency_coefficient_compute(confmat)
Borda marked this conversation as resolved.
Show resolved Hide resolved
return pearsons_cont_coef_matrix_value
Loading