Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Compute precision/recall for each label
Browse files Browse the repository at this point in the history
Summary:
This diff adds a new command to fasttext to display precision/recall score for each individual label : `print-label-scores`
It will get predicted labels above given threshold, and compute scores.

For example, the question "vinegar softens the bite of raw onions ?" has two labels : "vinegar" and "onions". It will ask fastText to predict labels above given threshold. If there are two such labels : "pickling", "onions", we will obtain :
"onions" will have a precision of 100%,
"pickling" a precision of 0%,
"onions" will have a recall of 100%,
"vinegar" will have a recall of 0%.

Reviewed By: EdouardGrave

Differential Revision: D9991570

fbshipit-source-id: 63cff90f57659d51f5aa1f10243d40e253445aa6
  • Loading branch information
Celebio authored and facebook-github-bot committed Oct 24, 2018
1 parent 8e68462 commit be1e597
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 22 deletions.
93 changes: 77 additions & 16 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,37 +397,29 @@ FastText::test(std::istream& in, int32_t k, real threshold) {
}

void FastText::predict(
std::istream& in,
int32_t k,
std::vector<std::pair<real, std::string>>& predictions,
const std::vector<int32_t>& words,
std::vector<std::pair<real, int32_t>>& predictions,
real threshold) const {
std::vector<int32_t> words, labels;
predictions.clear();
dict_->getLine(in, words, labels);
predictions.clear();
if (words.empty()) {
return;
}
Vector hidden(args_->dim);
Vector output(dict_->nlabels());
std::vector<std::pair<real, int32_t>> modelPredictions;
model_->predict(words, k, threshold, modelPredictions, hidden, output);
for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend();
it++) {
predictions.push_back(
std::make_pair(it->first, dict_->getLabel(it->second)));
}
model_->predict(words, k, threshold, predictions, hidden, output);
}

void FastText::predict(
std::istream& in,
int32_t k,
bool print_prob,
real threshold) {
std::vector<std::pair<real, std::string>> predictions;
std::vector<std::pair<real, int32_t>> predictions;
while (in.peek() != EOF) {
std::vector<int32_t> words, labels;
dict_->getLine(in, words, labels);
predictions.clear();
predict(in, k, predictions, threshold);
predict(k, words, predictions, threshold);
if (predictions.empty()) {
std::cout << std::endl;
continue;
Expand All @@ -436,7 +428,7 @@ void FastText::predict(
if (it != predictions.cbegin()) {
std::cout << " ";
}
std::cout << it->second;
std::cout << dict_->getLabel(it->second);
if (print_prob) {
std::cout << " " << std::exp(it->first);
}
Expand All @@ -445,6 +437,75 @@ void FastText::predict(
}
}

void FastText::printLabelStats(
const std::vector<LabelStats>& labelStats) const {
const static double kUnknownValue = -1.0;
auto computeF1Score = [](double precision, double recall) -> double {
if (precision == kUnknownValue || recall == kUnknownValue) {
return kUnknownValue;
}
if (precision != 0 && recall != 0) {
return 2 * precision * recall / (precision + recall);
}
return 0.;
};
auto displayScore = [](double value) {
std::cout << std::fixed;
std::cout.precision(6);
if (value == kUnknownValue) {
std::cout << "--------";
} else {
std::cout << value;
}
};

for (size_t labelId = 0; labelId < labelStats.size(); labelId++) {
const auto& labelStat = labelStats[labelId];
double precision = labelStat.predicted
? ((double)labelStat.predictedGold / labelStat.predicted)
: kUnknownValue;
double recall = labelStat.gold
? ((double)labelStat.predictedGold / labelStat.gold)
: kUnknownValue;
double f1score = computeF1Score(precision, recall);
std::cout << "F1-Score : ";
displayScore(f1score);
std::cout << " Precision : ";
displayScore(precision);
std::cout << " Recall : ";
displayScore(recall);
std::cout << " " << dict_->getLabel(labelId) << std::endl;
}
}

void FastText::printLabelStats(std::istream& in, int32_t k, real threshold)
const {
std::vector<std::pair<real, int32_t>> predictions;
size_t labelsSize = dict_->nlabels();
std::vector<LabelStats> labelStats(labelsSize);
while (in.peek() != EOF) {
std::vector<int32_t> words, gold;
dict_->getLine(in, words, gold);
predictions.clear();
predict(k, words, predictions, threshold);
for (const auto& goldLabelId : gold) {
assert(goldLabelId < labelsSize);
labelStats[goldLabelId].gold++;
}
for (const auto& predictedLabel : predictions) {
int32_t predictedLabelId = predictedLabel.second;
assert(predictedLabelId < labelsSize);
labelStats[predictedLabelId].predicted++;
if (auto itFound =
std::find(gold.begin(), gold.end(), predictedLabelId) !=
gold.end()) {
labelStats[predictedLabelId].predictedGold++;
}
}
}
printLabelStats(labelStats);
}

void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
svec.zero();
if (args_->model == model_name::sup) {
Expand Down
17 changes: 12 additions & 5 deletions src/fasttext.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,18 @@ class FastText {
bool quant_;
int32_t version;

struct LabelStats {
int32_t gold, predicted, predictedGold;
LabelStats() : gold(0), predicted(0), predictedGold(0) {}
};

void startThreads();
void predict(
int32_t,
const std::vector<int32_t>&,
std::vector<std::pair<real, int32_t>>&,
real = 0.0) const;
void printLabelStats(const std::vector<LabelStats>& labelStats) const;

public:
FastText();
Expand Down Expand Up @@ -95,11 +106,7 @@ class FastText {
void quantize(const Args);
std::tuple<int64_t, double, double> test(std::istream&, int32_t, real = 0.0);
void predict(std::istream&, int32_t, bool, real = 0.0);
void predict(
std::istream&,
int32_t,
std::vector<std::pair<real, std::string>>&,
real = 0.0) const;
void printLabelStats(std::istream&, int32_t, real = 0.0) const;
void ngramVectors(std::string);
void precomputeWordVectors(Matrix&);
void findNN(
Expand Down
42 changes: 42 additions & 0 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void printUsage() {
<< " supervised train a supervised classifier\n"
<< " quantize quantize a model to reduce the memory usage\n"
<< " test evaluate a supervised classifier\n"
<< " test-label print labels with precision and recall scores\n"
<< " predict predict most likely labels\n"
<< " predict-prob predict most likely labels with probabilities\n"
<< " skipgram train a skipgram model\n"
Expand Down Expand Up @@ -59,6 +60,16 @@ void printPredictUsage() {
<< std::endl;
}

void printPrintLabelStatsUsage() {
std::cerr
<< "usage: fasttext test-label <model> <test-data> [<k>] [<th>]\n\n"
<< " <model> model filename\n"
<< " <test-data> test data filename\n"
<< " <k> (optional; 1 by default) predict top k labels\n"
<< " <th> (optional; 0.0 by default) probability threshold\n"
<< std::endl;
}

void printPrintWordVectorsUsage() {
std::cerr << "usage: fasttext print-word-vectors <model>\n\n"
<< " <model> model filename\n"
Expand Down Expand Up @@ -186,6 +197,35 @@ void predict(const std::vector<std::string>& args) {
exit(0);
}

void printLabelStats(const std::vector<std::string>& args) {
if (args.size() < 4 || args.size() > 6) {
printPrintLabelStatsUsage();
exit(EXIT_FAILURE);
}
int32_t k = 1;
real threshold = 0.0;
if (args.size() > 4) {
k = std::stoi(args[4]);
if (args.size() > 5) {
threshold = std::stof(args[5]);
}
}

FastText fasttext;
fasttext.loadModel(std::string(args[2]));

std::string infile(args[3]);
std::ifstream ifs(infile);
if (!ifs.is_open()) {
std::cerr << "Input file cannot be opened!" << std::endl;
exit(EXIT_FAILURE);
}
fasttext.printLabelStats(ifs, k, threshold);
ifs.close();

exit(0);
}

void printWordVectors(const std::vector<std::string> args) {
if (args.size() != 3) {
printPrintWordVectorsUsage();
Expand Down Expand Up @@ -355,6 +395,8 @@ int main(int argc, char** argv) {
analogies(args);
} else if (command == "predict" || command == "predict-prob") {
predict(args);
} else if (command == "test-label") {
printLabelStats(args);
} else if (command == "dump") {
dump(args);
} else {
Expand Down
4 changes: 3 additions & 1 deletion src/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ void Model::predict(
std::vector<std::pair<real, int32_t>>& heap,
Vector& hidden,
Vector& output) const {
if (k <= 0) {
if (k == Model::kUnlimitedPredictions) {
k = osz_;
} else if (k <= 0) {
throw std::invalid_argument("k needs to be 1 or higher!");
}
if (args_->model != model_name::sup) {
Expand Down
2 changes: 2 additions & 0 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Model {
bool quant_;
void
setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);

static const int32_t kUnlimitedPredictions = -1;
};

} // namespace fasttext

0 comments on commit be1e597

Please sign in to comment.