Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Estimator] handle composite metrics in estimator #16676

Merged
merged 3 commits into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .utils import _check_metrics
from ...data import DataLoader
from ...loss import SoftmaxCrossEntropyLoss
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from .... import autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import EvalMetric, Accuracy
from ....metric import Accuracy
from ....metric import Loss as metric_loss

__all__ = ['Estimator']
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, net,

self.net = net
self.loss = self._check_loss(loss)
self.train_metrics = self._check_metrics(metrics)
self.train_metrics = _check_metrics(metrics)

self.context = self._check_context(context)
self._initialize(initializer)
Expand All @@ -84,16 +85,6 @@ def _check_loss(self, loss):
"refer to gluon.loss.Loss:{}".format(loss))
return loss

def _check_metrics(self, metrics):
if isinstance(metrics, EvalMetric):
metrics = [metrics]
else:
metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
return metrics

def _check_context(self, context):
# infer available context
gpus = num_gpus()
Expand Down
14 changes: 9 additions & 5 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

import numpy as np

from ....metric import EvalMetric
from ....metric import EvalMetric, CompositeEvalMetric
from ....metric import Loss as metric_loss
from .utils import _check_metrics

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
Expand Down Expand Up @@ -118,7 +119,7 @@ class MetricHandler(EpochBegin, BatchEnd):
"""

def __init__(self, train_metrics):
self.train_metrics = train_metrics or []
self.train_metrics = _check_metrics(train_metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(self,
self.eval_fn = eval_fn
self.epoch_period = epoch_period
self.batch_period = batch_period
self.val_metrics = val_metrics
self.val_metrics = _check_metrics(val_metrics)
self.current_batch = 0
self.current_epoch = 0
# order to be called among all callbacks
Expand Down Expand Up @@ -255,8 +256,8 @@ def __init__(self, file_name=None,
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
% verbose)
self.verbose = verbose
self.train_metrics = train_metrics or []
self.val_metrics = val_metrics or []
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
Expand Down Expand Up @@ -637,6 +638,9 @@ def __init__(self,
if not isinstance(monitor, EvalMetric):
raise ValueError("Please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
if isinstance(monitor, CompositeEvalMetric):
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
"please specify a simple metric instead.")
self.monitor = monitor
self.baseline = baseline
self.patience = patience
Expand Down
34 changes: 34 additions & 0 deletions python/mxnet/gluon/contrib/estimator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator Utility Functions"""

from ....metric import EvalMetric, CompositeEvalMetric

def _check_metrics(metrics):
if isinstance(metrics, CompositeEvalMetric):
metrics = [m for metric in metrics.metrics for m in _check_metrics(metric)]
elif isinstance(metrics, EvalMetric):
metrics = [metrics]
else:
metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
return metrics
6 changes: 5 additions & 1 deletion tests/nightly/estimator/test_sentiment_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,14 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
# Define loss and evaluation metrics
loss = gluon.loss.SoftmaxCrossEntropyLoss()
metrics = mx.metric.CompositeEvalMetric()
acc = mx.metric.Accuracy()
nested_metrics = mx.metric.CompositeEvalMetric()
metrics.add([acc, mx.metric.Loss()])
nested_metrics.add([metrics, mx.metric.Accuracy()])

# Define estimator
est = estimator.Estimator(net=net, loss=loss, metrics=acc,
est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics,
trainer=trainer, context=ctx)
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
roywei marked this conversation as resolved.
Show resolved Hide resolved
Expand Down