diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9f32da16fd7972..8f1ad56f6999df 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model( for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + for param_name, param in state_dict.items(): # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if param_name not in loaded_state_dict_keys or param_name not in expected_keys: @@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model( module_name = param_name set_module_kwargs = {} - # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. - if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn: + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: if ( keep_in_fp32_modules is not None and any(