Skip to content

Commit

Permalink
cleaned up, added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mimno committed Jun 17, 2015
1 parent 442de6b commit 3f6a274
Showing 1 changed file with 51 additions and 40 deletions.
91 changes: 51 additions & 40 deletions src/cc/mallet/classify/evaluate/ConfusionMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,16 @@ License, or (at your option) any later version.
* Calculates and prints confusion matrix, accuracy,
* and precision for a given clasification trial.
*/
public class ConfusionMatrix
{
public class ConfusionMatrix {
private static Logger logger = MalletLogger.getLogger(ConfusionMatrix.class.getName());

int numClasses;
/**
* the list of classifications from the trial
*/
ArrayList classifications;
/**
* 2-d confiusion matrix
*/
* the list of classifications from the trial
*/
ArrayList classifications;
/** 2-d confusion matrix, indexed by [actual class][predicted class] */
int[][] values;

Trial trial;
Expand All @@ -65,41 +63,42 @@ public class ConfusionMatrix
* Constructs matrix and calculates values
* @param t the trial to build matrix from
*/
public ConfusionMatrix(Trial t)
{
public ConfusionMatrix(Trial t) {
this.trial = t;
this.classifications = t;
Labeling tempLabeling =
((Classification)classifications.get(0)).getLabeling();
((Classification) classifications.get(0)).getLabeling();
this.numClasses = tempLabeling.getLabelAlphabet().size();
values = new int[numClasses][numClasses];
for(int i=0; i < classifications.size(); i++)
{

for (int i=0; i < classifications.size(); i++) {
LabelVector lv =
((Classification)classifications.get(i)).getLabelVector();
Instance inst = ((Classification)classifications.get(i)).getInstance();
int bestIndex = lv.getBestIndex();
int correctIndex = inst.getLabeling().getBestIndex();
assert(correctIndex != -1);
//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);

values[correctIndex][bestIndex]++;
}
}

/** Return the count at row i (true) , column j (predicted) */
double value(int i, int j)
{
double value(int i, int j) {
assert(i >= 0 && j >= 0 && i < numClasses && j < numClasses);
return values[i][j];
}

static private void appendJustifiedInt (StringBuffer sb, int i, boolean zeroDot) {
if (i < 100) sb.append (' ');
if (i < 10) sb.append (' ');
if (i == 0 && zeroDot)
if (i == 0 && zeroDot) {
sb.append (".");
else
}
else {
sb.append (""+i);
}
}

public String toString () {
Expand All @@ -108,24 +107,33 @@ public String toString () {
LabelAlphabet labelAlphabet = trial.getClassifier().getLabelAlphabet();
for (int i = 0; i < numClasses; i++) {
int len = labelAlphabet.lookupLabel(i).toString().length();
if (maxLabelNameLength < len)
if (maxLabelNameLength < len) {
maxLabelNameLength = len;
}
}

double[] distribution = new double[values.length];
for (int i = 0; i < distribution.length; i++)
distribution[i] = MatrixOps.sum(values[i]);
double baselineAccuracy = MatrixOps.max(distribution) / MatrixOps.sum(distribution);
// These counts will be integers, but we'll keep them as doubles so we can divide later
double[] correctLabelCounts = new double[values.length];

for (int i = 0; i < correctLabelCounts.length; i++){
// This sum is the number of instances whose correct class is i
correctLabelCounts[i] = MatrixOps.sum(values[i]);
}
// Find the count of the most frequent class and divide that by
// the total number of instances.
double baselineAccuracy = MatrixOps.max(correctLabelCounts) / MatrixOps.sum(correctLabelCounts);

sb.append ("Confusion Matrix, row=true, column=predicted accuracy="+trial.getAccuracy()+" most-frequent-tag baseline="+baselineAccuracy+"\n");
for (int i = 0; i < maxLabelNameLength-5+4; i++) sb.append (' ');

for (int i = 0; i < maxLabelNameLength-5+4; i++) { sb.append (' '); }
sb.append ("label");
for (int c2 = 0; c2 < Math.min(10,numClasses); c2++) sb.append (" "+c2);
for (int c2 = 10; c2 < numClasses; c2++) sb.append (" "+c2);
for (int c2 = 0; c2 < Math.min(10,numClasses); c2++) { sb.append (" "+c2); }
for (int c2 = 10; c2 < numClasses; c2++) { sb.append (" " + c2); }
sb.append (" |total\n");
for (int c = 0; c < numClasses; c++) {
appendJustifiedInt (sb, c, false);
String labelName = labelAlphabet.lookupLabel(c).toString();
for (int i = 0; i < maxLabelNameLength-labelName.length(); i++) sb.append (' ');
for (int i = 0; i < maxLabelNameLength-labelName.length(); i++) { sb.append (' '); }
sb.append (" "+labelName+" ");
for (int c2 = 0; c2 < numClasses; c2++) {
appendJustifiedInt (sb, values[c][c2], true);
Expand All @@ -140,45 +148,48 @@ public String toString () {
/**
* Returns the precision of this predicted class
*/
public double getPrecision (int predictedClassIndex)
{
public double getPrecision (int predictedClassIndex) {
int total = 0;
for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {
total += values[trueClassIndex][predictedClassIndex];
}
if (total == 0)

if (total == 0) {
return 0.0;
else
}
else {
return (double) (values[predictedClassIndex][predictedClassIndex]) / total;
}
}

/**
* Returns percent of time that class2 is true class when
* class1 is predicted class
*
*/
public double getConfusionBetween (int class1, int class2)
{
public double getConfusionBetween (int class1, int class2) {
int total = 0;
for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {
total += values[trueClassIndex][class1];
}
if (total == 0)
if (total == 0) {
return 0.0;
else
return (double) (values[class2][class1]) / total;
}
else {
return (double) (values[class2][class1]) / total;
}
}

/**
* Returns the percentage of instances with
* true label = classIndex
*/
public double getClassPrior (int classIndex)
{
int sum= 0;
for(int i=0; i < numClasses; i++)
public double getClassPrior (int classIndex) {
double sum = 0;
for (int i=0; i < numClasses; i++) {
sum += values[classIndex][i];
return (double)sum / classifications.size();
}
return sum / classifications.size();
}


Expand Down

0 comments on commit 3f6a274

Please sign in to comment.