Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix composite metric case in handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 31, 2019
1 parent 9396030 commit 2ad3224
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 20 deletions.
15 changes: 2 additions & 13 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .utils import _check_metrics
from ...data import DataLoader
from ...loss import SoftmaxCrossEntropyLoss
from ...loss import Loss as gluon_loss
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, net,

self.net = net
self.loss = self._check_loss(loss)
self.train_metrics = self._check_metrics(metrics)
self.train_metrics = _check_metrics(metrics)

self.context = self._check_context(context)
self._initialize(initializer)
Expand All @@ -84,18 +85,6 @@ def _check_loss(self, loss):
"refer to gluon.loss.Loss:{}".format(loss))
return loss

def _check_metrics(self, metrics):
if isinstance(metrics, CompositeEvalMetric):
metrics = metrics.metrics
elif isinstance(metrics, EvalMetric):
metrics = [metrics]
else:
metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
return metrics

def _check_context(self, context):
# infer available context
gpus = num_gpus()
Expand Down
14 changes: 9 additions & 5 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

import numpy as np

from ....metric import EvalMetric
from ....metric import EvalMetric, CompositeEvalMetric
from ....metric import Loss as metric_loss
from .utils import _check_metrics

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
Expand Down Expand Up @@ -118,7 +119,7 @@ class MetricHandler(EpochBegin, BatchEnd):
"""

def __init__(self, train_metrics):
self.train_metrics = train_metrics or []
self.train_metrics = _check_metrics(train_metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(self,
self.eval_fn = eval_fn
self.epoch_period = epoch_period
self.batch_period = batch_period
self.val_metrics = val_metrics
self.val_metrics = _check_metrics(val_metrics)
self.current_batch = 0
self.current_epoch = 0
# order to be called among all callbacks
Expand Down Expand Up @@ -255,8 +256,8 @@ def __init__(self, file_name=None,
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
% verbose)
self.verbose = verbose
self.train_metrics = train_metrics or []
self.val_metrics = val_metrics or []
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
Expand Down Expand Up @@ -637,6 +638,9 @@ def __init__(self,
if not isinstance(monitor, EvalMetric):
raise ValueError("Please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
if isinstance(monitor, CompositeEvalMetric):
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
"please specify a simple metric instead.")
self.monitor = monitor
self.baseline = baseline
self.patience = patience
Expand Down
34 changes: 34 additions & 0 deletions python/mxnet/gluon/contrib/estimator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator Utility Functions"""

from ....metric import EvalMetric, CompositeEvalMetric

def _check_metrics(metrics):
if isinstance(metrics, CompositeEvalMetric):
metrics = [m for metric in metrics.metrics for m in _check_metrics(metric)]
elif isinstance(metrics, EvalMetric):
metrics = [metrics]
else:
metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
return metrics
7 changes: 5 additions & 2 deletions tests/nightly/estimator/test_sentiment_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,13 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
# Define loss and evaluation metrics
loss = gluon.loss.SoftmaxCrossEntropyLoss()
metrics = mx.metric.CompositeEvalMetric()
metrics.add([mx.metric.Accuracy(), mx.metric.Loss()])
acc = mx.metric.Accuracy()
nested_metrics = mx.metric.CompositeEvalMetric()
metrics.add([acc, mx.metric.Loss()])
nested_metrics.add([metrics, mx.metric.Accuracy()])

# Define estimator
est = estimator.Estimator(net=net, loss=loss, metrics=metrics,
est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics,
trainer=trainer, context=ctx)
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
Expand Down

0 comments on commit 2ad3224

Please sign in to comment.