Skip to content

Commit

Permalink
support multiple models to traing/save/optimize for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
fuhailin committed Oct 27, 2023
1 parent b627ca8 commit 78b7316
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 90 deletions.
66 changes: 36 additions & 30 deletions deepray/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,25 @@ def __init__(

self._model = {}
if isinstance(model, list):
if len(model) == 1:
if len(model) > 0:
self._model = {"main_model": model[0]}
elif len(model) == 2:
self._model = {"main_model": model[0], "sub_model": model[1]}
else:
for i in range(len(model)):
if i == 0:
self._model["main_model"] = model[i]
else:
if len(model) == 2:
self._model["sub_model"] = model[1]
else:
for i in range(1, len(model)):
self._model[f"sub_model{i}"] = model[i]
else:
raise ValueError("Not a reachable model.")
elif isinstance(model, dict):
self._model = model
main_keys = [k for k in model.keys() if "main" in k]
if len(main_keys) == 1:
self._model = model
else:
raise ValueError("Must set one model with key contains \"main\"")
elif isinstance(model, tf.keras.Model):
self._model = {"main_model": model}
else:
ValueError("Not a reachable model.")
raise ValueError("Not a reachable model.")

self._loss = loss
self._metrics = metrics
Expand Down Expand Up @@ -258,7 +261,7 @@ def __init__(
self.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(self.optimizer, dynamic=True)

@property
def model(self):
def main_model(self):
"""
Returns:
The main model
Expand Down Expand Up @@ -584,12 +587,12 @@ def fit(
self.loss_container = self._loss
else:
self.loss_container = compile_utils.LossesContainer(
self._loss, self._loss_weights, output_names=self.model.output_names
self._loss, self._loss_weights, output_names=self.main_model.output_names
)
self.metric_container = compile_utils.MetricsContainer(
self._metrics,
self._weighted_metrics,
output_names=self.model.output_names,
output_names=self.main_model.output_names,
# from_serialized=from_serialized,
) if self._metrics or self._weighted_metrics else None

Expand All @@ -611,7 +614,10 @@ def fit(
if FLAGS.init_checkpoint:
for (name, ckpt), init_ckpt in zip(self._checkpoints.items(), FLAGS.init_checkpoint):
if init_ckpt:
latest_checkpoint = tf.train.latest_checkpoint(init_ckpt)
if tf.io.gfile.isdir(init_ckpt):
latest_checkpoint = tf.train.latest_checkpoint(init_ckpt)
else:
latest_checkpoint = init_ckpt
logging.info(
f'Checkpoint file {latest_checkpoint} found and restoring from initial checkpoint for {name} model.'
)
Expand All @@ -636,7 +642,7 @@ def fit(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=self.model,
model=self.main_model,
verbose=verbose,
epochs=self.epochs,
steps=self.steps_per_epoch * self.epochs,
Expand All @@ -649,7 +655,7 @@ def fit(
else:
opt = self.optimizer

self.model.compile(
self.main_model.compile(
optimizer=opt,
loss=self._loss,
loss_weights=self._loss_weights,
Expand Down Expand Up @@ -680,7 +686,7 @@ def fit(

# Horovod: write logs on worker 0.
verbose = 2 if is_main_process() else 0
history = self.model.fit(
history = self.main_model.fit(
train_input,
epochs=self.epochs,
steps_per_epoch=self.steps_per_epoch if self.steps_per_epoch else None,
Expand All @@ -703,7 +709,7 @@ def run_customized_training_loop(
self.current_step = self._first_steps = self.optimizer.iterations.numpy()

self.first_batch = True
if not hasattr(self.model, 'optimizer'):
if not hasattr(self.main_model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
# if self.sub_model_export_name and self.sub_model is None:
Expand Down Expand Up @@ -750,8 +756,8 @@ def run_customized_training_loop(
self.first_batch = False
self.on_batch_end(training_logs, steps, t0)
self.on_epoch_end(epoch, self.current_step, eval_input, epoch_logs=training_logs)
if self.model.stop_training:
logging.info(f"self.model.stop_training = {self.model.stop_training}")
if self.main_model.stop_training:
logging.info(f"self.model.stop_training = {self.main_model.stop_training}")
break
self.callbacks.on_train_end(logs=training_logs)

Expand All @@ -764,7 +770,7 @@ def run_customized_training_loop(
export.export_to_checkpoint(self.manager, self.current_step)
if is_main_process():
training_summary = {'total_training_steps': self.current_step}
if self.metric_container.metrics:
if self.loss_container:
training_summary['train_loss'] = self._float_metric_value(self.loss_container.metrics[0])

if self.metric_container and self.metric_container.metrics:
Expand Down Expand Up @@ -797,7 +803,7 @@ def run_customized_training_loop(
dllogging.logger.log(step=(), data={"total_loss": training_summary['train_loss']}, verbosity=Verbosity.DEFAULT)
dllogging.logger.log(data=results_perf, step=tuple())

return self.model
return self.main_model

def train_single_step(self, iterator, num_grad_accumulates):
"""Performs a distributed training step.
Expand All @@ -814,7 +820,7 @@ def train_single_step(self, iterator, num_grad_accumulates):
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.model.variables, 0)
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
else:
self._replicated_step(iterator, self.first_batch)
Expand All @@ -823,15 +829,15 @@ def train_single_step(self, iterator, num_grad_accumulates):
@property
def trainable_variables(self):
if hasattr(self.loss_container, 'trainable_variables'):
return self.model.trainable_variables + self.loss_container.trainable_variables
return self.main_model.trainable_variables + self.loss_container.trainable_variables
else:
return self.model.trainable_variables
return self.main_model.trainable_variables

def _replicated_step(self, inputs, first_batch=False):
"""Replicated training step."""
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
with tf.GradientTape() as tape:
model_outputs = self.model(inputs, training=True)
model_outputs = self.main_model(inputs, training=True)
loss = self.loss_container(labels, model_outputs, sample_weight=sample_weight)

if self.use_horovod and not FLAGS.use_dynamic_embedding:
Expand All @@ -843,7 +849,7 @@ def _replicated_step(self, inputs, first_batch=False):

if self.use_horovod and first_batch:
broadcast_vars = [
var for var in self.model.variables
var for var in self.main_model.variables
if (not isinstance(var, TrainableWrapper)) and (not isinstance(var, DEResourceVariable))
]
hvd.broadcast_variables(broadcast_vars, root_rank=0)
Expand All @@ -861,7 +867,7 @@ def _replicated_step(self, inputs, first_batch=False):
def forward(self, inputs):
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
with tf.GradientTape() as tape:
model_outputs = self.model(inputs, training=True)
model_outputs = self.main_model(inputs, training=True)
loss = self.loss_container(labels, model_outputs, sample_weight=sample_weight)

# Compute gradients
Expand Down Expand Up @@ -897,7 +903,7 @@ def predict_step(self, iterator):
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
model_outputs = self.model(inputs, training=False)
model_outputs = self.main_model(inputs, training=False)
if labels is not None and self.metric_container:
self.metric_container.update_state(labels, model_outputs, sample_weight=sample_weight)
return model_outputs
Expand Down Expand Up @@ -949,7 +955,7 @@ def train_steps(self, iterator, steps, num_grad_accumulates):
if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
self.step(num_grad_accumulates)
if self.use_horovod and _ == 0 and self.first_batch:
hvd.broadcast_variables(self.model.variables, 0)
hvd.broadcast_variables(self.main_model.variables, 0)
hvd.broadcast_variables(self.optimizer.variables(), 0)
else:
for _ in tf.range(steps):
Expand Down
2 changes: 1 addition & 1 deletion deepray/datasets/wikicorpus_en/processing/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _resumable_file_manager():
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
logging.warning('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
Expand Down
2 changes: 1 addition & 1 deletion deepray/utils/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .export import SavedModel, TFTRTModel, export_to_savedmodel, export_to_checkpoint
from .export import SavedModel, TFTRTModel, export_to_savedmodel, export_to_checkpoint, optimize_for_inference
58 changes: 54 additions & 4 deletions deepray/utils/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import os
import re
from typing import Optional, Union, Dict, Text
import tempfile
from typing import Optional, Union, Dict, Text, List

import tensorflow as tf
from absl import logging, flags
from keras.engine import data_adapter
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
Expand Down Expand Up @@ -84,7 +86,7 @@ def export_to_savedmodel(
savedmodel_dir: Optional[Text] = None,
checkpoint_dir: Optional[Union[Text, Dict[Text, Text]]] = None,
restore_model_using_load_weights: bool = False
) -> None:
) -> Text:
"""Export keras model for serving which does not include the optimizer.
Arguments:
Expand Down Expand Up @@ -148,11 +150,46 @@ def helper(name, _model: tf.keras.Model, _checkpoint_dir):
if is_main_process():
logging.info(f"save pb model to: {_savedmodel_dir}, without optimizer & traces")

return _savedmodel_dir

if isinstance(model, dict):
ans = []
for name, _model in model.items():
helper(name, _model, _checkpoint_dir=checkpoint_dir[name] if checkpoint_dir else None)
_dir = helper(name, _model, _checkpoint_dir=checkpoint_dir[name] if checkpoint_dir else None)
ans.append(_dir)
prefix_path = longestCommonPrefix(ans)
logging.info(f"Export multiple models to {prefix_path}*")
return prefix_path
else:
helper(name="main", _model=model, _checkpoint_dir=checkpoint_dir)
return helper(name="main", _model=model, _checkpoint_dir=checkpoint_dir)


def optimize_for_inference(
model: Union[tf.keras.Model, Dict[Text, tf.keras.Model]],
dataset: tf.data.Dataset,
savedmodel_dir: Text,
) -> None:
x, y, z = data_adapter.unpack_x_y_sample_weight(next(iter(dataset)))
preds = model(x)
logging.info(preds)

def helper(_model, path):
tmp_path = tempfile.mkdtemp(dir='/tmp/')
export_to_savedmodel(_model, savedmodel_dir=tmp_path)
file = os.path.join(path, "saved_model.pb")
if tf.io.gfile.exists(file):
tf.io.gfile.remove(file)
logging.info(f"Replace optimized saved_modle.pb for {file}")
tf.io.gfile.copy(os.path.join(tmp_path + "_main", "saved_model.pb"), file, overwrite=True)
else:
raise FileNotFoundError(f"{file} does not exist.")

if isinstance(model, dict):
for name, _model in model.items():
src = savedmodel_dir + name
helper(_model, src)
else:
helper(model, savedmodel_dir)


class SavedModel:
Expand Down Expand Up @@ -219,3 +256,16 @@ def __call__(self, x, **kwargs):
def infer_step(self, x):
output = self.graph_func(**x)
return output


def longestCommonPrefix(strs: List[str]) -> str:
if not strs:
return ""

length, count = len(strs[0]), len(strs)
for i in range(length):
c = strs[0][i]
if any(i == len(strs[j]) or strs[j][i] != c for j in range(1, count)):
return strs[0][:i]

return strs[0]
1 change: 1 addition & 0 deletions docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ docker pull hailinfufu/deepray-dev:latest-py${PY_VERSION}-tf${TF_VERSION}-cu116-

docker run --gpus all -it \
--rm=true \
--name="deepray_dev" \
-w /workspaces \
--volume=dev-build:/workspaces \
--shm-size=1g \
Expand Down
2 changes: 2 additions & 0 deletions modelzoo/Recommendation/criteo_ctr/optimize_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def main(_):
tf.io.gfile.remove(file)
logging.info(f"Replace optimized saved_modle.pb for {file}")
tf.io.gfile.copy(os.path.join(tmp_path + "_main", "saved_model.pb"), file, overwrite=True)
else:
raise FileNotFoundError(f"{file} does not exist.")


if __name__ == "__main__":
Expand Down
9 changes: 0 additions & 9 deletions modelzoo/Recommendation/criteo_ctr/run_horovod.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,5 @@ $hvd_command $nsys_command python train.py \
$use_hvd $use_fp16 $use_xla_tag
set +x

# if [ $num_gpu -gt 1 ]; then
# python optimize_for_inference.py \
# --feature_map=feature_map_small.csv \
# --use_dynamic_embedding=True \
# --model_dir=${RESULTS_DIR} \
# --distribution_strategy=off \
# $use_fp16 $use_xla_tag
# fi

# --init_checkpoint=/results/tf_tfra_training_criteo_dcn_fp32_gbs4096_231018053444/ckpt_main_model/ \
# --init_weights="/results/tf_tfra_training_criteo_dcn_fp32_gbs16384_231016072901/export_main/variables" \
Loading

0 comments on commit 78b7316

Please sign in to comment.