Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update module API for other submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Jul 12, 2017
1 parent 955b13d commit 135e755
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def get_input_grads(self, merge_multi_context=True):
"""
raise NotImplementedError()

def update(self):
def update(self, sparse_pull_dict=None):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):

def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),),
force_init=False):
force_init=False, sparse_pull_dict=None):
"""Installs and initializes optimizers.
Parameters
Expand All @@ -356,7 +356,8 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
return

self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
force_init=force_init)
force_init=force_init,
sparse_pull_dict=sparse_pull_dict)
for mod in self._buckets.values():
if mod is not self._curr_module:
mod.borrow_optimizer(self._curr_module)
Expand Down Expand Up @@ -399,13 +400,13 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)

def update(self):
def update(self, sparse_pull_dict=None):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update()
self._curr_module.update(sparse_pull_dict=sparse_pull_dict)

def get_outputs(self, merge_multi_context=True):
"""Gets outputs from a previous forward computation.
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/module/python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
"""
pass

def update(self):
def update(self, sparse_pull_dict=None):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch. Currently we do nothing here. Subclass should
override this method if contains parameters.
Expand Down Expand Up @@ -196,7 +196,8 @@ def _compute_output_shapes(self):
raise NotImplementedError()

def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
optimizer_params=(('learning_rate', 0.01),), force_init=False,
sparse_pull_dict=None):
"""Installs and initializes optimizers. By default we do nothing. Subclass should
override this method if needed.
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/module/sequential_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,

def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),),
force_init=False):
force_init=False,
sparse_pull_dict=None):
"""Installs and initializes optimizers.
Parameters
Expand All @@ -298,7 +299,8 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',

for module in self._modules:
module.init_optimizer(kvstore=kvstore, optimizer=optimizer,
optimizer_params=optimizer_params, force_init=force_init)
optimizer_params=optimizer_params, force_init=force_init,
sparse_pull_dict=sparse_pull_dict)

self.optimizer_initialized = True

Expand Down Expand Up @@ -344,14 +346,14 @@ def backward(self, out_grads=None):

out_grads = module.get_input_grads()

def update(self):
def update(self, sparse_pull_dict=None):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized

for module in self._modules:
module.update()
module.update(sparse_pull_dict=None)

def get_outputs(self, merge_multi_context=True):
"""Gets outputs from a previous forward computation.
Expand Down

0 comments on commit 135e755

Please sign in to comment.