Skip to content

Latest commit

 

History

History
189 lines (137 loc) · 3.57 KB

PrecisionAtK.md

File metadata and controls

189 lines (137 loc) · 3.57 KB

TFSimilarity.retrieval_metrics.PrecisionAtK

Precision@K is computed as.

Inherits From: RetrievalMetric, ABC

TFSimilarity.retrieval_metrics.PrecisionAtK(
    name: str = precision, k: int = 5, **kwargs
) -> None

$$ P_i@k = \fracTP_i}{TP_i+FP_i} = \frac{\sum_{j = 1}^{k} {rel_i_j}}{K $$

Where: K is the number of neighbors in the i_th query result set. rel is the relevance mask (indicator function) for the i_th query. i represents the i_th query. j represents the j_th ranked query result.

P@K is unordered and does not take into account the rank of the TP results.

This metric is useful when we are interested in evaluating the embedding within the context of a kNN classifier or as part of a clustering method.

Args

name Name associated with the metric object, e.g., precision@5
canonical_name The canonical name associated with metric, e.g., precision@K
k The number of nearest neighbors over which the metric is computed.
distance_threshold The max distance below which a nearest neighbor is considered a valid match.
average 'micro', 'macro' Determines the type of averaging performed on the data.
  • 'micro': Calculates metrics globally over all data.

  • 'macro': Calculates metrics for each label and takes the unweighted mean.

Attributes

name

Methods

compute

View source

compute(
    *,
    query_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
    match_mask: <a href="../../TFSimilarity/utils/BoolTensor.md">TFSimilarity.utils.BoolTensor```
</a>,
    **kwargs
) -> <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>

Compute the metric

Args
query_labels A 1D array of the labels associated with the embedding queries.
match_mask A 2D mask where a 1 indicates a match between the jth query and the kth neighbor and a 0 indicates a mismatch.
**kwargs Additional compute args.
Returns
A rank 0 tensor containing the metric.

get_config

View source

get_config()