diff --git a/docs/source/concept_guides/low_precision_training.md b/docs/source/concept_guides/low_precision_training.md index 96278842f48..2d8df0ca4e7 100644 --- a/docs/source/concept_guides/low_precision_training.md +++ b/docs/source/concept_guides/low_precision_training.md @@ -50,7 +50,7 @@ The `TransformerEngine` can receive many different arguments that customize how * `margin`: The margin to use for the gradient scaling. * `interval`: The interval to use for how often the scaling factor is recomputed. -* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`. +* `fp8_format``: The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, `E4M3` for evaluation) * `amax_history_len`: The length of the history to use for the scaling factor computation * `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. * `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md index 415175e53bd..f8f7d83df0e 100644 --- a/docs/source/usage_guides/low_precision_training.md +++ b/docs/source/usage_guides/low_precision_training.md @@ -56,7 +56,7 @@ fp8_config: amax_compute_algorithm: max amax_history_length: 1024 backend: TE - fp8_format: E4M3 + fp8_format: HYBRID interval: 1 margin: 0 override_linear_precision: false @@ -117,7 +117,7 @@ fp8_config: amax_compute_algorithm: max amax_history_length: 1024 backend: TE - fp8_format: E4M3 + fp8_format: HYBRID interval: 1 margin: 0 override_linear_precision: false diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 5901c4ba103..0862b9c9b09 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -735,8 +735,8 @@ def get_cluster_input(): ) fp8_config["fp8_format"] = _ask_options( "Which weight format should be used?", - ["E4M3", "HYBRID"], - lambda x: "E4M3" if x == 0 else "HYBRID", + ["HYBRID", "E4M3"], + lambda x: "HYBRID" if x == 0 else "E4M3", default=0, ) fp8_config["amax_history_length"] = _ask_field( diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 919b7fadc2b..1151cd73fda 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -313,8 +313,9 @@ class FP8RecipeKwargs(KwargsHandler): The margin to use for the gradient scaling. interval (`int`, *optional*, default to 1): The interval to use for how often the scaling factor is recomputed. - fp8_format (`str`, *optional*, default to "E4M3"): - The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`. + fp8_format (`str`, *optional*, default to "HYBRID"): + The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, + `E4M3` for evaluation) amax_history_len (`int`, *optional*, default to 1024): The length of the history to use for the scaling factor computation amax_compute_algo (`str`, *optional*, default to "most_recent"): @@ -364,7 +365,7 @@ def __post_init__(self): if self.interval is None: self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) if self.fp8_format is None: - self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3") + self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") self.fp8_format = self.fp8_format.upper() if self.fp8_format not in get_args(FP8Format): raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")