Skip to content

Commit

Permalink
First round of comment addressing
Browse files Browse the repository at this point in the history
  • Loading branch information
rchao committed Nov 24, 2020
1 parent 8e106dc commit ac4175a
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions rfcs/20201121-keras-model-fit-ps.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

## Background

With the recent release of TF2 parameter server training support ([ddoc](https://github.com/tensorflow/community/blob/master/rfcs/20200306-single-client-parameter-server.md)) ([API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/parameter_server_strategy_v2.py)) ([tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)), custom training loop (CTL) users have started using the `ParameterServerStrategy` and `ClusterCoordinator` APIs for parameter server style distributed training. `ParameterServerStrategy` provides implementation of variable placement, and APIs for defining computation, and `ClusterCoordinator` provides APIs for dataset creation, asynchronous function scheduling and remote execution. The asynchronicity brought by `ClusterCoordinator` provides scalability and training fault tolerance, and at the same time implications such as the need for remote resource creation.
With the recent release of TF2 parameter server training support ([API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/parameter_server_strategy_v2.py)) ([tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)), custom training loop (CTL) users have started using the `tf.distribute.experimental.ParameterServerStrategy` and `tf.distribute.experimental.coordinator.ClusterCoordinator` APIs for parameter server style distributed training. `ParameterServerStrategy` provides implementation of variable placement, and APIs for defining computation, and `ClusterCoordinator` provides APIs for dataset creation, asynchronous function scheduling and remote execution. The asynchronicity brought by `ClusterCoordinator` provides scalability and training fault tolerance, and at the same time implications such as the need for remote resource creation.

While CTL user flow has since been supported, Keras `model.fit` training API is not yet. It has been a common ask (as shown in a survey conducted earlier this year) for availability, given its simplicity and support for a variety of machine learning models, metrics, optimizers, etc.
TF2 parameter server training is based on one coordinator task, multiple workers, and multiple (usually fewer than workers) parameter servers (referred to as "ps"). Workers and parameter servers run TensorFlow servers, while the coordinator creates resources on workers and parameter servers, dispatches functions, coordinates the training amd writes checkpoints etc.

While CTL user flow has been supported since the release of TF 2.4, Keras `model.fit` training API is not yet. It has been a common ask (as shown in a survey conducted earlier this year) for availability, given its simplicity and support for a variety of machine learning models, metrics, optimizers, etc.

In this design, we will discuss the changes in `model.fit` API that we expect to make to accommodate asynchronous, coordinator-based parameter server training flow, and challenges the integration may have given the historical focus of synchronous distributed training with `model.fit`.

Expand All @@ -26,6 +28,15 @@ In this design, we will discuss the changes in `model.fit` API that we expect to
* Minimal performance implications


## Glossary

* Parameter server training: a distributed training method utilizing multiple machines to scale up model training, utilizing a set of training workers, and a set of parameter servers that store the training variables.

* Coordinator: A task (referring to a program run on a dedicated machine) where the python program creates variables on parameter servers, defines functions to be executed on remote workers, and controls the training via `ParameterServerStrategy` and `ClusterCoordinator` APIs.

* `model.fit`: A Keras API for running epoch and step based training loops, with user-provided optimizers, metrics, loss, and callbacks etc.


## Proposed options and solutions

Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies (except for the strategy swap). Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator.
Expand All @@ -44,13 +55,13 @@ with strategy.scope():
model = ... # Building a Keras model
model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
def dataset_fn():
... # Make use of `preproc_stage` for transformation
return tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
logging.info("result: %r", history)
```


with a dataset:
with a dataset instance:


```
Expand All @@ -60,8 +71,9 @@ with strategy.scope():
preproc_stage = ... # Some Keras preproc layers
model = ... # Building a Keras model
model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
dataset = tf.data.Dataset.... # Make use of `preproc_stage` for transformation
history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
dataset = tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
# model.fit serializes and deserializes dataset onto workers
history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
logging.info("result: %r", history)
```

Expand All @@ -81,7 +93,7 @@ Previous discussion indicates that although an API modification is needed, the s

##### `dataset_fn` path

In TF2 parameter server training, `ClusterCoordinator` naturally supports a dataset function to be passed in to `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see `DataHandler` section below for proposed changes.
In TF2 parameter server training, `ClusterCoordinator` naturally supports a dataset function to be passed in to `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see `DataHandler` section below for proposed changes.

In terms of how users pass a dataset factory into `model.fit`, there are a couple of options:

Expand All @@ -97,6 +109,11 @@ def dataset_fn():
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
```

Pros:
* `callable` does not require users to use additional APIs and may be less overhead.

Cons:
* Less future proof as there could be different intepretation of callable passed as `dataset` to `model.fit` in the future.


###### Option 2: dataset factory
Expand Down Expand Up @@ -131,6 +148,12 @@ class DataFactory(object):
return self.x(*args, **kwargs)
```

Pros:
* If `dataset` has a different intepretation, for example it takes an argument instead of none, we get a adapting layer with a `DatasetFactory`.

Cons:
* This requires users to use an additional symbol.


The following discussion is tentatively based on option 1, where a simple callable is taken.

Expand Down

0 comments on commit ac4175a

Please sign in to comment.