Skip to content

Commit

Permalink
Merge pull request #34 from deepray-AI/hotfix
Browse files Browse the repository at this point in the history
[Fix] Fix a vital performance bug when run in graph mode deepray will…
  • Loading branch information
fuhailin authored Nov 6, 2023
2 parents 9936eb6 + 6c9e9a2 commit 7505906
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 69 deletions.
131 changes: 73 additions & 58 deletions deepray/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import json
import os
import sys
import time
from typing import Union, List, Dict, Text

Expand Down Expand Up @@ -202,8 +203,9 @@ def __init__(
use_horovod=None,
run_eagerly=None,
jit_compile=None,
**kwargs,
**kwargs
):
super().__init__(**kwargs)
self.strategy = distribution_utils.get_distribution_strategy()

self._model = {}
Expand Down Expand Up @@ -260,6 +262,22 @@ def __init__(
if self.use_float16:
self.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(self.optimizer, dynamic=True)

with distribution_utils.get_strategy_scope(self.strategy):
# To correctly place the model weights on accelerators,
# model should be created in scope.
if isinstance(self._loss, compile_utils.LossesContainer):
self.loss_container = self._loss
else:
self.loss_container = compile_utils.LossesContainer(
self._loss, self._loss_weights, output_names=self.main_model.output_names
)
self.metric_container = compile_utils.MetricsContainer(
self._metrics,
self._weighted_metrics,
output_names=self.main_model.output_names,
# from_serialized=from_serialized,
) if self._metrics or self._weighted_metrics else None

@property
def main_model(self):
"""
Expand Down Expand Up @@ -531,6 +549,7 @@ def fit(
and what the model expects or when the input data is empty.
"""
self.steps_per_epoch = steps_per_epoch if steps_per_epoch else -1
self.validation_steps = eval_steps
if FLAGS.benchmark or FLAGS.stop_steps >= 0:
if FLAGS.stop_steps >= 0:
self.steps_per_epoch = FLAGS.stop_steps
Expand Down Expand Up @@ -580,22 +599,6 @@ def fit(
self.train_summary_writer = None
eval_input_fn = None

with distribution_utils.get_strategy_scope(self.strategy):
# To correctly place the model weights on accelerators,
# model should be created in scope.
if isinstance(self._loss, compile_utils.LossesContainer):
self.loss_container = self._loss
else:
self.loss_container = compile_utils.LossesContainer(
self._loss, self._loss_weights, output_names=self.main_model.output_names
)
self.metric_container = compile_utils.MetricsContainer(
self._metrics,
self._weighted_metrics,
output_names=self.main_model.output_names,
# from_serialized=from_serialized,
) if self._metrics or self._weighted_metrics else None

self._checkpoints, self._managers = {}, {}
for name, model in self._model.items():
if "main" in name:
Expand Down Expand Up @@ -708,7 +711,9 @@ def run_customized_training_loop(
# Training loop starts here.
self.current_step = self._first_steps = self.optimizer.iterations.numpy()

self.first_batch = True
if self.use_horovod:
with tf.init_scope():
self.first_batch = tf.Variable(True, trainable=False, dtype=tf.bool, name='first_batch')
if not hasattr(self.main_model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
Expand Down Expand Up @@ -749,11 +754,9 @@ def run_customized_training_loop(
elif steps > 1 and self.optimizer.iterations.numpy() > self.current_step:
steps = self.optimizer.iterations.numpy() - self.current_step
training_logs = self.get_metrics_result()
self.first_batch = False
self.on_batch_end(training_logs, steps, t0)
break

self.first_batch = False
self.on_batch_end(training_logs, steps, t0)
self.on_epoch_end(epoch, self.current_step, eval_input, epoch_logs=training_logs)
if self.main_model.stop_training:
Expand Down Expand Up @@ -819,11 +822,10 @@ def train_single_step(self, iterator, num_grad_accumulates):
self.forward(iterator)
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
if self.use_horovod and self.first_batch:
self.do_broadcast()
else:
self._replicated_step(iterator, self.first_batch)
self._replicated_step(iterator)
return self.get_metrics_result()

@property
Expand All @@ -833,12 +835,30 @@ def trainable_variables(self):
else:
return self.main_model.trainable_variables

def _replicated_step(self, inputs, first_batch=False):
def do_broadcast(self):
model_broadcast_vars = [
var for var in self.main_model.variables
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
opt_broadcast_vars = [
var for var in self.optimizer.variables()
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]

print_op = tf.print(
f"Broadcasting {len(model_broadcast_vars)} model variables & {len(opt_broadcast_vars)} optimizer variables...",
output_stream=sys.stdout
)
with tf.control_dependencies([print_op]):
hvd.broadcast_variables(model_broadcast_vars + opt_broadcast_vars, root_rank=0)
self.first_batch.assign(False)

def _replicated_step(self, inputs):
"""Replicated training step."""
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
with tf.GradientTape() as tape:
model_outputs = self.main_model(inputs, training=True)
loss = self.loss_container(labels, model_outputs, sample_weight=sample_weight)
model_outputs = self.main_model(x, training=True)
loss = self.loss_container(y, model_outputs, sample_weight=sample_weight)

if self.use_horovod and not FLAGS.use_dynamic_embedding:
tape = hvd.DistributedGradientTape(
Expand All @@ -847,28 +867,18 @@ def _replicated_step(self, inputs, first_batch=False):
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)

if self.use_horovod and first_batch:
broadcast_vars = [
var for var in self.main_model.variables
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
hvd.broadcast_variables(broadcast_vars, root_rank=0)

opt_broadcast_vars = [
var for var in self.optimizer.variables()
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
hvd.broadcast_variables(opt_broadcast_vars, root_rank=0)
if self.use_horovod and self.first_batch:
self.do_broadcast()

# For reporting, the metric takes the mean of losses.
if self.metric_container:
self.metric_container.update_state(y_true=labels, y_pred=model_outputs, sample_weight=sample_weight)
self.metric_container.update_state(y_true=y, y_pred=model_outputs, sample_weight=sample_weight)

def forward(self, inputs):
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
with tf.GradientTape() as tape:
model_outputs = self.main_model(inputs, training=True)
loss = self.loss_container(labels, model_outputs, sample_weight=sample_weight)
model_outputs = self.main_model(x, training=True)
loss = self.loss_container(y, model_outputs, sample_weight=sample_weight)

# Compute gradients
if version.parse(tf.keras.__version__.replace("-tf", "+tf")) < version.parse("2.11"):
Expand All @@ -880,7 +890,7 @@ def forward(self, inputs):

# For reporting, the metric takes the mean of losses.
if self.metric_container:
self.metric_container.update_state(y_true=labels, y_pred=model_outputs, sample_weight=sample_weight)
self.metric_container.update_state(y_true=y, y_pred=model_outputs, sample_weight=sample_weight)

def step(self, num_grad_accumulates):
gradients = self.accum_gradients.gradients
Expand All @@ -897,25 +907,31 @@ def step(self, num_grad_accumulates):
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.accum_gradients.reset()

def predict_step(self, iterator):
def forward_step(self, iterator):
"""Calculates evaluation metrics on distributed devices."""

def _test_step_fn(inputs):
def _forward_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
model_outputs = self.main_model(inputs, training=False)
if labels is not None and self.metric_container:
self.metric_container.update_state(labels, model_outputs, sample_weight=sample_weight)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
model_outputs = self.main_model(x, training=False)
if y is not None and self.metric_container:
if self.use_horovod:
y = hvd.allgather(y)
model_outputs = hvd.allgather(model_outputs)
if sample_weight:
sample_weight = hvd.allgather(sample_weight)

self.metric_container.update_state(y, model_outputs, sample_weight=sample_weight)
return model_outputs

def tuple_fun(x):
return x,

if self.strategy:
outputs = self.strategy.run(_test_step_fn, args=(iterator,))
outputs = self.strategy.run(_forward_step_fn, args=(iterator,))
map_func = self.strategy.experimental_local_results
else:
outputs = _test_step_fn(iterator)
outputs = _forward_step_fn(iterator)
map_func = tuple_fun
return tf.nest.map_structure(map_func, outputs)

Expand Down Expand Up @@ -954,12 +970,11 @@ def train_steps(self, iterator, steps, num_grad_accumulates):
self.forward(next(iterator))
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
if self.use_horovod and self.first_batch:
self.do_broadcast()
else:
for _ in tf.range(steps):
self._replicated_step(next(iterator), (self.first_batch and _ == 0))
self._replicated_step(next(iterator))
return self.get_metrics_result()

def train_single_step_strategy(self, iterator, num_grad_accumulates):
Expand All @@ -984,7 +999,7 @@ def make_train_function(self):
if not self.run_eagerly:
_train_single_step = tf.function(self.train_single_step)
_train_multi_steps = tf.function(self.train_steps)
self.predict_step = tf.function(self.predict_step)
self.forward_step = tf.function(self.forward_step)
else:
_train_single_step = self.train_single_step
_train_multi_steps = self.train_steps
Expand Down
37 changes: 30 additions & 7 deletions deepray/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

class Module():

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.validation_steps = None

def steps_to_run(self, current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
Expand Down Expand Up @@ -112,7 +116,7 @@ def on_epoch_end(self, epoch, current_step, eval_input, epoch_logs=None):
if self.metric_container:
self.metric_container.reset_state()

val_logs = self.run_evaluation(eval_input, current_step)
val_logs = self.run_evaluation(eval_input, self.validation_steps)
val_logs = {'val_' + name: val for name, val in val_logs.items()}
epoch_logs.update(val_logs)

Expand All @@ -126,19 +130,38 @@ def on_epoch_end(self, epoch, current_step, eval_input, epoch_logs=None):
"""
self.callbacks.on_epoch_end(epoch, epoch_logs)

def run_evaluation(self, eval_input, current_training_step=0):
def run_evaluation(self, eval_input, validation_steps=None):
if validation_steps is None:
if self.validation_steps is not None:
validation_steps = self.validation_steps
else:
if self.validation_steps is None:
self.validation_steps = validation_steps
"""Runs validation steps and aggregate metrics."""
if not isinstance(eval_input, Iterator):
eval_input = distribution_utils.make_distributed_iterator(self.strategy, eval_input)

step_num = 0
while 1:
current_step = 0
while validation_steps is None or current_step < validation_steps:
try:
self.predict_step(next(eval_input))
step_num += 1
t0 = time.time()
for _ in tf.range(FLAGS.steps_per_summary):
self.forward_step(next(eval_input))
current_step += 1
elapse_time = time.time() - t0
# Updates validing logging.
if validation_steps is None:
training_status = 'Valid Step: %d / time=%.3f sec' % (current_step, elapse_time)
else:
training_status = 'Valid Step: %d/%d / time=%.3f sec' % (current_step, validation_steps, elapse_time)
for key, value in self.get_metrics_result().items():
metric_value = value.numpy().astype(float)
training_status += ' %s=%f' % (key, metric_value)
logging.info(training_status)
except (tf.errors.OutOfRangeError, StopIteration):
self.validation_steps = current_step
if is_main_process():
logging.info('Data exhausted after %d steps', step_num)
logging.info('Data exhausted after %d steps', current_step)
break

return self.get_metrics_result()
Expand Down
1 change: 0 additions & 1 deletion deepray/utils/flags/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def define_base(
"specified, the training batch size (--batch_size) will be used."
)
)
flags.DEFINE_integer("predict_batch_size", 8, 'Total batch size for prediction.')
key_flags.append("batch_size")

if num_gpus:
Expand Down
1 change: 0 additions & 1 deletion deepray/utils/flags/common_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def define_common_flags():
)
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("prebatch", 1, "prebatch size for tfrecord")
flags.DEFINE_list("label", [], "label name")
flags.DEFINE_string("feature_map", os.path.join(os.getcwd(), "business/data/feature_map.csv"), "path to feature_map")
flags.DEFINE_string("black_list", None, "black list for feature_map")
flags.DEFINE_string("white_list", None, "white list for feature_map")
Expand Down
2 changes: 1 addition & 1 deletion deepray/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = "0"
_MINOR_VERSION = "21"
_PATCH_VERSION = "6"
_PATCH_VERSION = "7"

# When building releases, we can update this value on the release branch to
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official
Expand Down
2 changes: 1 addition & 1 deletion modelzoo/LanguageModeling/BERT/run_squad_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def tuple_fun(x):
elapsed_secs = 0

for _ in range(num_steps):
predictions = trainer.predict_step(next(predict_iterator))
predictions = trainer.forward_step(next(predict_iterator))
if FLAGS.benchmark:
# transfer tensor to CPU for synchronization
t0 = predictions['unique_ids'][0]
Expand Down

0 comments on commit 7505906

Please sign in to comment.