Skip to content

Commit

Permalink
Merge pull request tensorflow#2 from omalleyt12/rick_ps
Browse files Browse the repository at this point in the history
Callback changes
  • Loading branch information
rchao authored Nov 24, 2020
2 parents ac4175a + 84ec9f4 commit 8919853
Showing 1 changed file with 12 additions and 27 deletions.
39 changes: 12 additions & 27 deletions rfcs/20201121-keras-model-fit-ps.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,20 @@ class DataFactoryAdapter(DataAdapter):

#### Callbacks

With `ParameterServerStrategy`, the return value of `Model.train_function` is a dict `RemoteValue`s. This dict is passed as the `logs` argument to the `CallbackList` object. To obtain the `NumPy` value of each `RemoteValue`, one can do:
With `ParameterServerStrategy`, the return value of `Model.train_function` is a dict `RemoteValue`s. This dict is passed as the `logs` argument to the `CallbackList` object. The `CallbackList` object relies on the `tf_utils.to_numpy_or_python_type` utility to convert these `logs` into `NumPy` values. We will extend the logic of this utility to support `ParameterServerStrategy`. The utility will sync the workers and fetch the `NumPy` value from the `RemoteValue`s:


```
def to_numpy(logs):
cluster_coordinator.join() # Sync the workers.
return {k: v.fetch() for k, v in logs.items()} # Return the NumPy results
def to_numpy_or_python_type(logs):
if isinstance(logs, RemoteValue):
get_strategy().cluster_coordinator.join() # Sync the workers.
return logs.fetch() # Return the NumPy results.
else:
... # Existing logic.
```


This logic can be handled by the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. However, obtaining these logs requires us to sync the workers, which will result in a slowdown if done every batch. Currently we only plan to sync the workers once every epoch at the end of the epoch.
This utility is only used in the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. User-defined `Callback`s do not have to be aware of this logic and will not need changes to support `ParameterServerStrategy`.


##### Epoch-level callbacks
Expand All @@ -365,29 +368,11 @@ Since the workers will sync every epoch anyway, fetching the remote values incur

##### Batch-level callbacks

Some users may want to use batch-level `Callback`s. When users use `steps_per_execution=N`, the `Callback`s will only execute every `N` steps, and so batch-level callbacks might not be prohibitively slow for large `N`. However, in most cases, batch-level callbacks will cause a significant slowdown and are likely to be added only in error. We have a few options for handling batch-level Callbacks.
Some users may want to use batch-level `Callback`s. When users use `steps_per_execution=N`, the `Callback`s will only execute every `N` steps, and so batch-level callbacks will not be prohibitively slow for large `N`. However in other cases, batch-level callbacks may cause a significant slowdown.

We will support batch-level `Callback`s, but we will use existing logic in `CallbackList` to detect when batch-level `Callback`s are passed, and only incur the performance penalty of syncing workers each batch when the user has passed batch-level `Callback`s (for context, none of Keras's built-in `Callbacks` other than the `ProgressBar` will incur this penalty). This logic was originally put in place to ensure that TPU async mode was only blocked when needed, and applies equally well to `ParameterServerStrategy` without significant modifications.

###### Option 1: Detect and support batch-level callbacks

We have a mechanism in `CallbackList` to detect when batch-level callbacks are being used. This mechanism will only block asynchronous TPU execution when batch-level callbacks require it. We can use this mechanism to also only block asynchronous PSStrategy execution when batch-level callbacks require it. This would allow us to support batch-level callbacks, without paying a performance penalty in the case where they are not used.


###### Option 2: Detect, warn, and support batch-level callbacks



This case is the same as above, except we warn the user when batch-level callbacks are used with PSStrategy. This is because it is likely an error that the user passed a batch-level callback, due to the performance penalty that is possible with these callbacks.


###### Option 3: Detect, error, set toggle, and support batch-level callbacks

In this case, we error out when batch-level callbacks are passed, unless the user has explicitly specified via some setting that they want to use batch-level callbacks.


###### Option 4: Detect and disallow batch-level callbacks

In this case, we do not support batch-level callbacks at all.
We will also re-use existing logic to log a warning to the user when their batch-level `Callback`s are causing a significant slowdown in training time. This logic also resides in the `CallbackList` object.


#### Metrics variables
Expand Down Expand Up @@ -521,4 +506,4 @@ If the assumption that model.fit user code remains the same across strategies (f
* Schedule design review (ETA: Early Dec)
* User model testing (ETA: Dec)
* Aligned design with approvals on this doc (ETA: End of Dec)
* Demonstrable working prototype (ETA: End of Dec)
* Demonstrable working prototype (ETA: End of Dec)

0 comments on commit 8919853

Please sign in to comment.