Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix a vital performance bug when run in graph mode deepray will… #34

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions deepray/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import json
import os
import sys
import time
from typing import Union, List, Dict, Text

Expand Down Expand Up @@ -708,7 +709,7 @@ def run_customized_training_loop(
# Training loop starts here.
self.current_step = self._first_steps = self.optimizer.iterations.numpy()

self.first_batch = True
self.first_batch = tf.Variable(True, trainable=False, dtype=tf.bool, name='first_batch')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating variable first_batch under init_scope would be better? Which makes sure this variable was created in eager mode to be compatible with python logical controlling code in any time, and also guarantees a safety from TF tape and other contexts.

with tf.init_scope():
    # Initialization runs with eager execution enabled
    # assert tf.executing_eagerly()
    tf.Variable(True, trainable=False, dtype=tf.bool, name='first_batch')

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if not hasattr(self.main_model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
Expand Down Expand Up @@ -749,11 +750,9 @@ def run_customized_training_loop(
elif steps > 1 and self.optimizer.iterations.numpy() > self.current_step:
steps = self.optimizer.iterations.numpy() - self.current_step
training_logs = self.get_metrics_result()
self.first_batch = False
self.on_batch_end(training_logs, steps, t0)
break

self.first_batch = False
self.on_batch_end(training_logs, steps, t0)
self.on_epoch_end(epoch, self.current_step, eval_input, epoch_logs=training_logs)
if self.main_model.stop_training:
Expand Down Expand Up @@ -819,11 +818,10 @@ def train_single_step(self, iterator, num_grad_accumulates):
self.forward(iterator)
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
if self.use_horovod and self.first_batch:
self.do_broadcast()
else:
self._replicated_step(iterator, self.first_batch)
self._replicated_step(iterator)
return self.get_metrics_result()

@property
Expand All @@ -833,7 +831,24 @@ def trainable_variables(self):
else:
return self.main_model.trainable_variables

def _replicated_step(self, inputs, first_batch=False):
def do_broadcast(self):
broadcast_vars = [
var for var in self.main_model.variables
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
opt_broadcast_vars = [
var for var in self.optimizer.variables()
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]

print_op = tf.print(
f"Broadcasting {len(broadcast_vars + opt_broadcast_vars)} variables...", output_stream=sys.stdout
)
with tf.control_dependencies([print_op]):
hvd.broadcast_variables(broadcast_vars + opt_broadcast_vars, root_rank=0)
self.first_batch.assign(False)

def _replicated_step(self, inputs):
"""Replicated training step."""
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
with tf.GradientTape() as tape:
Expand All @@ -847,18 +862,8 @@ def _replicated_step(self, inputs, first_batch=False):
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)

if self.use_horovod and first_batch:
broadcast_vars = [
var for var in self.main_model.variables
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
hvd.broadcast_variables(broadcast_vars, root_rank=0)

opt_broadcast_vars = [
var for var in self.optimizer.variables()
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
hvd.broadcast_variables(opt_broadcast_vars, root_rank=0)
if self.use_horovod and self.first_batch:
self.do_broadcast()

# For reporting, the metric takes the mean of losses.
if self.metric_container:
Expand Down Expand Up @@ -954,12 +959,11 @@ def train_steps(self, iterator, steps, num_grad_accumulates):
self.forward(next(iterator))
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
if self.use_horovod and self.first_batch:
self.do_broadcast()
else:
for _ in tf.range(steps):
self._replicated_step(next(iterator), (self.first_batch and _ == 0))
self._replicated_step(next(iterator))
return self.get_metrics_result()

def train_single_step_strategy(self, iterator, num_grad_accumulates):
Expand Down
2 changes: 1 addition & 1 deletion deepray/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = "0"
_MINOR_VERSION = "21"
_PATCH_VERSION = "6"
_PATCH_VERSION = "7"

# When building releases, we can update this value on the release branch to
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official
Expand Down
Loading