Skip to content

Commit

Permalink
Move meters, metrics and progress_bar into fairseq.logging (#1046)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1046

Differential Revision: D20030412

Pulled By: myleott

fbshipit-source-id: bd87391aa9cdb73306ee90a30eeb2bdeff3690f9
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 27, 2020
1 parent cef5653 commit f8b795f
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 258 deletions.
71 changes: 38 additions & 33 deletions examples/speech_recognition/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import sentencepiece as spm
import torch
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq import checkpoint_utils, options, utils, tasks
from fairseq.logging import meters, progress_bar
from fairseq.utils import import_user_module


Expand Down Expand Up @@ -199,9 +199,15 @@ def main(args):

# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
)

# Initialize generator
gen_timer = StopwatchMeter()
gen_timer = meters.StopwatchMeter()
generator = task.build_generator(args)

num_sentences = 0
Expand All @@ -213,36 +219,35 @@ def main(args):
sp.Load(os.path.join(args.data, "spm.model"))

res_files = prepare_result_files(args)
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue

prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]

gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)

for i, sample_id in enumerate(sample["id"].tolist()):
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
id = task.dataset(args.gen_subset).ids[int(sample_id)]
target_tokens = (
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
)
# Process top predictions
process_predictions(
args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id
)

wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"]
wps_meter = meters.TimeMeter()
for sample in progress:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue

prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]

gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)

for i, sample_id in enumerate(sample["id"].tolist()):
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
id = task.dataset(args.gen_subset).ids[int(sample_id)]
target_tokens = (
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
)
# Process top predictions
process_predictions(
args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id
)

wps_meter.update(num_generated_tokens)
progress.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"]

logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
Expand Down
8 changes: 8 additions & 0 deletions fairseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
__all__ = ['pdb']
__version__ = '0.9.0'

import sys

# backwards compatibility to support `from fairseq.meters import AverageMeter`
from fairseq.logging import meters, metrics, progress_bar # noqa
sys.modules['fairseq.meters'] = meters
sys.modules['fairseq.metrics'] = metrics
sys.modules['fairseq.progress_bar'] = progress_bar

import fairseq.criterions # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
Expand Down
Empty file added fairseq/logging/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
99 changes: 65 additions & 34 deletions fairseq/progress_bar.py → fairseq/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,86 @@
Wrapper around various loggers and progress bars (e.g., tqdm).
"""

from collections import OrderedDict
from contextlib import contextmanager
import json
import logging
from numbers import Number
import os
import sys
from collections import OrderedDict
from contextlib import contextmanager
from numbers import Number
from typing import Optional

import torch

from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from .meters import AverageMeter, StopwatchMeter, TimeMeter


logger = logging.getLogger(__name__)


def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
args.log_format = no_progress_bar if args.no_progress_bar else default

if args.log_format == 'tqdm' and not sys.stderr.isatty():
args.log_format = 'simple'

if args.log_format == 'json':
bar = json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'simple':
bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'tqdm':
bar = tqdm_progress_bar(iterator, epoch, prefix)
def progress_bar(
iterator,
log_format: Optional[str] = None,
log_interval: int = 100,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = 'tqdm',
):
if log_format is None:
log_format = default_log_format
if log_format == 'tqdm' and not sys.stderr.isatty():
log_format = 'simple'

if log_format == 'json':
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == 'none':
bar = NoopProgressBar(iterator, epoch, prefix)
elif log_format == 'simple':
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == 'tqdm':
bar = TqdmProgressBar(iterator, epoch, prefix)
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
raise ValueError('Unknown log format: {}'.format(log_format))

if args.tensorboard_logdir and distributed_utils.is_master(args):
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from fairseq.fb_tbmf_wrapper import fb_tbmf_wrapper
bar = fb_tbmf_wrapper(bar, args, args.log_interval)
from .fb_tbmf_wrapper import FbTbmfWrapper
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args)
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)

return bar


def build_progress_bar(
args,
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default: str = 'tqdm',
no_progress_bar: str = 'none',
):
"""Legacy wrapper that takes an argparse.Namespace."""
if getattr(args, 'no_progress_bar', False):
default = no_progress_bar
if getattr(args, 'distributed_rank', 0) == 0:
tensorboard_logdir = getattr(args, 'tensorboard_logdir', None)
else:
tensorboard_logdir = None
return progress_bar(
iterator,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=tensorboard_logdir,
default_log_format=default,
)


def format_stat(stat):
if isinstance(stat, Number):
stat = '{:g}'.format(stat)
Expand All @@ -68,7 +101,7 @@ def format_stat(stat):
return stat


class progress_bar(object):
class BaseProgressBar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
Expand Down Expand Up @@ -125,7 +158,7 @@ def rename_logger(logger, new_name):
logger.name = old_name


class json_progress_bar(progress_bar):
class JsonProgressBar(BaseProgressBar):
"""Log output in JSON format."""

def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
Expand Down Expand Up @@ -179,7 +212,7 @@ def _format_stats(self, stats, epoch=None, update=None):
return postfix


class noop_progress_bar(progress_bar):
class NoopProgressBar(BaseProgressBar):
"""No logging."""

def __init__(self, iterable, epoch=None, prefix=None):
Expand All @@ -198,7 +231,7 @@ def print(self, stats, tag=None, step=None):
pass


class simple_progress_bar(progress_bar):
class SimpleProgressBar(BaseProgressBar):
"""A minimal logger for non-TTY environments."""

def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
Expand Down Expand Up @@ -233,7 +266,7 @@ def print(self, stats, tag=None, step=None):
logger.info('{} | {}'.format(self.prefix, postfix))


class tqdm_progress_bar(progress_bar):
class TqdmProgressBar(BaseProgressBar):
"""Log to tqdm."""

def __init__(self, iterable, epoch=None, prefix=None):
Expand Down Expand Up @@ -261,13 +294,12 @@ def print(self, stats, tag=None, step=None):
SummaryWriter = None


class tensorboard_log_wrapper(progress_bar):
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""

def __init__(self, wrapped_bar, tensorboard_logdir, args):
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
self.args = args

if SummaryWriter is None:
logger.warning(
Expand All @@ -281,7 +313,6 @@ def _writer(self, key):
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
_writers[key].add_text('args', str(vars(self.args)))
_writers[key].add_text('sys.argv', " ".join(sys.argv))
return _writers[key]

Expand Down
2 changes: 1 addition & 1 deletion fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def get_parser(desc, default_task="translation"):
parser = argparse.ArgumentParser(allow_abbrev=False)
# fmt: off
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
Expand Down
6 changes: 3 additions & 3 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import torch

from fairseq import checkpoint_utils, distributed_utils, metrics, models, optim, utils
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.file_io import PathManager
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.logging import meters, metrics
from fairseq.optim import lr_scheduler


Expand Down Expand Up @@ -226,7 +226,7 @@ def load_checkpoint(

# reset TimeMeters, since their start times don't make sense anymore
for meter in metrics.get_meters("default"):
if isinstance(meter, TimeMeter):
if isinstance(meter, meters.TimeMeter):
meter.reset()
else:
logger.info("no existing checkpoint found {}".format(filename))
Expand Down
Loading

0 comments on commit f8b795f

Please sign in to comment.