forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
331 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
early-stop | ||
======== | ||
早停算法接口在实验中如何使用 | ||
|
||
MedianStop | ||
------ | ||
|
||
.. py:class:: paddleslim.nas.early_stop.MedianStop(strategy, start_epoch, mode) | ||
`源代码 <>`_ | ||
|
||
MedianStop是利用历史较好实验的中间结果来判断当前实验是否有运行完成的必要,如果当前实验在中间步骤的结果差于历史记录的实验列表中相同步骤的结果的中值,则代表当前实验是较差的实验,可以提前终止。参考 `Google Vizier: A Service for Black-Box Optimization <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf>`_. | ||
|
||
**参数:** | ||
|
||
- **strategy<class instance>** - 搜索策略的实例,例如是SANAS的实例。 | ||
- **start_epoch<int>** - 起始epoch,代表从第几个epoch开始监控实验中间结果。 | ||
- **mode<str>** - 中间结果是越大越好还是越小越好,在'minimize'和'maxmize'之间选择。默认:'maxmize'。 | ||
|
||
**返回:** | ||
一个MedianStop的实例 | ||
|
||
**示例代码:** | ||
|
||
.. code-block:: python | ||
from paddleslim.nas import SANAS | ||
from paddleslim.nas.early_stop import MedianStop | ||
config = [('MobileNetV2Space')] | ||
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None) | ||
earlystop = MedianStop(sanas, start_epoch = 2) | ||
.. py:method:: get_status(step, result, epochs): | ||
获取当前实验当前result的状态。 | ||
**参数:** | ||
- **step<int>** - 当前实验是当前client中的第几个实验。 | ||
- **result<float>** - 当前实验的中间步骤的result,可以为损失值,也可以为准确率等指标,只要和`mode`对应即可。 | ||
- **epochs<int>** - 在搜索过程中每个实验需要运行的总得epoch数量。 | ||
**返回:** | ||
返回当前实验在当前epoch的状态,为`GOOD`或者`BAD`,如果为`BAD`,则代表当前实验可以早停。 | ||
**示例代码:** | ||
.. code-block:: python | ||
from paddleslim.nas import SANAS | ||
from paddleslim.nas.early_stop import MedianStop | ||
steps = 10 | ||
epochs = 7 | ||
config = [('MobileNetV2Space')] | ||
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None) | ||
earlystop = MedianStop(sanas, 2) | ||
for step in range(steps): | ||
archs = sanas.next_archs()[0] | ||
for epoch in range(epochs): | ||
for data in train_reader(): | ||
loss = archs(data) | ||
for data in test_reader(): | ||
loss = archs(data) | ||
avg_cost = np.mean(loss) | ||
status = earlystop.get_status(step, avg_cost, epochs) | ||
if status == 'BAD': | ||
break; | ||
sanas.reward(avg_cost) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import absolute_import | ||
from .early_stop import EarlyStopBase | ||
from .median_stop import MedianStop | ||
|
||
__all__ = ['EarlyStopBase', 'MedianStop'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
__all__ = ['EarlyStopBase'] | ||
|
||
|
||
class EarlyStopBase(object): | ||
""" Abstract early Stop algorithm. | ||
""" | ||
|
||
def get_status(self, iter, result): | ||
"""Get experiment status. | ||
""" | ||
raise NotImplementedError( | ||
'get_status in Early Stop algorithm NOT implemented.') | ||
|
||
def client_end(self): | ||
""" Stop a client, this function may useful for the client that result is better and better. | ||
""" | ||
raise NotImplementedError( | ||
'client_end in Early Stop algorithm NOT implemented.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import absolute_import | ||
from .median_stop import MedianStop | ||
|
||
__all__ = ['MedianStop'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from multiprocessing.managers import BaseManager | ||
from ..early_stop import EarlyStopBase | ||
from ....common.log_helper import get_logger | ||
|
||
PublicAuthKey = u'AbcXyz3' | ||
|
||
__all__ = ['MedianStop'] | ||
|
||
_logger = get_logger(__name__, level=logging.INFO) | ||
|
||
completed_history = dict() | ||
|
||
|
||
def return_completed_history(): | ||
return completed_history | ||
|
||
|
||
class MedianStop(EarlyStopBase): | ||
""" | ||
Median Stop, reference: | ||
Args: | ||
strategy<class instance>: the stategy of search. | ||
start_epoch<int>: which step to start early stop algorithm. | ||
mode<str>: bigger is better or smaller is better, chooice in ['maxmize', 'minimize']. Default: maxmize. | ||
""" | ||
|
||
def __init__(self, strategy, start_epoch, mode='maxmize'): | ||
self._start_epoch = start_epoch | ||
self._running_history = dict() | ||
self._strategy = strategy | ||
self._mode = mode | ||
self._is_server = self._strategy._is_server | ||
self._manager = self._start_manager() | ||
assert self._mode in [ | ||
'maxmize', 'minimize' | ||
], 'mode of MedianStop must be \'maxmize\' or \'minimize\', but received mode is {}'.format( | ||
self._mode) | ||
|
||
def _start_manager(self): | ||
self._server_ip = self._strategy._server_ip | ||
self._server_port = self._strategy._server_port + 1 | ||
|
||
if self._is_server: | ||
BaseManager.register( | ||
'get_completed_history', callable=return_completed_history) | ||
base_manager = BaseManager( | ||
address=(self._server_ip, self._server_port), | ||
authkey=PublicAuthKey) | ||
|
||
base_manager.start() | ||
else: | ||
BaseManager.register('get_completed_history') | ||
base_manager = BaseManager( | ||
address=(self._server_ip, self._server_port), | ||
authkey=PublicAuthKey) | ||
base_manager.connect() | ||
return base_manager | ||
|
||
def _update_data(self, exp_name, result): | ||
if exp_name not in self._running_history.keys(): | ||
self._running_history[exp_name] = [] | ||
self._running_history[exp_name].append(result) | ||
|
||
def _convert_running2completed(self, exp_name, status): | ||
""" | ||
Convert experiment record from running to complete. | ||
Args: | ||
exp_name<str>: the name of experiment. | ||
status<str>: the status of this experiment. | ||
""" | ||
_logger.debug('the status of this experiment is {}'.format(status)) | ||
completed_avg_history = dict() | ||
if exp_name in self._running_history: | ||
if status == "GOOD": | ||
count = 0 | ||
history_sum = 0 | ||
result = [] | ||
for res in self._running_history[exp_name]: | ||
count += 1 | ||
history_sum += res | ||
result.append(history_sum / count) | ||
completed_avg_history[exp_name] = result | ||
self._running_history.pop(exp_name) | ||
|
||
if len(completed_avg_history) > 0: | ||
while True: | ||
try: | ||
new_dict = self._manager.get_completed_history() | ||
new_dict.update(completed_avg_history) | ||
break | ||
except Exception as err: | ||
_logger.error("update data error: {}".format(err)) | ||
|
||
def get_status(self, step, result, epochs): | ||
""" | ||
Get current experiment status | ||
Args: | ||
step: step in this client. | ||
result: the result of this epoch. | ||
epochs: whole epochs. | ||
Return: | ||
the status of this experiment. | ||
""" | ||
exp_name = self._strategy._client_name + str(step) | ||
self._update_data(exp_name, result) | ||
|
||
_logger.debug("running history after update data: {}".format( | ||
self._running_history)) | ||
|
||
curr_step = len(self._running_history[exp_name]) | ||
status = "GOOD" | ||
if curr_step < self._start_epoch: | ||
return status | ||
|
||
res_same_step = [] | ||
|
||
def list2dict(lists): | ||
res_dict = dict() | ||
for l in lists: | ||
tmp_dict = dict() | ||
tmp_dict[l[0]] = l[1] | ||
res_dict.update(tmp_dict) | ||
return res_dict | ||
|
||
while True: | ||
try: | ||
completed_avg_history = self._manager.get_completed_history() | ||
break | ||
except Exception as err: | ||
_logger.error("get status error: {}".format(err)) | ||
|
||
if len(completed_avg_history.keys()) == 0: | ||
for exp in self._running_history.keys(): | ||
if curr_step <= len(self._running_history[exp]): | ||
res_same_step.append(self._running_history[exp][curr_step - | ||
1]) | ||
else: | ||
completed_avg_history_dict = list2dict(completed_avg_history.items( | ||
)) | ||
|
||
for exp in completed_avg_history.keys(): | ||
if curr_step <= len(completed_avg_history_dict[exp]): | ||
res_same_step.append(completed_avg_history_dict[exp][ | ||
curr_step - 1]) | ||
|
||
_logger.debug("result of same step in other experiment: {}".format( | ||
res_same_step)) | ||
if res_same_step: | ||
res_same_step.sort() | ||
|
||
if self._mode == 'maxmize' and result < res_same_step[( | ||
len(res_same_step) - 1) // 2]: | ||
status = "BAD" | ||
|
||
if self._mode == 'minimize' and result > res_same_step[len( | ||
res_same_step) // 2]: | ||
status = "BAD" | ||
|
||
if curr_step == epochs: | ||
self._convert_running2completed(exp_name, status) | ||
|
||
return status | ||
|
||
def __del__(self): | ||
if self._is_server: | ||
self._manager.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters