Skip to content

Commit

Permalink
[ issue #38 ] F-Measure (F1, F0.5 and F2)
Browse files Browse the repository at this point in the history
  • Loading branch information
agazzarini committed Sep 16, 2018
1 parent dd6b1f6 commit cc10d29
Show file tree
Hide file tree
Showing 19 changed files with 550 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ public abstract class Metric implements HitsCollector {
protected String idFieldName = DEFAULT_ID_FIELD_NAME;
protected JsonNode relevantDocuments;
protected Map<String, ValueFactory> values = new LinkedHashMap<>();
protected List<String> versions;

/**
* Sets into this metrics the different versions available in the current evaluation process.
*
* @param versions the different versions available in the current evaluation process.
*/
public void setVersions(final List<String> versions) {
versions.forEach(version -> values.put(version, valueFactory()));
this.versions = versions;
versions.forEach(version -> values.put(version, createValueFactory(version)));
}

/**
Expand Down Expand Up @@ -126,7 +128,7 @@ public String getName() {
*
* @return the factory which will be used for actually computing metric value(s).
*/
public abstract ValueFactory valueFactory();
public abstract ValueFactory createValueFactory(final String version);

/**
* Returns a map of the available versions with the corresponding value factory.
Expand All @@ -144,6 +146,6 @@ public Map<String, ValueFactory> getVersions() {
* @return the {@link ValueFactory} instance associated with a given version.
*/
public ValueFactory valueFactory(final String version) {
return values.getOrDefault(version, valueFactory());
return values.get(version);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* @since 1.0
*/
public abstract class ValueFactory implements HitsCollector {
private String version;
private final String version;
private final Metric owner;
protected long totalHits;

Expand All @@ -27,8 +27,9 @@ public abstract class ValueFactory implements HitsCollector {
*
* @param owner the owner metric.
*/
protected ValueFactory(final Metric owner) {
protected ValueFactory(final Metric owner, final String version){
this.owner = owner;
this.version = version;
}

public Metric owner() {
Expand All @@ -38,7 +39,6 @@ public Metric owner() {
@Override
public void setTotalHits(final long totalHits, final String version) {
this.totalHits = totalHits;
this.version = version;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@

import static io.sease.rre.Calculator.*;

/**
* Precision and recall are single-value metrics based on the whole list of documents returned by the system.
* For systems that return a ranked sequence of documents, it is desirable to also consider the order in which the
* returned documents are presented. By computing a precision and recall at every position in the ranked sequence of
* documents, one can plot a precision-recall curve, plotting precision.
*
* @author agazzarini
* @since 1.0
*/
public class AveragePrecision extends Metric {
/**
* Builds a new {@link AveragePrecision} metric.
Expand All @@ -17,9 +26,8 @@ public AveragePrecision() {
}

@Override
public ValueFactory valueFactory() {

return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
private BigDecimal relevantItemsFound = BigDecimal.ZERO;

private BigDecimal howManyRelevantDocuments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class MutableValueFactory extends ValueFactory {
*
* @param owner the owner metric.
*/
private MutableValueFactory(final Metric owner) {
super(owner);
private MutableValueFactory(final Metric owner, final String version) {
super(owner, version);
}

@Override
Expand Down Expand Up @@ -75,12 +75,12 @@ public AveragedMetric(final String name) {
*/
public void collect(final String version, final BigDecimal additionalValue) {
((MutableValueFactory)
values.computeIfAbsent(version, v -> valueFactory()))
values.computeIfAbsent(version, this::createValueFactory))
.collect(additionalValue);
}

@Override
public ValueFactory valueFactory() {
return new MutableValueFactory(this);
public ValueFactory createValueFactory(final String version) {
return new MutableValueFactory(this, version);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.sease.rre.core.domain.metrics.impl;

/**
* A F-measure which weighs recall lower than precision (by attenuating the influence of false negatives).
*
* @author agazzarini
* @since 1.0
*/
public class F0_5 extends FMeasure {
/**
* Builds a new F1 metric instance.
*/
public F0_5() {
super("F0.5", 0.5f);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.sease.rre.core.domain.metrics.impl;

/**
* The most popular F-Measure, which balances between precisin and recall using 1 as beta factor.
*
* @author agazzarini
* @since 1.0
*/
public class F1 extends FMeasure {
/**
* Builds a new F1 metric instance.
*/
public F1() {
super("F1", 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.sease.rre.core.domain.metrics.impl;

/**
* A F-measure which weighs recall higher than precision (by placing more emphasis on false negatives).
*
* @author agazzarini
* @since 1.0
*/
public class F2 extends FMeasure {
/**
* Builds a new F1 metric instance.
*/
public F2() {
super("F2", 2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.sease.rre.core.domain.metrics.impl;

import com.fasterxml.jackson.databind.JsonNode;
import io.sease.rre.Calculator;
import io.sease.rre.core.domain.metrics.Metric;
import io.sease.rre.core.domain.metrics.ValueFactory;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;

/**
* The F-measure measures the effectiveness of retrieval with respect to a user who attaches β times as much importance to recall as precision.
* In statistical analysis of binary classification, the F1 score (also F-score or F-measure) is a measure of a test's accuracy.
* It considers both the precision p and the recall r of the test to compute the score: p is the number of correct positive results
* divided by the number of all positive results returned by the classifier, and r is the number of correct positive results
* divided by the number of all relevant samples (all samples that should have been identified as positive).
*
* The F1 score is the harmonic average of the precision and recall, where an F1 score reaches its best value at 1
* (perfect precision and recall) and worst at 0.
*
* (Wikipedia)
*
* @author agazzarini
* @since 1.1
*/
public abstract class FMeasure extends Metric {

private final BigDecimal beta;
final Metric precision = new Precision();
final Metric recall = new Recall();

/**
* Builds a new F-Measure/F-Score metric with the given beta factor.
*
* @param beta the balance factor between precision and recall.
*/
public FMeasure(final String name, final float beta) {
super(name);
this.beta = BigDecimal.valueOf(beta).pow(2);
}

@Override
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
@Override
public BigDecimal value() {
final BigDecimal betaPlusOne = Calculator.sum(BigDecimal.ONE, beta);

final BigDecimal p = precision.valueFactory(version).value();
final BigDecimal r = recall.valueFactory(version).value();

if (p.doubleValue() == 0 || r.doubleValue() == 0) return BigDecimal.ZERO;

final BigDecimal precisionTimesBeta = Calculator.multiply(p, beta);

final BigDecimal dividend = Calculator.multiply(p,r);
final BigDecimal divisor = Calculator.sum(precisionTimesBeta,r);

return Calculator.multiply(betaPlusOne, Calculator.divide(dividend, divisor));
}

@Override
public void setTotalHits(long totalHits, String version) {
precision.setTotalHits(totalHits, version);
recall.setTotalHits(totalHits, version);
}

@Override
public void collect(final Map<String, Object> hit, final int rank, final String version) {
precision.collect(hit, rank, version);
recall.collect(hit, rank, version);
}
};
}

@Override
public void setTotalHits(long totalHits, String version) {
super.setTotalHits(totalHits, version);
}

@Override
public void setRelevantDocuments(JsonNode relevantDocuments) {
super.setRelevantDocuments(relevantDocuments);
precision.setRelevantDocuments(relevantDocuments);
recall.setRelevantDocuments(relevantDocuments);
}

@Override
public void setVersions(List<String> versions) {
super.setVersions(versions);
precision.setVersions(versions);
recall.setVersions(versions);
}

@Override
public void setIdFieldName(String idFieldName) {
super.setIdFieldName(idFieldName);
precision.setIdFieldName(idFieldName);
recall.setIdFieldName(idFieldName);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.sease.rre.core.domain.metrics.impl;

import com.fasterxml.jackson.databind.JsonNode;
import io.sease.rre.Func;
import io.sease.rre.core.domain.metrics.Metric;
import io.sease.rre.core.domain.metrics.ValueFactory;

Expand All @@ -12,7 +11,6 @@
import java.util.Map;
import java.util.stream.StreamSupport;

import static io.sease.rre.Field.GAIN;
import static io.sease.rre.Func.gainOrRatingNode;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.groupingBy;
Expand All @@ -34,8 +32,8 @@ public NDCGAtTen() {
}

@Override
public ValueFactory valueFactory() {
return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
private BigDecimal dcg = BigDecimal.ZERO;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import static io.sease.rre.Calculator.divide;

/**
* AveragePrecision is the fraction of the documents retrieved that are relevant to the user's information need.
* Precision is the fraction of the documents retrieved that are relevant to the user's information need.
*
* @author agazzarini
* @since 1.0
Expand All @@ -24,8 +24,8 @@ public Precision() {
}

@Override
public ValueFactory valueFactory() {
return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
final AtomicInteger relevantItemsFound = new AtomicInteger();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public class PrecisionAtK extends Metric {
}

@Override
public ValueFactory valueFactory() {
return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
private final List<Map<String, Object>> collected = new ArrayList<>();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ public Recall() {
}

@Override
public ValueFactory valueFactory() {
return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
final AtomicInteger relevantItemsFound = new AtomicInteger();

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.sease.rre.core.domain.metrics.impl;

import com.fasterxml.jackson.databind.JsonNode;
import io.sease.rre.Field;
import io.sease.rre.Func;
import io.sease.rre.core.domain.metrics.Metric;
import io.sease.rre.core.domain.metrics.ValueFactory;
Expand All @@ -11,7 +10,7 @@
import java.util.Map;

/**
* Reciprocal Rank metric
* The reciprocal rank of a query response is the multiplicative inverse of the rank of the first correct answer.
*
* @author agazzarini
* @since 1.0
Expand All @@ -25,8 +24,8 @@ public ReciprocalRank() {
}

@Override
public ValueFactory valueFactory() {
return new ValueFactory(this) {
public ValueFactory createValueFactory(final String version) {
return new ValueFactory(this, version) {
private int rank;
private int maxGain;

Expand Down
Loading

0 comments on commit cc10d29

Please sign in to comment.