Skip to content

Commit

Permalink
fixed checkpointing for tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharan-devarajan committed Nov 29, 2023
1 parent 6221b33 commit 8a1fb5a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions dlio_benchmark/framework/tf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def __init__(self, profiling):
self.optimization_state = None
if len(self.args.optimization_groups) > 0:
self.optimization_state = dict()
tensor_array = []
tensor_array_size = 0
for index, state in enumerate(self.args.optimization_groups):
if state > 0:
self.optimization_state[str(index)] = {'a': self._get_tensor(state*num_ranks),
'b': self._get_tensor(state*num_ranks)}
tensor_array.append(self._get_tensor(state*num_ranks))
self.optimization_state["combined"] = tensor_array
self.optimization_state[str(index)] = {'a': self._get_tensor(state),
'b': self._get_tensor(state)}
tensor_array_size += state
self.optimization_state["combined"] = self._get_tensor(tensor_array_size)
self.layer_state = None
if len(self.args.layer_parameters) > 0:
self.layer_state = dict()
Expand Down
6 changes: 3 additions & 3 deletions dlio_benchmark/framework/torch_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def __init__(self, profiling):
self.optimization_state = None
if len(self.args.optimization_groups) > 0:
self.optimization_state = dict()
tensor_array = []
tensor_array_size = 0
for index, state in enumerate(self.args.optimization_groups):
if state > 0:
self.optimization_state[str(index)] = {'a': self._get_tensor(state), 'b': self._get_tensor(state)}
tensor_array.append(self._get_tensor(state))
self.optimization_state["combined"] = tensor_array
tensor_array_size += state
self.optimization_state["combined"] = self._get_tensor(tensor_array_size)
self.layer_state = None
if len(self.args.layer_parameters) > 0:
self.layer_state = dict()
Expand Down

0 comments on commit 8a1fb5a

Please sign in to comment.