Skip to content

Commit

Permalink
fix ensure_bagua_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying committed Oct 21, 2021
1 parent dfd81f9 commit 2aac918
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 1 addition & 3 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,7 @@ def with_bagua( # pytype: disable=module-attr
... GradientAllReduce()
... )
"""
self.bagua_module_name = "{}_{}".format(
self.__class__.__name__, next(BaguaModule.__id_iter)
)
self.bagua_module_name = "{}_{}".format(self.__class__.__name__, id(self))

# set bucket process group
if process_group is None:
Expand Down
5 changes: 5 additions & 0 deletions bagua/torch_api/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def ensure_bagua_tensor(
self.bagua_tensor_name == name
), "assigning a different name to existing bagua tensor is forbidden"

assert (
self.bagua_module_name == module_name
), "assigning a different module name to existing bagua tensor is forbidden"
return self

self.bagua_tensor_name = name if name is not None else ""
self.bagua_module_name = module_name
self.bagua_backend = (
Expand Down

0 comments on commit 2aac918

Please sign in to comment.