Skip to content

Commit

Permalink
Tweak defaults for quantized-typed FP8 TE weights (#3018)
Browse files Browse the repository at this point in the history
* Tweak defaults

* Can't forget about CLI

* Update docs
  • Loading branch information
muellerzr authored Aug 19, 2024
1 parent 589fddd commit 7ec8eab
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/concept_guides/low_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The `TransformerEngine` can receive many different arguments that customize how

* `margin`: The margin to use for the gradient scaling.
* `interval`: The interval to use for how often the scaling factor is recomputed.
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, `E4M3` for evaluation)
* `amax_history_len`: The length of the history to use for the scaling factor computation
* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/usage_guides/low_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: E4M3
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
Expand Down Expand Up @@ -117,7 +117,7 @@ fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: E4M3
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ def get_cluster_input():
)
fp8_config["fp8_format"] = _ask_options(
"Which weight format should be used?",
["E4M3", "HYBRID"],
lambda x: "E4M3" if x == 0 else "HYBRID",
["HYBRID", "E4M3"],
lambda x: "HYBRID" if x == 0 else "E4M3",
default=0,
)
fp8_config["amax_history_length"] = _ask_field(
Expand Down
7 changes: 4 additions & 3 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,9 @@ class FP8RecipeKwargs(KwargsHandler):
The margin to use for the gradient scaling.
interval (`int`, *optional*, default to 1):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, default to "E4M3"):
The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
fp8_format (`str`, *optional*, default to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
`E4M3` for evaluation)
amax_history_len (`int`, *optional*, default to 1024):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, default to "most_recent"):
Expand Down Expand Up @@ -364,7 +365,7 @@ def __post_init__(self):
if self.interval is None:
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
if self.fp8_format is None:
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3")
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
self.fp8_format = self.fp8_format.upper()
if self.fp8_format not in get_args(FP8Format):
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
Expand Down

0 comments on commit 7ec8eab

Please sign in to comment.