diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index b634a9aec156..bebbb94eb038 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -24,6 +24,7 @@ from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd +from .utils import _check_metrics from ...data import DataLoader from ...loss import SoftmaxCrossEntropyLoss from ...loss import Loss as gluon_loss @@ -68,7 +69,7 @@ def __init__(self, net, self.net = net self.loss = self._check_loss(loss) - self.train_metrics = self._check_metrics(metrics) + self.train_metrics = _check_metrics(metrics) self.context = self._check_context(context) self._initialize(initializer) @@ -84,18 +85,6 @@ def _check_loss(self, loss): "refer to gluon.loss.Loss:{}".format(loss)) return loss - def _check_metrics(self, metrics): - if isinstance(metrics, CompositeEvalMetric): - metrics = metrics.metrics - elif isinstance(metrics, EvalMetric): - metrics = [metrics] - else: - metrics = metrics or [] - if not all([isinstance(metric, EvalMetric) for metric in metrics]): - raise ValueError("metrics must be a Metric or a list of Metric, " - "refer to mxnet.metric.EvalMetric:{}".format(metrics)) - return metrics - def _check_context(self, context): # infer available context gpus = num_gpus() diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index da2c84455e35..c5a4f1a3f836 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -26,8 +26,9 @@ import numpy as np -from ....metric import EvalMetric +from ....metric import EvalMetric, CompositeEvalMetric from ....metric import Loss as metric_loss +from .utils import _check_metrics __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', @@ -118,7 +119,7 @@ class MetricHandler(EpochBegin, BatchEnd): """ def __init__(self, train_metrics): - self.train_metrics = train_metrics or [] + self.train_metrics = _check_metrics(train_metrics) # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them self.priority = -np.Inf @@ -173,7 +174,7 @@ def __init__(self, self.eval_fn = eval_fn self.epoch_period = epoch_period self.batch_period = batch_period - self.val_metrics = val_metrics + self.val_metrics = _check_metrics(val_metrics) self.current_batch = 0 self.current_epoch = 0 # order to be called among all callbacks @@ -255,8 +256,8 @@ def __init__(self, file_name=None, "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" % verbose) self.verbose = verbose - self.train_metrics = train_metrics or [] - self.val_metrics = val_metrics or [] + self.train_metrics = _check_metrics(train_metrics) + self.val_metrics = _check_metrics(val_metrics) self.batch_index = 0 self.current_epoch = 0 self.processed_samples = 0 @@ -637,6 +638,9 @@ def __init__(self, if not isinstance(monitor, EvalMetric): raise ValueError("Please provide one of the metric objects as monitor, " "You can create these objects using estimator.prepare_loss_and_metric()") + if isinstance(monitor, CompositeEvalMetric): + raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, " + "please specify a simple metric instead.") self.monitor = monitor self.baseline = baseline self.patience = patience diff --git a/python/mxnet/gluon/contrib/estimator/utils.py b/python/mxnet/gluon/contrib/estimator/utils.py new file mode 100644 index 000000000000..f5be0878e0d9 --- /dev/null +++ b/python/mxnet/gluon/contrib/estimator/utils.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable +"""Gluon Estimator Utility Functions""" + +from ....metric import EvalMetric, CompositeEvalMetric + +def _check_metrics(metrics): + if isinstance(metrics, CompositeEvalMetric): + metrics = [m for metric in metrics.metrics for m in _check_metrics(metric)] + elif isinstance(metrics, EvalMetric): + metrics = [metrics] + else: + metrics = metrics or [] + if not all([isinstance(metric, EvalMetric) for metric in metrics]): + raise ValueError("metrics must be a Metric or a list of Metric, " + "refer to mxnet.metric.EvalMetric:{}".format(metrics)) + return metrics diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py index 6386cdbd3128..233355b7ebfd 100644 --- a/tests/nightly/estimator/test_sentiment_rnn.py +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -191,10 +191,13 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr): # Define loss and evaluation metrics loss = gluon.loss.SoftmaxCrossEntropyLoss() metrics = mx.metric.CompositeEvalMetric() - metrics.add([mx.metric.Accuracy(), mx.metric.Loss()]) + acc = mx.metric.Accuracy() + nested_metrics = mx.metric.CompositeEvalMetric() + metrics.add([acc, mx.metric.Loss()]) + nested_metrics.add([metrics, mx.metric.Accuracy()]) # Define estimator - est = estimator.Estimator(net=net, loss=loss, metrics=metrics, + est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics, trainer=trainer, context=ctx) # Begin training est.fit(train_data=train_dataloader, val_data=test_dataloader,