Skip to content

Commit

Permalink
Print base 2 scores in generate.py, interactive.py, eval_lm.py
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#987

Differential Revision: D19373600

Pulled By: myleott

fbshipit-source-id: d0cd0b616a95c7907d1856d786b4ebdb7a084011
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 13, 2020
1 parent ab6ce42 commit 660d69f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 5 additions & 3 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Evaluate the perplexity of a trained language model.
"""

import math

import numpy as np
import torch

Expand Down Expand Up @@ -95,7 +97,7 @@ def main(parsed_args):

assert len(models) > 0

print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
print('| num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

itr = task.get_batch_iterator(
dataset=dataset,
Expand Down Expand Up @@ -208,9 +210,9 @@ def main(parsed_args):
wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})

avg_nll_loss = -score_sum / count
avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
print('| Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, 2**avg_nll_loss))

if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
Expand Down
9 changes: 7 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
Translate pre-processed data with a trained model.
"""

import math
import os

import torch

from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
Expand Down Expand Up @@ -151,12 +153,14 @@ def main(args):
)

if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str), file=output_file)
score = hypo['score'] / math.log(2) # convert to base 2
print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file)
print('P-{}\t{}'.format(
sample_id,
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
# convert from base e to base 2
hypo['positional_scores'].div_(math.log(2)).tolist(),
))
), file=output_file)

Expand Down Expand Up @@ -192,6 +196,7 @@ def main(args):
t.log({'wps': round(wps_meter.avg)})
num_sentences += sample['nsentences']

print('| NOTE: hypothesis and token scores are output in base 2')
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg), file=output_file)
if has_target:
Expand Down
11 changes: 9 additions & 2 deletions interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from collections import namedtuple
import fileinput
import math

import torch

Expand Down Expand Up @@ -130,6 +131,7 @@ def decode_fn(x):

if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| NOTE: hypothesis and token scores are output in base 2')
print('| Type the input sentence and press return:')
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size):
Expand Down Expand Up @@ -169,10 +171,15 @@ def decode_fn(x):
remove_bpe=args.remove_bpe,
)
hypo_str = decode_fn(hypo_str)
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
score = hypo['score'] / math.log(2) # convert to base 2
print('H-{}\t{}\t{}'.format(id, score, hypo_str))
print('P-{}\t{}'.format(
id,
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
' '.join(map(
lambda x: '{:.4f}'.format(x),
# convert from base e to base 2
hypo['positional_scores'].div_(math.log(2)).tolist(),
))
))
if args.print_alignment:
alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
Expand Down

0 comments on commit 660d69f

Please sign in to comment.