Skip to content

Commit

Permalink
Switch to Python logging (+ lint) (#1627)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1627

Python logging offers a number of benefits, such as logging timestamps, better
cross-library compatibility, ability to add multiple output handlers, etc.

Pull Request resolved: fairinternal/fairseq-py#646

Reviewed By: spencerp

Differential Revision: D15815620

Pulled By: myleott

fbshipit-source-id: 5e64e9929b5e4b9dd5bb49bcdf7c510631907134
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 17, 2020
1 parent 1bb218f commit fb76dac
Show file tree
Hide file tree
Showing 43 changed files with 533 additions and 237 deletions.
40 changes: 27 additions & 13 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
Evaluate the perplexity of a trained language model.
"""

import logging
import math

import numpy as np
import torch

from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
Expand All @@ -19,6 +19,14 @@
from fairseq.sequence_scorer import SequenceScorer


logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger('fairseq_cli.eval_lm')


class WordStat(object):
def __init__(self, word, is_bpe):
self.word = word
Expand Down Expand Up @@ -50,14 +58,14 @@ def main(parsed_args):

utils.import_user_module(parsed_args)

print(parsed_args)
logger.info(parsed_args)

use_cuda = torch.cuda.is_available() and not parsed_args.cpu

task = tasks.setup_task(parsed_args)

# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
logger.info('loading model(s) from {}'.format(parsed_args.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
Expand Down Expand Up @@ -85,7 +93,7 @@ def main(parsed_args):
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
Expand All @@ -97,7 +105,7 @@ def main(parsed_args):

assert len(models) > 0

print('| num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

itr = task.get_batch_iterator(
dataset=dataset,
Expand All @@ -123,11 +131,11 @@ def main(parsed_args):
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(
bpe_toks = {
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
)
}
bpe_len = len(bpe_cont)
else:
bpe_toks = None
Expand Down Expand Up @@ -171,8 +179,10 @@ def main(parsed_args):

inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
logger.info(
'skipping tokens with inf scores:',
task.target_dictionary.string(tokens[inf_scores.nonzero()])
)
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
Expand Down Expand Up @@ -202,7 +212,7 @@ def main(parsed_args):
is_bpe = False
w = ''
if args.output_word_probs:
print(
logger.info(
str(int(sample_id)) + " "
+ ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
)
Expand All @@ -211,12 +221,16 @@ def main(parsed_args):
t.log({'wps': round(wps_meter.avg)})

avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, 2**avg_nll_loss))
logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
gen_timer.n, gen_timer.sum, 1. / gen_timer.avg
))
logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
avg_nll_loss, 2**avg_nll_loss
))

if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
print(ws)
logger.info(ws)


def cli_main():
Expand Down
13 changes: 8 additions & 5 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from torch.serialization import default_restore_location


logger = logging.getLogger(__name__)


def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters

Expand Down Expand Up @@ -77,7 +80,7 @@ def is_better(a, b):
PathManager.copy(checkpoints[0], cp, overwrite=True)

write_timer.stop()
print(
logger.info(
"| saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, val_loss, write_timer.sum
)
Expand Down Expand Up @@ -231,7 +234,7 @@ def torch_persistent_save(*args, **kwargs):
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())


def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
Expand Down Expand Up @@ -388,8 +391,8 @@ def prune_state_dict(state_dict, args):
return state_dict

# apply pruning
print(
"| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
logger.info(
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)

def create_pruning_pass(layers_to_keep, layer_name):
Expand Down Expand Up @@ -485,7 +488,7 @@ def verify_checkpoint_directory(save_dir: str) -> None:
with open(temp_file_path, "w"):
pass
except OSError as e:
print("| Unable to access checkpoint save directory: {}".format(save_dir))
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
raise e
else:
os.remove(temp_file_path)
10 changes: 7 additions & 3 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
from collections import Iterable
import contextlib
import itertools
import logging
import os
import sys
import types

import numpy as np


logger = logging.getLogger(__name__)


def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
Expand Down Expand Up @@ -78,7 +82,7 @@ def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, def
)
if dataset is None:
break
print('| loaded {} examples from: {}'.format(len(dataset), path_k))
logger.info('loaded {} examples from: {}'.format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
Expand Down Expand Up @@ -187,8 +191,8 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(ignored[0], dataset.size(ignored[0]), max_positions))
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
logger.warn((
'{} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), max_positions, ignored[:10]))
return indices
Expand Down
7 changes: 6 additions & 1 deletion fairseq/data/language_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch

from . import data_utils, FairseqDataset


logger = logging.getLogger(__name__)


def collate(
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
input_feeding=True,
Expand All @@ -26,7 +31,7 @@ def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
print("| alignment size mismatch found, skipping alignment!")
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True

Expand Down
7 changes: 6 additions & 1 deletion fairseq/data/subsample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np

from . import BaseWrapperDataset


logger = logging.getLogger(__name__)


class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Expand All @@ -23,7 +28,7 @@ def __init__(self, dataset, size_ratio):
self.indices = np.random.choice(
list(range(len(self.dataset))), self.actual_size, replace=False
)
print(
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
Expand Down
32 changes: 14 additions & 18 deletions fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import pickle
import socket
Expand All @@ -16,6 +17,9 @@
from fairseq import utils


logger = logging.getLogger(__name__)


def is_master(args):
return args.distributed_rank == 0

Expand Down Expand Up @@ -76,42 +80,34 @@ def distributed_init(args):
if torch.distributed.is_initialized():
warnings.warn('Distributed is already initialized, cannot initialize twice!')
else:
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
logger.info('distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method,
))
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
print('| initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank), flush=True)
logger.info('initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank,
))

# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
else:
dist.all_reduce(torch.zeros(1))

suppress_output(is_master(args))
if is_master(args):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)

args.distributed_rank = torch.distributed.get_rank()
return args.distributed_rank


def suppress_output(is_master):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)

__builtin__.print = print


def get_rank():
return dist.get_rank()

Expand Down
8 changes: 4 additions & 4 deletions fairseq/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_archive_file(archive_file):
try:
resolved_archive_file = cached_path(archive_file, cache_dir=None)
except EnvironmentError:
print(
logger.info(
"Archive name '{}' was not found in archive name list. "
"We assumed '{}' was a path or URL but couldn't find any file "
"associated to this path or URL.".format(
Expand All @@ -65,16 +65,16 @@ def load_archive_file(archive_file):
return None

if resolved_archive_file == archive_file:
print("loading archive file {}".format(archive_file))
logger.info("loading archive file {}".format(archive_file))
else:
print("loading archive file {} from cache at {}".format(
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))

# Extract archive to temp dir and replace .tar.bz2 if necessary
tempdir = None
if not os.path.isdir(resolved_archive_file):
tempdir = tempfile.mkdtemp()
print("extracting archive file {} to temp dir {}".format(
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
ext = os.path.splitext(archive_file)[1][1:]
with tarfile.open(resolved_archive_file, 'r:' + ext) as archive:
Expand Down
12 changes: 8 additions & 4 deletions fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import copy
import logging
import os
from typing import List, Dict, Iterator, Tuple, Any

Expand All @@ -16,6 +17,9 @@
from fairseq.data import encoders


logger = logging.getLogger(__name__)


def from_pretrained(
model_name_or_path,
checkpoint_file='model.pt',
Expand Down Expand Up @@ -172,15 +176,15 @@ def getarg(name, default):

for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
print('S\t{}'.format(src_str_with_unk))
logger.info('S\t{}'.format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo['tokens'])
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
logger.info('H\t{}\t{}'.format(hypo['score'], hypo_str))
logger.info('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if hypo['alignment'] is not None and getarg('print_alignment', False):
print('A\t{}'.format(
logger.info('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
))
return outputs
Expand Down
Loading

0 comments on commit fb76dac

Please sign in to comment.