Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Estimator] refactor estimator and clarify docs #16694

Merged
merged 3 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 57 additions & 77 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .utils import _check_metrics
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
from ...data import DataLoader
from ...loss import SoftmaxCrossEntropyLoss
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from .... import autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import Accuracy
from ....metric import Loss as metric_loss

__all__ = ['Estimator']
Expand All @@ -48,8 +47,8 @@ class Estimator(object):
----------
net : gluon.Block
The model used for training.
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models.
initializer : Initializer
Expand All @@ -69,19 +68,17 @@ def __init__(self, net,

self.net = net
self.loss = self._check_loss(loss)
self.train_metrics = _check_metrics(metrics)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()

self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)

def _check_loss(self, loss):
if isinstance(loss, gluon_loss):
loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon_loss) for l in loss]):
loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, "
if not isinstance(loss, gluon_loss):
raise ValueError("loss must be a Loss, "
"refer to gluon.loss.Loss:{}".format(loss))
return loss

Expand Down Expand Up @@ -166,31 +163,30 @@ def _get_data_and_label(self, batch, ctx, batch_axis=0):
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def prepare_loss_and_metrics(self):
"""
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
def _add_default_training_metrics(self):
if not self._train_metrics:
suggested_metric = _suggest_metric_for_loss(self.loss)
if suggested_metric:
self._train_metrics = [suggested_metric]
loss_name = self.loss.name.rstrip('1234567890')
self._train_metrics.append(metric_loss(loss_name))

Returns
-------
train_metrics, val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
# Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
# remove trailing numbers from loss name to avoid confusion
self.train_metrics.append(metric_loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "train " + metric.name
val_metric.name = "validation " + val_metric.name
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics
for metric in self._train_metrics:
metric.name = "training " + metric.name

def _add_validation_metrics(self):
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]

for metric in self._val_metrics:
metric.name = "validation " + metric.name

@property
def train_metrics(self):
return self._train_metrics

@property
def val_metrics(self):
return self._val_metrics

def evaluate_batch(self,
val_batch,
Expand All @@ -209,7 +205,7 @@ def evaluate_batch(self,
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
Expand Down Expand Up @@ -275,7 +271,7 @@ def fit_batch(self, train_batch,

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()
Expand Down Expand Up @@ -377,63 +373,47 @@ def fit(self, train_data,
handler.train_end(estimator_ref)

def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
self.prepare_loss_and_metrics()
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
default_handlers.append("StoppingHandler")
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=self.train_metrics))
default_handlers.append("MetricHandler")
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))

if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
# add default validation handler if validation data found
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=self.val_metrics))
default_handlers.append("ValidationHandler")
val_metrics = self.val_metrics
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
val_metrics=val_metrics))
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")
added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of loss and metrics
if default_handlers and len(event_handlers) != len(default_handlers):
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
# they should have the same set of metrics
mixing_handlers = event_handlers and added_default_handlers

event_handlers.extend(added_default_handlers)

if mixing_handlers:
msg = "The following default event handlers are added: {}.".format(
", ".join([type(h).__name__ for h in added_default_handlers]))
warnings.warn(msg)
# check if all handlers has the same set of references to loss and metrics
references = []


# check if all handlers have the same set of references to metrics
known_metrics = set(self.train_metrics + self.val_metrics)
for handler in event_handlers:
for attribute in dir(handler):
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
reference = getattr(handler, attribute)
if isinstance(reference, list):
references += reference
else:
references.append(reference)
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric not in self.train_metrics + self.val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
"for all your handlers." % \
", ".join(default_handlers)
raise ValueError(msg)
_check_handler_metric_ref(handler, known_metrics)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers
Expand Down
51 changes: 33 additions & 18 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-argument
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
"""Gluon EventHandlers for Estimators"""

import logging
Expand All @@ -34,33 +34,47 @@
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']

class EventHandler(object):
pass

class TrainBegin(object):

def _check_event_handlers(handlers):
if isinstance(handlers, EventHandler):
handlers = [handlers]
else:
handlers = handlers or []
if not all([isinstance(handler, EventHandler) for handler in handlers]):
raise ValueError("handlers must be an EventHandler or a list of EventHandler, "
"got: {}".format(handlers))
return handlers


class TrainBegin(EventHandler):
def train_begin(self, estimator, *args, **kwargs):
pass


class TrainEnd(object):
class TrainEnd(EventHandler):
def train_end(self, estimator, *args, **kwargs):
pass


class EpochBegin(object):
class EpochBegin(EventHandler):
def epoch_begin(self, estimator, *args, **kwargs):
pass


class EpochEnd(object):
class EpochEnd(EventHandler):
def epoch_end(self, estimator, *args, **kwargs):
return False


class BatchBegin(object):
class BatchBegin(EventHandler):
def batch_begin(self, estimator, *args, **kwargs):
pass


class BatchEnd(object):
class BatchEnd(EventHandler):
def batch_end(self, estimator, *args, **kwargs):
return False

Expand Down Expand Up @@ -393,8 +407,8 @@ def __init__(self,
self.model_prefix = model_prefix
self.save_best = save_best
if self.save_best and not isinstance(self.monitor, EvalMetric):
raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
"You can get these objects using estimator.prepare_loss_and_metric()")
raise ValueError("To save best model only, please provide one of the metric objects "
"from estimator.train_metrics and estimator.val_metrics as monitor.")
self.epoch_period = epoch_period
self.batch_period = batch_period
self.current_batch = 0
Expand Down Expand Up @@ -487,10 +501,10 @@ def _save_checkpoint(self, estimator):
monitor_name, monitor_value = self.monitor.get()
# check if monitor exists in train stats
if np.isnan(monitor_value):
warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you '
'pass one of the metric objects as monitor, '
'you can use estimator.prepare_loss_and_metrics to'
'create all metric objects', monitor_name))
warnings.warn(RuntimeWarning(
'Skipping save best because %s is not updated, make sure you pass one of the '
'metric objects estimator.train_metrics and estimator.val_metrics as monitor',
monitor_name))
else:
if self.monitor_op(monitor_value, self.best):
prefix = self.model_prefix + '-best'
Expand Down Expand Up @@ -636,8 +650,9 @@ def __init__(self,
super(EarlyStoppingHandler, self).__init__()

if not isinstance(monitor, EvalMetric):
raise ValueError("Please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
raise ValueError(
"Please provide one of the metric objects from estimator.train_metrics and "
"estimator.val_metrics as monitor.")
if isinstance(monitor, CompositeEvalMetric):
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
"please specify a simple metric instead.")
Expand Down Expand Up @@ -693,9 +708,9 @@ def train_begin(self, estimator, *args, **kwargs):
def epoch_end(self, estimator, *args, **kwargs):
monitor_name, monitor_value = self.monitor.get()
if np.isnan(monitor_value):
warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
'as monitor, you can use estimator.prepare_loss_and_metrics to'
'create all metric objects', monitor_name))
warnings.warn(RuntimeWarning(
'%s is not updated, make sure you pass one of the metric objects from'
'estimator.train_metrics and estimator.val_metrics as monitor.', monitor_name))
else:
if self.monitor_op(monitor_value - self.min_delta, self.best):
self.best = monitor_value
Expand Down
31 changes: 29 additions & 2 deletions python/mxnet/gluon/contrib/estimator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator Utility Functions"""

from ....metric import EvalMetric, CompositeEvalMetric
from ...loss import SoftmaxCrossEntropyLoss
from ....metric import Accuracy, EvalMetric, CompositeEvalMetric

def _check_metrics(metrics):
if isinstance(metrics, CompositeEvalMetric):
Expand All @@ -30,5 +31,31 @@ def _check_metrics(metrics):
metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
"refer to mxnet.metric.EvalMetric: {}".format(metrics))
return metrics

def _check_handler_metric_ref(handler, known_metrics):
for attribute in dir(handler):
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
reference = getattr(handler, attribute)
if not reference:
continue
elif isinstance(reference, list):
for metric in reference:
_check_metric_known(handler, metric, known_metrics)
else:
_check_metric_known(handler, reference, known_metrics)

def _check_metric_known(handler, metric, known_metrics):
if metric not in known_metrics:
raise ValueError(
'Event handler {} refers to a metric instance {} outside of '
'the known training and validation metrics. Please use the metrics from '
'estimator.train_metrics and estimator.val_metrics '
'instead.'.format(type(handler).__name__,
metric))

def _suggest_metric_for_loss(loss):
if isinstance(loss, SoftmaxCrossEntropyLoss):
return Accuracy()
return None
Loading