Skip to content

Commit

Permalink
add early stop (PaddlePaddle#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored Apr 8, 2020
1 parent abb30ef commit 7d1ec56
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 7 deletions.
73 changes: 73 additions & 0 deletions docs/zh_cn/api_cn/early_stop.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
early-stop
========
早停算法接口在实验中如何使用

MedianStop
------

.. py:class:: paddleslim.nas.early_stop.MedianStop(strategy, start_epoch, mode)
`源代码 <>`_

MedianStop是利用历史较好实验的中间结果来判断当前实验是否有运行完成的必要,如果当前实验在中间步骤的结果差于历史记录的实验列表中相同步骤的结果的中值,则代表当前实验是较差的实验,可以提前终止。参考 `Google Vizier: A Service for Black-Box Optimization <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf>`_.

**参数:**

- **strategy<class instance>** - 搜索策略的实例,例如是SANAS的实例。
- **start_epoch<int>** - 起始epoch,代表从第几个epoch开始监控实验中间结果。
- **mode<str>** - 中间结果是越大越好还是越小越好,在'minimize'和'maxmize'之间选择。默认:'maxmize'。

**返回:**
一个MedianStop的实例

**示例代码:**

.. code-block:: python
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, start_epoch = 2)
.. py:method:: get_status(step, result, epochs):
获取当前实验当前result的状态。
**参数:**
- **step<int>** - 当前实验是当前client中的第几个实验。
- **result<float>** - 当前实验的中间步骤的result,可以为损失值,也可以为准确率等指标,只要和`mode`对应即可。
- **epochs<int>** - 在搜索过程中每个实验需要运行的总得epoch数量。
**返回:**
返回当前实验在当前epoch的状态,为`GOOD`或者`BAD`,如果为`BAD`,则代表当前实验可以早停。
**示例代码:**
.. code-block:: python
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
steps = 10
epochs = 7
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, 2)
for step in range(steps):
archs = sanas.next_archs()[0]
for epoch in range(epochs):
for data in train_reader():
loss = archs(data)
for data in test_reader():
loss = archs(data)
avg_cost = np.mean(loss)
status = earlystop.get_status(step, avg_cost, epochs)
if status == 'BAD':
break;
sanas.reward(avg_cost)
2 changes: 1 addition & 1 deletion paddleslim/common/sa_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def current_tokens(self):

return self._current_tokens

def update(self, tokens, reward, iter, client_num):
def update(self, tokens, reward, iter, client_num=1):
"""
Update the controller according to latest tokens and reward.
Expand Down
18 changes: 18 additions & 0 deletions paddleslim/nas/early_stop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .early_stop import EarlyStopBase
from .median_stop import MedianStop

__all__ = ['EarlyStopBase', 'MedianStop']
32 changes: 32 additions & 0 deletions paddleslim/nas/early_stop/early_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['EarlyStopBase']


class EarlyStopBase(object):
""" Abstract early Stop algorithm.
"""

def get_status(self, iter, result):
"""Get experiment status.
"""
raise NotImplementedError(
'get_status in Early Stop algorithm NOT implemented.')

def client_end(self):
""" Stop a client, this function may useful for the client that result is better and better.
"""
raise NotImplementedError(
'client_end in Early Stop algorithm NOT implemented.')
17 changes: 17 additions & 0 deletions paddleslim/nas/early_stop/median_stop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .median_stop import MedianStop

__all__ = ['MedianStop']
184 changes: 184 additions & 0 deletions paddleslim/nas/early_stop/median_stop/median_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from multiprocessing.managers import BaseManager
from ..early_stop import EarlyStopBase
from ....common.log_helper import get_logger

PublicAuthKey = u'AbcXyz3'

__all__ = ['MedianStop']

_logger = get_logger(__name__, level=logging.INFO)

completed_history = dict()


def return_completed_history():
return completed_history


class MedianStop(EarlyStopBase):
"""
Median Stop, reference:
Args:
strategy<class instance>: the stategy of search.
start_epoch<int>: which step to start early stop algorithm.
mode<str>: bigger is better or smaller is better, chooice in ['maxmize', 'minimize']. Default: maxmize.
"""

def __init__(self, strategy, start_epoch, mode='maxmize'):
self._start_epoch = start_epoch
self._running_history = dict()
self._strategy = strategy
self._mode = mode
self._is_server = self._strategy._is_server
self._manager = self._start_manager()
assert self._mode in [
'maxmize', 'minimize'
], 'mode of MedianStop must be \'maxmize\' or \'minimize\', but received mode is {}'.format(
self._mode)

def _start_manager(self):
self._server_ip = self._strategy._server_ip
self._server_port = self._strategy._server_port + 1

if self._is_server:
BaseManager.register(
'get_completed_history', callable=return_completed_history)
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)

base_manager.start()
else:
BaseManager.register('get_completed_history')
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)
base_manager.connect()
return base_manager

def _update_data(self, exp_name, result):
if exp_name not in self._running_history.keys():
self._running_history[exp_name] = []
self._running_history[exp_name].append(result)

def _convert_running2completed(self, exp_name, status):
"""
Convert experiment record from running to complete.
Args:
exp_name<str>: the name of experiment.
status<str>: the status of this experiment.
"""
_logger.debug('the status of this experiment is {}'.format(status))
completed_avg_history = dict()
if exp_name in self._running_history:
if status == "GOOD":
count = 0
history_sum = 0
result = []
for res in self._running_history[exp_name]:
count += 1
history_sum += res
result.append(history_sum / count)
completed_avg_history[exp_name] = result
self._running_history.pop(exp_name)

if len(completed_avg_history) > 0:
while True:
try:
new_dict = self._manager.get_completed_history()
new_dict.update(completed_avg_history)
break
except Exception as err:
_logger.error("update data error: {}".format(err))

def get_status(self, step, result, epochs):
"""
Get current experiment status
Args:
step: step in this client.
result: the result of this epoch.
epochs: whole epochs.
Return:
the status of this experiment.
"""
exp_name = self._strategy._client_name + str(step)
self._update_data(exp_name, result)

_logger.debug("running history after update data: {}".format(
self._running_history))

curr_step = len(self._running_history[exp_name])
status = "GOOD"
if curr_step < self._start_epoch:
return status

res_same_step = []

def list2dict(lists):
res_dict = dict()
for l in lists:
tmp_dict = dict()
tmp_dict[l[0]] = l[1]
res_dict.update(tmp_dict)
return res_dict

while True:
try:
completed_avg_history = self._manager.get_completed_history()
break
except Exception as err:
_logger.error("get status error: {}".format(err))

if len(completed_avg_history.keys()) == 0:
for exp in self._running_history.keys():
if curr_step <= len(self._running_history[exp]):
res_same_step.append(self._running_history[exp][curr_step -
1])
else:
completed_avg_history_dict = list2dict(completed_avg_history.items(
))

for exp in completed_avg_history.keys():
if curr_step <= len(completed_avg_history_dict[exp]):
res_same_step.append(completed_avg_history_dict[exp][
curr_step - 1])

_logger.debug("result of same step in other experiment: {}".format(
res_same_step))
if res_same_step:
res_same_step.sort()

if self._mode == 'maxmize' and result < res_same_step[(
len(res_same_step) - 1) // 2]:
status = "BAD"

if self._mode == 'minimize' and result > res_same_step[len(
res_same_step) // 2]:
status = "BAD"

if curr_step == epochs:
self._convert_running2completed(exp_name, status)

return status

def __del__(self):
if self._is_server:
self._manager.shutdown()
12 changes: 6 additions & 6 deletions paddleslim/nas/sa_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def __init__(self,
self._key = str(self._configs)
self._current_tokens = init_tokens

server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._server_ip, self._server_port = server_addr
if self._server_ip == None or self._server_ip == "":
self._server_ip = self._get_host_ip()

factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
Expand Down Expand Up @@ -171,16 +171,16 @@ def __init__(self,
max_client_num = 100
self._controller_server = ControllerServer(
controller=self._controller,
address=(server_ip, server_port),
address=(self._server_ip, self._server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=self._key)
self._controller_server.start()
server_port = self._controller_server.port()

self._controller_client = ControllerClient(
server_ip,
server_port,
self._server_ip,
self._server_port,
key=self._key,
client_name=self._client_name)

Expand Down

0 comments on commit 7d1ec56

Please sign in to comment.