Skip to content

Commit

Permalink
SeaseLtd#109: Updates to NDCG@K implementation:
Browse files Browse the repository at this point in the history
- Make maxgrade and fairgrade configurable at construction time;
- Use floating point grade values in gain function.
  • Loading branch information
Matt Pearce committed Mar 16, 2020
1 parent 8fcd697 commit 9ec19ad
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.StreamSupport;

Expand All @@ -42,15 +43,43 @@
*/
public class NDCGAtK extends Metric {
private final static BigDecimal TWO = new BigDecimal(2);
final int k;

private final BigDecimal fairgrade;
private final BigDecimal maxgrade;
private final int k;

/**
* Builds a new NDCGAtK metric with default maximum and missing judgement
* grades.
*
* @param k the top k reference elements used for building the measure.
*/
public NDCGAtK(final int k) {
this(k, null, null);
}

/**
* Builds a new NDCGAtK metric.
*
* @param k the top k reference elements used for building the measure.
* @param maxgrade the maximum grade available when judging documents. If
* {@code null}, will default to {@link Metric#DEFAULT_MAX_GRADE}.
* @param defaultgrade the default grade to use when judging documents. If
* {@code null}, will default to either {@code maxgrade / 2}
* or {@link Metric#DEFAULT_MISSING_GRADE}, depending
* whether or not {@code maxgrade} has been specified.
*/
public NDCGAtK(@JsonProperty("k") final int k) {
public NDCGAtK(@JsonProperty("k") final int k,
@JsonProperty("maxgrade") final Float maxgrade,
@JsonProperty("defaultgrade") final Float defaultgrade) {
super("NDCG@" + k);
if (maxgrade == null) {
this.maxgrade = DEFAULT_MAX_GRADE;
this.fairgrade = Optional.ofNullable(defaultgrade).map(BigDecimal::valueOf).orElse(DEFAULT_MISSING_GRADE);
} else {
this.maxgrade = BigDecimal.valueOf(maxgrade);
this.fairgrade = Optional.ofNullable(defaultgrade).map(BigDecimal::valueOf).orElseGet(() -> this.maxgrade.divide(TWO, 8, RoundingMode.HALF_UP));
}
this.k = k;
}

Expand All @@ -69,8 +98,8 @@ public void collect(final Map<String, Object> hit, final int rank, final String
if (rank > k) return;
judgment(id(hit))
.ifPresent(judgment -> {
final BigDecimal value = gainOrRatingNode(judgment).map(JsonNode::decimalValue).orElse(TWO);
BigDecimal numerator = TWO.pow(value.intValue()).subtract(BigDecimal.ONE);
final BigDecimal value = gainOrRatingNode(judgment).map(JsonNode::decimalValue).orElse(fairgrade);
BigDecimal numerator = BigDecimal.valueOf(Math.pow(TWO.doubleValue(), value.doubleValue())).subtract(BigDecimal.ONE);
if (rank == 1) {
dcg = numerator;
} else {
Expand Down Expand Up @@ -98,28 +127,28 @@ public BigDecimal value() {

private BigDecimal idealDcg(final JsonNode relevantDocuments) {
final int windowSize = Math.min(relevantDocuments.size(), k);
final int[] gains = new int[windowSize];
final double[] gains = new double[windowSize];

final Map<Integer, List<JsonNode>> groups =
final Map<BigDecimal, List<JsonNode>> groups =
StreamSupport.stream(relevantDocuments.spliterator(), false)
.collect(groupingBy(doc -> gainOrRatingNode(doc).map(JsonNode::intValue).orElse(2)));
.collect(groupingBy(doc -> gainOrRatingNode(doc).map(JsonNode::decimalValue).orElse(fairgrade)));

Set<Integer> ratingValues = groups.keySet();
List<Integer> ratingsSorted = new ArrayList<>(ratingValues);
Set<BigDecimal> ratingValues = groups.keySet();
List<BigDecimal> ratingsSorted = new ArrayList<>(ratingValues);
ratingsSorted.sort(Collections.reverseOrder());
int startIndex = 0;
for (Integer ratingValue : ratingsSorted) {
for (BigDecimal ratingValue : ratingsSorted) {
if (startIndex < windowSize) {
List<JsonNode> docsPerRating = groups.get(ratingValue);
int endIndex = startIndex + docsPerRating.size();
Arrays.fill(gains, startIndex, Math.min(windowSize, endIndex), ratingValue);
Arrays.fill(gains, startIndex, Math.min(windowSize, endIndex), ratingValue.doubleValue());
startIndex = endIndex;
}
}

BigDecimal result = BigDecimal.ZERO;
for (int i = 1; i <= gains.length; i++) {
BigDecimal num = TWO.pow(gains[i-1]).subtract(BigDecimal.ONE);
BigDecimal num = BigDecimal.valueOf(Math.pow(TWO.doubleValue(), gains[i-1])).subtract(BigDecimal.ONE);
double den = Math.log(i + 1) / Math.log(2);
result = result.add((num.divide(new BigDecimal(den), 2, RoundingMode.FLOOR)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,44 @@ public void _10_judgments_15_search_only_results_10th_relevant_result() {
cut.valueFactory(A_VERSION).value().doubleValue(),
0);
}

/**
* Scenario: 10 judgments, 15 search results, 5 relevant results in top positions.
*/
@Test
public void _10_judgments_15_search_only_results_10th_relevant_result_topscore() {
final ObjectNode judgements = mapper.createObjectNode();
stream(FIFTEEN_SEARCH_HITS).limit(10).forEach(docid -> judgements.set(docid, createJudgmentNode(4)));
cut.setRelevantDocuments(judgements);

cut.setTotalHits(FIFTEEN_SEARCH_HITS.length, A_VERSION);
stream(ANOTHER_FIVE_SEARCH_HITS)
.map(this::searchHit)
.forEach(hit -> cut.collect(hit, counter.incrementAndGet(), A_VERSION));

stream(ANOTHER_FOUR_SEARCH_HITS)
.map(this::searchHit)
.forEach(hit -> cut.collect(hit, counter.incrementAndGet(), A_VERSION));

cut.collect(searchHit(FIFTEEN_SEARCH_HITS[9]), counter.incrementAndGet(), A_VERSION);

Map<Integer, Double> expectations = new HashMap<Integer, Double>()
{{
put(1,0.0);
put(2,0.0);
put(3,0.0);
put(4,0.0);
put(5,0.0);
put(6,0.0);
put(7,0.0);
put(8, 0.0);
put(9,0.0);
put(10, 0.06);
}};

assertEquals(
expectations.get(currentAppliedK),
cut.valueFactory(A_VERSION).value().doubleValue(),
0);
}
}

0 comments on commit 9ec19ad

Please sign in to comment.