-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Download重构 #8020
Download重构 #8020
Changes from 41 commits
66744bb
40b27c4
68b5f8c
e342983
fcc392b
3aa76ab
d6dfcf0
0705617
f9c5af7
275e52b
76cd0da
9bdc94e
df82769
5148bc6
6a0085b
7006332
620aacc
ae6169f
7268671
fe24034
85f37cb
37b3c25
d8c552d
c22851a
e392644
40842fd
b44f8ed
a18ca41
03d5047
6bb0544
0364a65
b60d218
850796f
8ce5dfe
af7bb9d
3109368
d25e6cd
ee497e5
ed4d372
d829bc5
eb06571
793784f
286b80a
119c648
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,10 +24,9 @@ | |
from paddle.framework import core | ||
|
||
from paddlenlp.transformers import PretrainedModel | ||
from paddlenlp.utils.download import get_file | ||
|
||
# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later | ||
from paddlenlp.utils.downloader import COMMUNITY_MODEL_PREFIX, get_path_from_url | ||
from paddlenlp.utils.env import MODEL_HOME | ||
from paddlenlp.utils.log import logger | ||
|
||
__all__ = ["FasterPretrainedModel", "ActScalesLoader", "WeightScalesLoader"] | ||
|
@@ -96,6 +95,11 @@ | |
pretrained_models = list(cls.pretrained_init_configuration.keys()) | ||
resource_files = {} | ||
init_configuration = {} | ||
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. paddlenlp/experimental/model_utils.py 这些代码有CI测试覆盖吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. experimental目录下没有专门新增单测,但是transformers下有新增单测,只是加上单测会导致ci失败,但是在本地是可以正常运行的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JunnYu 这里CE可以覆盖吗?对推理而言风向比较大。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我那里的CE都是动态图的,不会涉及到experimental的部分 |
||
cache_dir = kwargs.pop("cache_dir", None) | ||
from_hf_hub = kwargs.pop("from_hf_hub", False) | ||
from_aistudio = kwargs.pop("from_aistudio", False) | ||
subfolder = kwargs.pop("subfolder", "") | ||
|
||
# From built-in pretrained models | ||
if pretrained_model_name_or_path in pretrained_models: | ||
|
@@ -106,40 +110,27 @@ | |
elif os.path.isdir(pretrained_model_name_or_path): | ||
for file_id, file_name in cls.resource_files_names.items(): | ||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) | ||
resource_files[file_id] = full_file_name | ||
if os.path.isfile(full_file_name): | ||
resource_files[file_id] = full_file_name | ||
resource_files["model_config_file"] = os.path.join(pretrained_model_name_or_path, cls.model_config_file) | ||
else: | ||
# Assuming from community-contributed pretrained models | ||
for file_id, file_name in cls.resource_files_names.items(): | ||
full_file_name = "/".join([COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, file_name]) | ||
resource_files[file_id] = full_file_name | ||
resource_files["model_config_file"] = "/".join( | ||
[COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.model_config_file] | ||
) | ||
resource_files[file_id] = file_name | ||
|
||
default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) | ||
# default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) | ||
resolved_resource_files = {} | ||
for file_id, file_path in resource_files.items(): | ||
if file_path is None or os.path.isfile(file_path): | ||
resolved_resource_files[file_id] = file_path | ||
continue | ||
path = os.path.join(default_root, file_path.split("/")[-1]) | ||
if os.path.exists(path): | ||
logger.info("Already cached %s" % path) | ||
resolved_resource_files[file_id] = path | ||
else: | ||
logger.info("Downloading %s and saved to %s" % (file_path, default_root)) | ||
try: | ||
resolved_resource_files[file_id] = get_path_from_url(file_path, default_root) | ||
except RuntimeError as err: | ||
logger.error(err) | ||
raise RuntimeError( | ||
f"Can't load weights for '{pretrained_model_name_or_path}'.\n" | ||
f"Please make sure that '{pretrained_model_name_or_path}' is:\n" | ||
"- a correct model-identifier of built-in pretrained models,\n" | ||
"- or a correct model-identifier of community-contributed pretrained models,\n" | ||
"- or the correct path to a directory containing relevant modeling files(model_weights and model_config).\n" | ||
) | ||
resolved_resource_files[file_id] = get_file( | ||
pretrained_model_name_or_path, | ||
[file_path], | ||
subfolder, | ||
cache_dir=cache_dir, | ||
from_aistudio=from_aistudio, | ||
from_hf_hub=from_hf_hub, | ||
) | ||
|
||
# Prepare model initialization kwargs | ||
# Did we saved some inputs and kwargs to reload ? | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,20 +20,11 @@ | |
from collections import defaultdict | ||
from typing import Dict, List, Type | ||
|
||
from huggingface_hub import hf_hub_download | ||
|
||
from ... import __version__ | ||
from ...utils.downloader import ( | ||
COMMUNITY_MODEL_PREFIX, | ||
get_path_from_url_with_filelock, | ||
url_file_exists, | ||
) | ||
from ...utils.download import get_file | ||
from ...utils.import_utils import import_module | ||
from ...utils.log import logger | ||
from ..aistudio_utils import aistudio_download | ||
from ..configuration_utils import PretrainedConfig | ||
from ..model_utils import PretrainedModel | ||
from ..utils import resolve_cache_dir | ||
|
||
__all__ = [ | ||
"AutoConfig", | ||
|
@@ -170,13 +161,8 @@ | |
config = AutoConfig.from_pretrained("bert-base-uncased") | ||
config.save_pretrained('./bert-base-uncased') | ||
""" | ||
subfolder = kwargs.get("subfolder", "") | ||
if subfolder is None: | ||
subfolder = "" | ||
from_aistudio = kwargs.pop("from_aistudio", False) | ||
from_hf_hub = kwargs.pop("from_hf_hub", False) | ||
cache_dir = kwargs.pop("cache_dir", None) | ||
cache_dir = resolve_cache_dir(from_hf_hub=from_hf_hub, from_aistudio=from_aistudio, cache_dir=cache_dir) | ||
|
||
# cache_dir = resolve_cache_dir(from_hf_hub=from_hf_hub, from_aistudio=from_aistudio, cache_dir=cache_dir) | ||
ZHUI marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not cls.name2class: | ||
cls.name2class = {} | ||
|
@@ -192,72 +178,33 @@ | |
pretrained_model_name_or_path, *model_args, **kwargs | ||
) | ||
|
||
# From local dir path | ||
elif os.path.isdir(pretrained_model_name_or_path): | ||
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_file) | ||
if not os.path.exists(config_file): | ||
# try to load legacy config file | ||
legacy_config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.legacy_config_file) | ||
if not os.path.exists(legacy_config_file): | ||
raise ValueError( | ||
f"config file<{cls.config_file}> or legacy config file<{cls.legacy_config_file}> not found" | ||
) | ||
subfolder = kwargs.get("subfolder", "") | ||
if subfolder is None: | ||
subfolder = "" | ||
from_aistudio = kwargs.pop("from_aistudio", False) | ||
from_hf_hub = kwargs.pop("from_hf_hub", False) | ||
cache_dir = kwargs.pop("cache_dir", None) | ||
|
||
logger.warning(f"loading legacy config file<{cls.legacy_config_file}> ...") | ||
config_file = legacy_config_file | ||
config_file = get_file( | ||
pretrained_model_name_or_path, | ||
[cls.config_file, cls.legacy_config_file], | ||
subfolder, | ||
cache_dir=cache_dir, | ||
from_hf_hub=from_hf_hub, | ||
from_aistudio=from_aistudio, | ||
) | ||
|
||
if os.path.exists(config_file): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否一定是 exists 的?不存在的话,报错是不是在 get_file 内部? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果下载失败的话是在get_file内部报错,如果repo没有该文件get_file会返回None,会在这报错 |
||
config_class = cls._get_config_class_from_config(pretrained_model_name_or_path, config_file) | ||
logger.info("We are using %s to load '%s'." % (config_class, pretrained_model_name_or_path)) | ||
if config_class is cls: | ||
return cls.from_file(config_file) | ||
return config_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
elif from_aistudio: | ||
file = aistudio_download( | ||
repo_id=pretrained_model_name_or_path, | ||
filename=cls.config_file, | ||
subfolder=subfolder, | ||
cache_dir=cache_dir, | ||
) | ||
return cls.from_pretrained(os.path.dirname(file)) | ||
elif from_hf_hub: | ||
file = hf_hub_download( | ||
repo_id=pretrained_model_name_or_path, | ||
filename=cls.config_file, | ||
cache_dir=cache_dir, | ||
subfolder=subfolder, | ||
library_name="PaddleNLP", | ||
library_version=__version__, | ||
) | ||
# from local dir path | ||
return cls.from_pretrained(os.path.dirname(file)) | ||
|
||
# Assuming from community-contributed pretrained models | ||
return config_class.from_pretrained(config_file, *model_args, **kwargs) | ||
else: | ||
url_list = [COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.config_file] | ||
legacy_url_list = [COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.legacy_config_file] | ||
cache_dir = os.path.join(cache_dir, pretrained_model_name_or_path, subfolder) | ||
if subfolder != "": | ||
url_list.insert(2, subfolder) | ||
legacy_url_list.insert(2, subfolder) | ||
community_config_path = "/".join(url_list) | ||
legacy_community_config_path = "/".join(legacy_url_list) | ||
|
||
if not url_file_exists(community_config_path): | ||
if not url_file_exists(legacy_community_config_path): | ||
raise RuntimeError( | ||
f"Can't load Config for '{pretrained_model_name_or_path}'.\n" | ||
f"Please make sure that '{pretrained_model_name_or_path}' is:\n" | ||
"- a correct model-identifier of built-in pretrained models,\n" | ||
"- or a correct model-identifier of community-contributed pretrained models,\n" | ||
"- or the correct path to a directory containing relevant config files.\n" | ||
) | ||
logger.warning(f"loading legacy config file<{cls.legacy_config_file}> ...") | ||
community_config_path = legacy_community_config_path | ||
|
||
resolved_config_file = get_path_from_url_with_filelock(community_config_path, cache_dir) | ||
config_class = cls._get_config_class_from_config(pretrained_model_name_or_path, resolved_config_file) | ||
logger.info("We are using %s to load '%s'." % (config_class, pretrained_model_name_or_path)) | ||
if config_class is cls: | ||
return cls.from_file(resolved_config_file, **kwargs) | ||
|
||
return config_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
raise RuntimeError( | ||
f"Can't load config for '{pretrained_model_name_or_path}'.\n" | ||
f"Please make sure that '{pretrained_model_name_or_path}' is:\n" | ||
"- a correct model-identifier of built-in pretrained models,\n" | ||
"- or a correct model-identifier of community-contributed pretrained models,\n" | ||
"- or the correct path to a directory containing relevant config files.\n" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why disable it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为之前测试发现这里会报错,所以去除了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_args.to_static 改成 training_args.to_static 你看一下