Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/checkpoint' into feature…
Browse files Browse the repository at this point in the history
…/checkpoint
  • Loading branch information
hariharan-devarajan committed Jan 9, 2024
2 parents a397d11 + 63ff8c4 commit 3727e5a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
3 changes: 2 additions & 1 deletion dlio_benchmark/configs/workload/megatron_deepspeed.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
model: unet3d
# 8 node run with 4 GPUs per node and TPSIZE=4 and PPSIZE=8
model: megatron_deepspeed

framework: pytorch

Expand Down
2 changes: 1 addition & 1 deletion dlio_benchmark/data_generator/hdf5_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def generate(self):
"""
super().generate()
np.random.seed(10)
samples_per_iter=max(1, int(32*1024*1024/self._args.record_length))
samples_per_iter=max(1, int(self._args.generation_buffer_size/self._args.record_length))
record_labels = [0] * self.num_samples
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
progress(i, self.total_files_to_generate, "Generating HDF5 Data")
Expand Down
7 changes: 2 additions & 5 deletions dlio_benchmark/framework/tf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ def __init__(self, profiling):
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
rank_to_checkpoint = 0
if rank_to_checkpoint == self.args.my_rank:
num_ranks = 1
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
num_ranks = self.args.comm_size
if self.args.model_size > 0:
self.model_state = {"a": self._get_tensor(self.args.model_size*num_ranks)}
self.model_state = {"a": self._get_tensor(self.args.model_size)}
self.optimization_state = None
if len(self.args.optimization_groups) > 0:
self.optimization_state = dict()
Expand All @@ -78,7 +75,7 @@ def __init__(self, profiling):
self.layer_state = dict()
for index, state in enumerate(self.args.layer_parameters):
if state > 0:
self.layer_state[str(index)] = self._get_tensor(state*num_ranks)
self.layer_state[str(index)] = self._get_tensor(state)

def _get_tensor(self, size):
return tf.random.uniform((int(size / 4),), maxval=100, dtype=tf.dtypes.int32)
Expand Down

0 comments on commit 3727e5a

Please sign in to comment.