Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Expected Reciprocal Rank metric #31891

Merged
merged 5 commits into from
Jul 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ public <T> void declareNamedObjects(BiConsumer<Value, List<T>> consumer, NamedOb
}
}

@Override
public String getName() {
return objectParser.getName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ public boolean equals(Object obj) {
return false;
}
DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj;
return (this.dcg == other.dcg &&
this.idcg == other.idcg &&
this.unratedDocs == other.unratedDocs);
return Double.compare(this.dcg, other.dcg) == 0 &&
Double.compare(this.idcg, other.idcg) == 0 &&
this.unratedDocs == other.unratedDocs;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.rankeval;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;

/**
* Implementation of the Expected Reciprocal Rank metric described in:<p>
*
* Chapelle, O., Metlzer, D., Zhang, Y., &amp; Grinspan, P. (2009).<br>
* Expected reciprocal rank for graded relevance.<br>
* Proceeding of the 18th ACM Conference on Information and Knowledge Management - CIKM ’09, 621.<br>
* https://doi.org/10.1145/1645953.1646033
*/
public class ExpectedReciprocalRank implements EvaluationMetric {

/** the default search window size */
private static final int DEFAULT_K = 10;

/** the search window size */
private final int k;

/**
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
*/
private final Integer unknownDocRating;

private final int maxRelevance;

private final double two_pow_maxRelevance;

public static final String NAME = "expected_reciprocal_rank";

public ExpectedReciprocalRank(int maxRelevance) {
this(maxRelevance, null, DEFAULT_K);
}

/**
* @param maxRelevance
* the maximal relevance judgment in the evaluation dataset
* @param unknownDocRating
* the rating for documents the user hasn't supplied an explicit
* rating for. Can be {@code null}, in which case document is
* skipped.
* @param k
* the search window size all request use.
*/
public ExpectedReciprocalRank(int maxRelevance, @Nullable Integer unknownDocRating, int k) {
this.maxRelevance = maxRelevance;
this.unknownDocRating = unknownDocRating;
this.k = k;
// we can pre-calculate the constant used in metric calculation
this.two_pow_maxRelevance = Math.pow(2, this.maxRelevance);
}

ExpectedReciprocalRank(StreamInput in) throws IOException {
this.maxRelevance = in.readVInt();
this.unknownDocRating = in.readOptionalVInt();
this.k = in.readVInt();
this.two_pow_maxRelevance = Math.pow(2, this.maxRelevance);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(maxRelevance);
out.writeOptionalVInt(unknownDocRating);
out.writeVInt(k);
}

@Override
public String getWriteableName() {
return NAME;
}

int getK() {
return this.k;
}

int getMaxRelevance() {
return this.maxRelevance;
}

/**
* get the rating used for unrated documents
*/
public Integer getUnknownDocRating() {
return this.unknownDocRating;
}


@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}

@Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
if (ratedHits.size() > this.k) {
ratedHits = ratedHits.subList(0, k);
}
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
int unratedResults = 0;
for (RatedSearchHit hit : ratedHits) {
// unknownDocRating might be null, in which case unrated will be ignored in the calculation.
// we still need to add them as a placeholder so the rank of the subsequent ratings is correct
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
if (hit.getRating().isPresent() == false) {
unratedResults++;
}
}

double p = 1;
double err = 0;
int rank = 1;
for (Integer rating : ratingsInSearchHits) {
if (rating != null) {
double probR = probabilityOfRelevance(rating);
err = err + (p * probR / rank);
p = p * (1 - probR);
}
rank++;
}

EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, err);
evalQueryQuality.addHitsAndRatings(ratedHits);
evalQueryQuality.setMetricDetails(new Detail(unratedResults));
return evalQueryQuality;
}

double probabilityOfRelevance(Integer rating) {
return (Math.pow(2, rating) - 1) / this.two_pow_maxRelevance;
}

private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ParseField MAX_RELEVANCE_FIELD = new ParseField("maximum_relevance");
private static final ConstructingObjectParser<ExpectedReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("dcg", false,
args -> {
int maxRelevance = (Integer) args[0];
Integer optK = (Integer) args[2];
return new ExpectedReciprocalRank(maxRelevance, (Integer) args[1],
optK == null ? DEFAULT_K : optK);
});


static {
PARSER.declareInt(constructorArg(), MAX_RELEVANCE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}

public static ExpectedReciprocalRank fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(MAX_RELEVANCE_FIELD.getPreferredName(), this.maxRelevance);
if (unknownDocRating != null) {
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
}
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
}

@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
ExpectedReciprocalRank other = (ExpectedReciprocalRank) obj;
return this.k == other.k &&
this.maxRelevance == other.maxRelevance
&& Objects.equals(unknownDocRating, other.unknownDocRating);
}

@Override
public final int hashCode() {
return Objects.hash(unknownDocRating, k, maxRelevance);
}

public static final class Detail implements MetricDetail {

private static ParseField UNRATED_FIELD = new ParseField("unrated_docs");
private final int unratedDocs;

Detail(int unratedDocs) {
this.unratedDocs = unratedDocs;
}

Detail(StreamInput in) throws IOException {
this.unratedDocs = in.readVInt();
}

@Override
public
String getMetricName() {
return NAME;
}

@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
return builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs);
}

private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Integer) args[0]);
});

static {
PARSER.declareInt(constructorArg(), UNRATED_FIELD);
}

public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(this.unratedDocs);
}

@Override
public String getWriteableName() {
return NAME;
}

/**
* @return the number of unrated documents in the search results
*/
public Object getUnratedDocs() {
return this.unratedDocs;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
ExpectedReciprocalRank.Detail other = (ExpectedReciprocalRank.Detail) obj;
return this.unratedDocs == other.unratedDocs;
}

@Override
public int hashCode() {
return Objects.hash(this.unratedDocs);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
/**
* Assuming the docs are ranked in the following order:
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 | 
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
Expand All @@ -82,7 +82,7 @@ public void testDCGAt() {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* ---------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
Expand All @@ -101,7 +101,7 @@ public void testDCGAt() {
* This tests metric when some documents in the search result don't have a
* rating provided by the user.
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
Expand Down Expand Up @@ -134,7 +134,7 @@ public void testDCGAtSixMissingRatings() {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* ----------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
Expand All @@ -154,7 +154,7 @@ public void testDCGAtSixMissingRatings() {
* documents than search hits because we restrict DCG to be calculated at the
* fourth position
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
Expand Down Expand Up @@ -191,7 +191,7 @@ public void testDCGAtFourMoreRatings() {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* rank | relevance | 2^(relevance) - 1 | log_2(rank + 1) | (2^(relevance) - 1) / log_2(rank + 1)
* ---------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
Expand Down
Loading