diff --git a/rfcs/20201121-keras-model-fit-ps.md b/rfcs/20201121-keras-model-fit-ps.md index 63702e8ae..9a4d7c579 100644 --- a/rfcs/20201121-keras-model-fit-ps.md +++ b/rfcs/20201121-keras-model-fit-ps.md @@ -10,9 +10,11 @@ ## Background -With the recent release of TF2 parameter server training support ([ddoc](https://github.com/tensorflow/community/blob/master/rfcs/20200306-single-client-parameter-server.md)) ([API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/parameter_server_strategy_v2.py)) ([tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)), custom training loop (CTL) users have started using the `ParameterServerStrategy` and `ClusterCoordinator` APIs for parameter server style distributed training. `ParameterServerStrategy` provides implementation of variable placement, and APIs for defining computation, and `ClusterCoordinator` provides APIs for dataset creation, asynchronous function scheduling and remote execution. The asynchronicity brought by `ClusterCoordinator` provides scalability and training fault tolerance, and at the same time implications such as the need for remote resource creation. +With the recent release of TF2 parameter server training support ([API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/parameter_server_strategy_v2.py)) ([tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)), custom training loop (CTL) users have started using the `tf.distribute.experimental.ParameterServerStrategy` and `tf.distribute.experimental.coordinator.ClusterCoordinator` APIs for parameter server style distributed training. `ParameterServerStrategy` provides implementation of variable placement, and APIs for defining computation, and `ClusterCoordinator` provides APIs for dataset creation, asynchronous function scheduling and remote execution. The asynchronicity brought by `ClusterCoordinator` provides scalability and training fault tolerance, and at the same time implications such as the need for remote resource creation. -While CTL user flow has since been supported, Keras `model.fit` training API is not yet. It has been a common ask (as shown in a survey conducted earlier this year) for availability, given its simplicity and support for a variety of machine learning models, metrics, optimizers, etc. +TF2 parameter server training is based on one coordinator task, multiple workers, and multiple (usually fewer than workers) parameter servers (referred to as "ps"). Workers and parameter servers run TensorFlow servers, while the coordinator creates resources on workers and parameter servers, dispatches functions, coordinates the training amd writes checkpoints etc. + +While CTL user flow has been supported since the release of TF 2.4, Keras `model.fit` training API is not yet. It has been a common ask (as shown in a survey conducted earlier this year) for availability, given its simplicity and support for a variety of machine learning models, metrics, optimizers, etc. In this design, we will discuss the changes in `model.fit` API that we expect to make to accommodate asynchronous, coordinator-based parameter server training flow, and challenges the integration may have given the historical focus of synchronous distributed training with `model.fit`. @@ -26,6 +28,15 @@ In this design, we will discuss the changes in `model.fit` API that we expect to * Minimal performance implications +## Glossary + +* Parameter server training: a distributed training method utilizing multiple machines to scale up model training, utilizing a set of training workers, and a set of parameter servers that store the training variables. + +* Coordinator: A task (referring to a program run on a dedicated machine) where the python program creates variables on parameter servers, defines functions to be executed on remote workers, and controls the training via `ParameterServerStrategy` and `ClusterCoordinator` APIs. + +* `model.fit`: A Keras API for running epoch and step based training loops, with user-provided optimizers, metrics, loss, and callbacks etc. + + ## Proposed options and solutions Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies (except for the strategy swap). Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator. @@ -44,13 +55,13 @@ with strategy.scope(): model = ... # Building a Keras model model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created def dataset_fn(): - ... # Make use of `preproc_stage` for transformation + return tf.data.Dataset.X... # Make use of `preproc_stage` for transformation history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...]) logging.info("result: %r", history) ``` -with a dataset: +with a dataset instance: ``` @@ -60,8 +71,9 @@ with strategy.scope(): preproc_stage = ... # Some Keras preproc layers model = ... # Building a Keras model model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created -dataset = tf.data.Dataset.... # Make use of `preproc_stage` for transformation -history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...]) +dataset = tf.data.Dataset.X... # Make use of `preproc_stage` for transformation +# model.fit serializes and deserializes dataset onto workers +history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...]) logging.info("result: %r", history) ``` @@ -81,7 +93,7 @@ Previous discussion indicates that although an API modification is needed, the s ##### `dataset_fn` path -In TF2 parameter server training, `ClusterCoordinator` naturally supports a dataset function to be passed in to `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see `DataHandler` section below for proposed changes. +In TF2 parameter server training, `ClusterCoordinator` naturally supports a dataset function to be passed in to `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see `DataHandler` section below for proposed changes. In terms of how users pass a dataset factory into `model.fit`, there are a couple of options: @@ -97,6 +109,11 @@ def dataset_fn(): history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...]) ``` +Pros: +* `callable` does not require users to use additional APIs and may be less overhead. + +Cons: +* Less future proof as there could be different intepretation of callable passed as `dataset` to `model.fit` in the future. ###### Option 2: dataset factory @@ -131,6 +148,12 @@ class DataFactory(object): return self.x(*args, **kwargs) ``` +Pros: +* If `dataset` has a different intepretation, for example it takes an argument instead of none, we get a adapting layer with a `DatasetFactory`. + +Cons: +* This requires users to use an additional symbol. + The following discussion is tentatively based on option 1, where a simple callable is taken.