Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewFeature] add paddlenlp command #3538

Merged
merged 49 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
dbbd3ea
add cli solution
wj-Mcat Oct 22, 2022
d94606d
add model weight converter command
wj-Mcat Oct 22, 2022
417ff77
Merge branch 'develop' into add-cli
wj-Mcat Oct 25, 2022
46c6e6f
add converter
wj-Mcat Oct 27, 2022
b96ed14
Merge branch 'add-cli' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Oct 27, 2022
841b1ad
update cli setup
wj-Mcat Oct 27, 2022
073295a
Merge branch 'develop' into add-cli
wj-Mcat Oct 27, 2022
14c5a35
Merge branch 'develop' into add-cli
wj-Mcat Oct 27, 2022
371bf0a
update cli
wj-Mcat Oct 27, 2022
ed8cf48
Merge branch 'develop' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Oct 28, 2022
689751f
update search & download
wj-Mcat Nov 1, 2022
3eb24b3
Merge branch 'add-cli' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Nov 1, 2022
1fe1695
update table header name
wj-Mcat Nov 1, 2022
8f51693
rename logger msg
wj-Mcat Nov 1, 2022
76e7e28
upgrade command arguments
wj-Mcat Nov 1, 2022
e0e3510
merge develop branch
wj-Mcat Nov 1, 2022
a43c6b4
complete converter
wj-Mcat Nov 1, 2022
0e7c5f2
add comments to online convert
wj-Mcat Nov 1, 2022
52d7ff9
add comment to cli/main
wj-Mcat Nov 1, 2022
d6da5b6
add clip converter
wj-Mcat Nov 1, 2022
59d9330
update clip converter
wj-Mcat Nov 2, 2022
c0c04b1
complete clip converter
wj-Mcat Nov 2, 2022
4882cd1
add init_class field
wj-Mcat Nov 2, 2022
2af7fee
Merge branch 'develop' into add-cli
wj-Mcat Nov 2, 2022
f803018
Merge branch 'develop' into add-cli
wj-Mcat Nov 2, 2022
7ab3966
update converter
wj-Mcat Nov 2, 2022
fe34a5c
Merge branch 'add-cli' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Nov 2, 2022
87cdfc8
add state-dict checker
wj-Mcat Nov 3, 2022
7357f79
imporve converter gpu memory usage case
wj-Mcat Nov 7, 2022
4f5303d
remove unused fields from config
wj-Mcat Nov 7, 2022
cb97e96
add remove unused fileds
wj-Mcat Nov 7, 2022
74a1710
add config fields to converter
wj-Mcat Nov 7, 2022
3a29bef
add cli
wj-Mcat Nov 9, 2022
e0b4df4
Merge branch 'develop' into add-cli
wj-Mcat Nov 9, 2022
0d3477b
add convert for stable-diffusion
wj-Mcat Nov 9, 2022
a8770cb
Merge branch 'add-cli' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Nov 9, 2022
03dafef
remove ppdiffusers converter
wj-Mcat Nov 11, 2022
9040b1a
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
598258f
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
e792ca3
update description of cli
wj-Mcat Nov 16, 2022
f1da84e
Merge branch 'add-cli' of github.com:wj-Mcat/PaddleNLP into add-cli
wj-Mcat Nov 16, 2022
1c142d7
update exit code
wj-Mcat Nov 16, 2022
e198d1b
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
1b77e4e
add cli dependency to dev tag
wj-Mcat Nov 16, 2022
f663930
merge develop branch
wj-Mcat Nov 16, 2022
fff7eea
fix typo
wj-Mcat Nov 16, 2022
871b4cf
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
171d370
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
784bb93
Merge branch 'develop' into add-cli
wj-Mcat Nov 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddlenlp/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .main import main
118 changes: 118 additions & 0 deletions paddlenlp/cli/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
from typing import Type
from paddlenlp.utils.log import logger
from paddlenlp.utils.converter import Converter, load_all_converters
from paddlenlp.transformers.bert.converter import BertConverter


def convert_from_local_dir(pretrained_dir: str, output: str):
"""convert weight from local dir

Args:
pretrained_dir (str): the pretrained dir
output (str): the output dir
"""
# 1. checking the related files
files = os.listdir(pretrained_dir)
assert 'pytorch_model.bin' in files, f"`pytorch_model.bin` file must exist in dir<{pretrained_dir}>"
assert 'config.json' in files, f"`config.json` file must exist in dir<{pretrained_dir}>"

# 2. get model architecture from config.json
config_file = os.path.join(pretrained_dir, 'config.json')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)

architectures = config.pop("architectures", []) or config.pop(
"init_class", None)
if not architectures:
raise ValueError(
"can not find the model weight architectures with field: <architectures> and <init_class>"
)

if isinstance(architectures, str):
architectures = [architectures]

if len(architectures) > 1:
raise ValueError("only support one model architecture")

architecture = architectures[0]

# 3. retrieve Model Converter
target_converter_classes = [
converter_class for converter_class in load_all_converters()
if architecture in converter_class.architectures
]
if not target_converter_classes:
logger.error(
f"can not find target Converter based on architecture<{architecture}>"
)
if len(target_converter_classes) > 1:
logger.warning(
f"{len(target_converter_classes)} found, we will adopt the first one as the target converter ..."
)

target_converter_class: Type[Converter] = target_converter_classes[0]

# 4. do converting
converter = target_converter_class()
converter.convert(pretrained_dir, output_dir=output)


def convert_from_local_file(weight_file_path: str, output: str):
"""convert from the local weitht file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weitht -> weight


Args:
weight_file_path (str): the path of the weight file
output (str): the output dir
"""
# 1. check the name of weight file
if not os.path.isdir(weight_file_path):
weight_file_dir, filename = os.path.split(weight_file_path)
if filename != "pytorch_model.bin":
shutil.copy(weight_file_path,
os.path.join(weight_file_dir, 'pytorch_model.bin'))

weight_file_path = weight_file_dir
convert_from_local_dir(weight_file_path, output)


def convert_from_online_model(model_name: str, cache_dir: str, output_dir):
"""convert the model which is not maintained in paddlenlp community, eg: vblagoje/bert-english-uncased-finetuned-pos

TODO(wj-Mcat): to deeply test this method

Args:
model_name (str): the name of model
cache_dir (str): the cache_dir to save pytorch model
output_dir (_type_): the output dir
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于时间原因,现在暂不测试这个函数,等模型中心弄完之后再来深入测试online_convert的相关功能。

# 1. auto save
from transformers import AutoModel
model = AutoModel.from_pretrained(model_name)
model.save_pretrained(cache_dir)

# 2. resolve the converter
config_file = os.path.join(cache_dir, 'config.json')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)

architectures = config['architectures']

converter = BertConverter()
converter.convert(input_dir=cache_dir, output_dir=output_dir)
62 changes: 62 additions & 0 deletions paddlenlp/cli/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import List, Tuple
from dataclasses import dataclass

from paddlenlp.utils.log import logger
from paddlenlp.utils.env import MODEL_HOME
from paddlenlp.utils.downloader import COMMUNITY_MODEL_PREFIX, get_path_from_url

COMMUNITY_MODEL_CONFIG_FILE_NAME = "community_models.json"


def load_community_models() -> List[Tuple[str, str]]:
"""load community models based on remote models.json

Returns:
List[Tuple[str, str]]: the name tuples of community models
"""
# 1. check & download community models.json
local_community_model_config_path = os.path.join(MODEL_HOME,
"community_models.json")

if not os.path.exists(local_community_model_config_path):
logger.info("download community model configuration from server ...")
remote_community_model_path = os.path.join(
COMMUNITY_MODEL_PREFIX, COMMUNITY_MODEL_CONFIG_FILE_NAME)
cache_dir = os.path.join(MODEL_HOME)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

community_models.json缓存文件放到~/.paddlenlp/models目录下。

local_community_model_config_path = get_path_from_url(
remote_community_model_path, root_dir=cache_dir)

# 2. load configuration
#
# config = {
# "model_name": {
# "type": "",
# "files": ["", ""]
# }
# }
#

with open(local_community_model_config_path, 'r', encoding='utf-8') as f:
config = json.load(f)

model_names = set()
for model_name, obj in config.items():
model_names.add(("community", model_name, obj.get("type", model_name)))
logger.info(f"find {len(model_names)} community models ...")
return model_names
157 changes: 157 additions & 0 deletions paddlenlp/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from genericpath import isdir
import os
import json
from typing import Type, List, Tuple, Optional
import typer
from typer import Typer
import shutil
import importlib, inspect
from paddlenlp import __version__
from paddlenlp.transformers import AutoModel, AutoTokenizer, PretrainedModel, PretrainedTokenizer
from paddlenlp.utils.log import logger
from paddlenlp.utils.downloader import is_url
from paddlenlp.cli.converter import convert_from_local_file, convert_from_local_dir, convert_from_online_model
from paddlenlp.cli.utils.tabulate import tabulate, print_example_code
from paddlenlp.cli.download import load_community_models


def load_all_models(include_community: bool = False) -> List[Tuple[str, str]]:
"""load all model_name infos

Returns:
List[Tuple[str, str]]: [model_type, model_name]
"""
module = importlib.import_module("paddlenlp.transformers")
model_names = set()
for attr_name in dir(module):
if attr_name.startswith("_"):
continue
obj = getattr(module, attr_name)
if not inspect.isclass(obj):
continue
if not issubclass(obj, PretrainedModel):
continue

obj: Type[PretrainedModel] = obj
if not obj.__name__.endswith("PretrainedModel"):
continue
configurations = obj.pretrained_init_configuration
for model_name in configurations.keys():
model_names.add(("official", obj.base_model_prefix, model_name))
logger.info(f"find {len(model_names)} official models ...")

if include_community:
# load & extend community models
community_model_names = load_community_models()
for model_name in community_model_names:
model_names.add(model_name)

return model_names


app = Typer()


@app.command()
def download(model_name: str,
cache_dir: str = typer.Option(
'./pretrained_models',
'--cache-dir',
'-c',
help="cache_dir for download pretrained model"),
force_download: bool = typer.Option(
False,
'--force-download',
'-f',
help="force download pretrained model")):
"""download the paddlenlp models with command, you can specific `model_name`

>>> paddlenlp download bert \n
>>> paddlenlp download -c ./my-models -f bert \n

Args:\n
model_name (str): pretarined model name, you can checkout all of model from source code. \n
cache_dir (str, optional): the cache_dir. Defaults to "./models".
"""
if not os.path.isabs(cache_dir):
cache_dir = os.path.join(os.getcwd(), cache_dir)

if is_url(model_name):
logger.error("<MODEL_NAME> can not be url")
return

cache_dir = os.path.join(cache_dir, model_name)
if force_download:
shutil.rmtree(cache_dir, ignore_errors=True)

model: PretrainedModel = AutoModel.from_pretrained(model_name)
model.save_pretrained(cache_dir)

tokenizer: PretrainedTokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(cache_dir)

logger.info(f"successfully saved model into <{cache_dir}>")


@app.command()
def search(query=typer.Argument(..., help='the query of searching model'),
include_community: bool = typer.Option(
False,
"--include-community",
'-i',
help="whether searching community models")):
"""search the model with query, eg: paddlenlp search bert

>>> paddlenlp search bert \n
>>> paddlenlp search -i bert \n

Args: \n
query (Optional[str]): the str fragment of bert-name \n
include_community (Optional[bool]): whether searching community models
"""
logger.info("start to search models ...")
model_names = load_all_models(include_community)

tables = []
for model_category, model_type, model_name in model_names:
if not query or query in model_name:
tables.append([model_category, model_type, model_name])
tabulate(tables,
headers=["model source", 'model type', 'model name'],
highlight_word=query)
print_example_code()

logger.info(f"the retrieved number of models results is {len(tables)} ...")


@app.command(help="convert pytorch models to paddle model")
def convert(input: Optional[str] = None, output: Optional[str] = None):
logger.info("starting to convert models ...")
if os.path.isdir(input):
convert_from_local_dir(pretrained_dir=input, output=output)
else:
# TODO(wj-Mcat): should complete the online converting
convert_from_online_model()


def main():
"""the PaddleNLPCLI entry"""
app()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions paddlenlp/cli/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
typer
rich
13 changes: 13 additions & 0 deletions paddlenlp/cli/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading