diff --git a/rfcs/20201121-keras-model-fit-ps.md b/rfcs/20201121-keras-model-fit-ps.md index 9a4d7c579..1c70d1f67 100644 --- a/rfcs/20201121-keras-model-fit-ps.md +++ b/rfcs/20201121-keras-model-fit-ps.md @@ -345,17 +345,20 @@ class DataFactoryAdapter(DataAdapter): #### Callbacks -With `ParameterServerStrategy`, the return value of `Model.train_function` is a dict `RemoteValue`s. This dict is passed as the `logs` argument to the `CallbackList` object. To obtain the `NumPy` value of each `RemoteValue`, one can do: +With `ParameterServerStrategy`, the return value of `Model.train_function` is a dict `RemoteValue`s. This dict is passed as the `logs` argument to the `CallbackList` object. The `CallbackList` object relies on the `tf_utils.to_numpy_or_python_type` utility to convert these `logs` into `NumPy` values. We will extend the logic of this utility to support `ParameterServerStrategy`. The utility will sync the workers and fetch the `NumPy` value from the `RemoteValue`s: ``` -def to_numpy(logs): - cluster_coordinator.join() # Sync the workers. - return {k: v.fetch() for k, v in logs.items()} # Return the NumPy results +def to_numpy_or_python_type(logs): + if isinstance(logs, RemoteValue): + get_strategy().cluster_coordinator.join() # Sync the workers. + return logs.fetch() # Return the NumPy results. + else: + ... # Existing logic. ``` -This logic can be handled by the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. However, obtaining these logs requires us to sync the workers, which will result in a slowdown if done every batch. Currently we only plan to sync the workers once every epoch at the end of the epoch. +This utility is only used in the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. User-defined `Callback`s do not have to be aware of this logic and will not need changes to support `ParameterServerStrategy`. ##### Epoch-level callbacks @@ -365,29 +368,11 @@ Since the workers will sync every epoch anyway, fetching the remote values incur ##### Batch-level callbacks -Some users may want to use batch-level `Callback`s. When users use `steps_per_execution=N`, the `Callback`s will only execute every `N` steps, and so batch-level callbacks might not be prohibitively slow for large `N`. However, in most cases, batch-level callbacks will cause a significant slowdown and are likely to be added only in error. We have a few options for handling batch-level Callbacks. +Some users may want to use batch-level `Callback`s. When users use `steps_per_execution=N`, the `Callback`s will only execute every `N` steps, and so batch-level callbacks will not be prohibitively slow for large `N`. However in other cases, batch-level callbacks may cause a significant slowdown. +We will support batch-level `Callback`s, but we will use existing logic in `CallbackList` to detect when batch-level `Callback`s are passed, and only incur the performance penalty of syncing workers each batch when the user has passed batch-level `Callback`s (for context, none of Keras's built-in `Callbacks` other than the `ProgressBar` will incur this penalty). This logic was originally put in place to ensure that TPU async mode was only blocked when needed, and applies equally well to `ParameterServerStrategy` without significant modifications. -###### Option 1: Detect and support batch-level callbacks - -We have a mechanism in `CallbackList` to detect when batch-level callbacks are being used. This mechanism will only block asynchronous TPU execution when batch-level callbacks require it. We can use this mechanism to also only block asynchronous PSStrategy execution when batch-level callbacks require it. This would allow us to support batch-level callbacks, without paying a performance penalty in the case where they are not used. - - -###### Option 2: Detect, warn, and support batch-level callbacks - - - -This case is the same as above, except we warn the user when batch-level callbacks are used with PSStrategy. This is because it is likely an error that the user passed a batch-level callback, due to the performance penalty that is possible with these callbacks. - - -###### Option 3: Detect, error, set toggle, and support batch-level callbacks - -In this case, we error out when batch-level callbacks are passed, unless the user has explicitly specified via some setting that they want to use batch-level callbacks. - - -###### Option 4: Detect and disallow batch-level callbacks - -In this case, we do not support batch-level callbacks at all. +We will also re-use existing logic to log a warning to the user when their batch-level `Callback`s are causing a significant slowdown in training time. This logic also resides in the `CallbackList` object. #### Metrics variables @@ -521,4 +506,4 @@ If the assumption that model.fit user code remains the same across strategies (f * Schedule design review (ETA: Early Dec) * User model testing (ETA: Dec) * Aligned design with approvals on this doc (ETA: End of Dec) -* Demonstrable working prototype (ETA: End of Dec) \ No newline at end of file +* Demonstrable working prototype (ETA: End of Dec)