Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Fix fp32 dtype for using newest paddle #9360

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions paddlenlp/trainer/unified_checkpoint/check_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import flatten_list

try:
from paddle.base import core
except:
core = None

from .utils import (
get_expected_state_dict,
is_sharding_split_param_mode,
Expand Down Expand Up @@ -200,7 +195,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False,
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
continue

if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32:
if is_master_weights and state_dict[key].dtype == paddle.float32:
continue

if not is_master_weights:
Expand Down
7 changes: 1 addition & 6 deletions paddlenlp/trainer/unified_checkpoint/load_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import _load_state_dict_into_model
from paddlenlp.transformers.utils import device_guard, is_safetensors_available
Expand Down Expand Up @@ -474,7 +469,7 @@ def check_optimizer_param(parameter):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
8 changes: 2 additions & 6 deletions paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@
import gc
import os

import paddle
from tqdm.auto import tqdm

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import (
_load_state_dict_into_model,
Expand Down Expand Up @@ -252,7 +248,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
9 changes: 2 additions & 7 deletions paddlenlp/trainer/unified_checkpoint/load_save_single_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@

import paddle

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import (
_load_state_dict_into_model,
Expand Down Expand Up @@ -120,7 +115,7 @@ def save_single_card_optimizer(model, optimizer, output_dir):
fp32_weight = {}
for k, v in state_dict.items():
static2struct_name_mappings[v.name] = k
if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32:
if master_weights is not None and v.dtype == paddle.float32:
fp32_weight[k] = v

# rename optimizer param
Expand Down Expand Up @@ -226,7 +221,7 @@ def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
9 changes: 2 additions & 7 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import paddle
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.argparser import strtobool
from paddlenlp.trainer.utils.helper import distributed_isfile
Expand Down Expand Up @@ -281,7 +276,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down Expand Up @@ -529,7 +524,7 @@ def unified_optimizer_into_shards(
fp32_weight = {}
for k, v in state_dict.items():
static2struct_name_mappings[v.name] = k
if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32:
if master_weights is not None and v.dtype == paddle.float32:
if args.dataset_rank > 0: # deal with different dataset rank.
continue
fp32_weight[k] = v
Expand Down
11 changes: 1 addition & 10 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.base import core
except:
core = None

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption
from paddlenlp.trainer.utils.helper import distributed_isfile
Expand Down Expand Up @@ -231,11 +226,7 @@ def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weight
expected_keys = []
for key in list(sharded_metadata["all_optimizer_keys"]):
key_name = key.split("/")[0]
if (
is_master_weights
and key_name in model_state_dict
and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32
):
if is_master_weights and key_name in model_state_dict and model_state_dict[key_name].dtype == paddle.float32:
continue

if args.use_expert_parallel and args.data_parallel_rank > 0:
Expand Down
Loading