Skip to content

Commit

Permalink
[Bug fixes] update dtype variable init (#5646)
Browse files Browse the repository at this point in the history
* update dtype variable init

* update dtype checking

* update dtype

* remove low_cpu_mem_usage

* update dtype guard
  • Loading branch information
wj-Mcat authored Apr 13, 2023
1 parent 08a0005 commit 376891b
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ def from_pretrained_v2(
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", None)
cache_dir = kwargs.pop("cache_dir", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
dtype = kwargs.pop("dtype", None)

cache_dir = resolve_cache_dir(pretrained_model_name_or_path, from_hf_hub, cache_dir)

Expand All @@ -1539,38 +1540,22 @@ def from_pretrained_v2(
**kwargs,
)

dtype = kwargs.pop("dtype", config.dtype)

init_contexts = []
if low_cpu_mem_usage:
load_state_as_np = True
# Instantiate model.
init_contexts.append(no_init_weights(_enable=True))
if is_paddle_support_lazy_init():
init_contexts.append(paddle.LazyGuard())

# Fix me for loading dtype paddle.int64 but cast paddle.float32
# if dtype is None, use config.dtype instead
if dtype is None and config.dtype is not None:
if dtype is None:
dtype = config.dtype

if dtype:
init_contexts.append(dtype_guard(dtype))

if not os.path.exists(os.path.join(cache_dir, CONFIG_NAME)):
config.save_pretrained(cache_dir)

# PretrainedConfig auto contains dtype field
dtype = kwargs.pop("dtype", config.get("dtype", None))
init_contexts = []
if low_cpu_mem_usage:
load_state_as_np = True
# Instantiate model.
init_contexts.append(no_init_weights(_enable=True))
if is_paddle_support_lazy_init():
init_contexts.append(paddle.LazyGuard())
if dtype:
init_contexts.append(dtype_guard(dtype))

if dtype:
init_contexts.append(dtype_guard(dtype))

# 2. resolve model_weight file
support_conversion = cls.support_conversion(config) and ENABLE_TORCH_CHECKPOINT
Expand Down

0 comments on commit 376891b

Please sign in to comment.