Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add shortcut to merge parameter into base setup #2540

Merged
merged 2 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/trials/mnist-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nni.utils import merge_parameter
from torchvision import datasets, transforms

logger = logging.getLogger('mnist_AutoML')
Expand Down Expand Up @@ -157,8 +158,7 @@ def get_params():
# get parameters form tuner
tuner_params = nni.get_next_parameter()
logger.debug(tuner_params)
params = vars(get_params())
params.update(tuner_params)
params = vars(merge_parameter(get_params(), tuner_params))
print(params)
main(params)
except Exception as exception:
Expand Down
40 changes: 40 additions & 0 deletions src/sdk/pynni/nni/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,43 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
else:
y = copy.deepcopy(x)
return y


def merge_parameter(base_params, override_params):
"""
Update the parameters in ``base_params`` with ``override_params``.
Can be useful to override parsed command line arguments.

Parameters
----------
base_params : namespace or dict
Base parameters. A key-value mapping.
override_params : dict or None
Parameters to override. Usually the parameters got from ``get_next_parameters()``.
When it is none, nothing will happen.

Returns
-------
namespace or dict
The updated ``base_params``. Note that ``base_params`` will be updated inplace. The return value is
only for convenience.
"""
if override_params is None:
return base_params
is_dict = isinstance(base_params, dict)
for k, v in override_params.items():
if is_dict:
if k not in base_params:
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(base_params[k]) != type(v) and base_params[k] is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(base_params[k]), type(v)))
base_params[k] = v
else:
if not hasattr(base_params, k):
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(getattr(base_params, k)) != type(v) and getattr(base_params, k) is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(getattr(base_params, k)), type(v)))
setattr(base_params, k, v)
return base_params