Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imports: deprecate from pkg root [4/n] Retrieval #1699

Merged
merged 11 commits into from
Apr 11, 2023
24 changes: 13 additions & 11 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,20 @@
TweedieDevianceScore,
WeightedMeanAbsolutePercentageError,
)
from torchmetrics.retrieval import RetrievalFallOut # noqa: E402
from torchmetrics.retrieval import RetrievalHitRate # noqa: E402
from torchmetrics.retrieval import ( # noqa: E402
RetrievalMAP,
RetrievalMRR,
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalPrecisionRecallCurve,
RetrievalRecall,
RetrievalRecallAtFixedPrecision,
RetrievalRPrecision,
from torchmetrics.retrieval._deprecated import _RetrievalFallOut as RetrievalFallOut # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalHitRate as RetrievalHitRate # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalMAP as RetrievalMAP # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalMRR as RetrievalMRR # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalNormalizedDCG as RetrievalNormalizedDCG # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalPrecision as RetrievalPrecision # noqa: E402
from torchmetrics.retrieval._deprecated import ( # noqa: E402
_RetrievalPrecisionRecallCurve as RetrievalPrecisionRecallCurve,
)
from torchmetrics.retrieval._deprecated import _RetrievalRecall as RetrievalRecall # noqa: E402
from torchmetrics.retrieval._deprecated import ( # noqa: E402
_RetrievalRecallAtFixedPrecision as RetrievalRecallAtFixedPrecision,
)
from torchmetrics.retrieval._deprecated import _RetrievalRPrecision as RetrievalRPrecision # noqa: E402
from torchmetrics.text import ( # noqa: E402
BLEUScore,
CharErrorRate,
Expand Down
20 changes: 11 additions & 9 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,17 @@
from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
from torchmetrics.functional.regression.wmape import weighted_mean_absolute_percentage_error
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.retrieval._deprecated import _retrieval_average_precision as retrieval_average_precision
from torchmetrics.functional.retrieval._deprecated import _retrieval_fall_out as retrieval_fall_out
from torchmetrics.functional.retrieval._deprecated import _retrieval_hit_rate as retrieval_hit_rate
from torchmetrics.functional.retrieval._deprecated import _retrieval_normalized_dcg as retrieval_normalized_dcg
from torchmetrics.functional.retrieval._deprecated import _retrieval_precision as retrieval_precision
from torchmetrics.functional.retrieval._deprecated import (
_retrieval_precision_recall_curve as retrieval_precision_recall_curve,
)
from torchmetrics.functional.retrieval._deprecated import _retrieval_r_precision as retrieval_r_precision
from torchmetrics.functional.retrieval._deprecated import _retrieval_recall as retrieval_recall
from torchmetrics.functional.retrieval._deprecated import _retrieval_reciprocal_rank as retrieval_reciprocal_rank
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.chrf import chrf_score
Expand Down
30 changes: 21 additions & 9 deletions src/torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve # noqa: F401
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank

__all__ = [
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
"retrieval_normalized_dcg",
"retrieval_precision",
"retrieval_precision_recall_curve",
"retrieval_r_precision",
"retrieval_recall",
"retrieval_reciprocal_rank",
]
140 changes: 140 additions & 0 deletions src/torchmetrics/functional/retrieval/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Optional, Tuple

from torch import Tensor

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.utilities.prints import _deprecated_root_import_func


def _retrieval_average_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_average_precision(preds, target)
tensor(0.8333)
"""
_deprecated_root_import_func("retrieval_average_precision", "retrieval")
return retrieval_average_precision(preds=preds, target=target, top_k=top_k)


def _retrieval_fall_out(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_fall_out(preds, target, top_k=2)
tensor(1.)
"""
_deprecated_root_import_func("retrieval_fall_out", "retrieval")
return retrieval_fall_out(preds=preds, target=target, top_k=top_k)


def _retrieval_hit_rate(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_hit_rate(preds, target, top_k=2)
tensor(1.)
"""
_deprecated_root_import_func("retrieval_hit_rate", "retrieval")
return retrieval_hit_rate(preds=preds, target=target, top_k=top_k)


def _retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([.1, .2, .3, 4, 70])
>>> target = tensor([10, 0, 0, 1, 5])
>>> _retrieval_normalized_dcg(preds, target)
tensor(0.6957)
"""
_deprecated_root_import_func("retrieval_normalized_dcg", "retrieval")
return retrieval_normalized_dcg(preds=preds, target=target, top_k=top_k)


def _retrieval_precision(
preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False
) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_precision(preds, target, top_k=2)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_precision", "retrieval")
return retrieval_precision(preds=preds, target=target, top_k=top_k, adaptive_k=adaptive_k)


def _retrieval_precision_recall_curve(
preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> precisions, recalls, top_k = _retrieval_precision_recall_curve(preds, target, max_k=2)
>>> precisions
tensor([1.0000, 0.5000])
>>> recalls
tensor([0.5000, 0.5000])
>>> top_k
tensor([1, 2])
"""
_deprecated_root_import_func("retrieval_precision_recall_curve", "retrieval")
return retrieval_precision_recall_curve(preds=preds, target=target, max_k=max_k, adaptive_k=adaptive_k)


def _retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_r_precision(preds, target)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_r_precision", "retrieval")
return retrieval_r_precision(preds=preds, target=target)


def _retrieval_recall(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_recall(preds, target, top_k=2)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_recall", "retrieval")
return retrieval_recall(preds=preds, target=target, top_k=top_k)


def _retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([False, True, False])
>>> _retrieval_reciprocal_rank(preds, target)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_reciprocal_rank", "retrieval")
return retrieval_reciprocal_rank(preds=preds, target=target)
34 changes: 22 additions & 12 deletions src/torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.retrieval.average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.average_precision import RetrievalMAP
from torchmetrics.retrieval.base import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.fall_out import RetrievalFallOut # noqa: F401
from torchmetrics.retrieval.hit_rate import RetrievalHitRate # noqa: F401
from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.precision_recall_curve import ( # noqa: F401
RetrievalPrecisionRecallCurve,
RetrievalRecallAtFixedPrecision,
)
from torchmetrics.retrieval.r_precision import RetrievalRPrecision # noqa: F401
from torchmetrics.retrieval.recall import RetrievalRecall # noqa: F401
from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.fall_out import RetrievalFallOut
from torchmetrics.retrieval.hit_rate import RetrievalHitRate
from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG
from torchmetrics.retrieval.precision import RetrievalPrecision
from torchmetrics.retrieval.precision_recall_curve import RetrievalPrecisionRecallCurve, RetrievalRecallAtFixedPrecision
from torchmetrics.retrieval.r_precision import RetrievalRPrecision
from torchmetrics.retrieval.recall import RetrievalRecall
from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR

__all__ = [
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
"RetrievalMRR",
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalPrecisionRecallCurve",
"RetrievalRecall",
"RetrievalRecallAtFixedPrecision",
"RetrievalRPrecision",
]
Loading