Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add paddle hub #31873

Merged
merged 38 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6ef62e0
add hub list function
LielinJiang Mar 9, 2021
a71dff5
add help load func
lyuwenyu Mar 23, 2021
64570eb
add load and help funcs.
lyuwenyu Mar 24, 2021
af43bfd
add local repo support for list/help/load func
lyuwenyu Mar 24, 2021
00a9786
add local repo support for list/help/load func
lyuwenyu Mar 24, 2021
1b613cb
add _check_dependencies hubconf
lyuwenyu Mar 24, 2021
884bbf9
update hub docs, test=develop
lyuwenyu Mar 24, 2021
2c52a28
update hub docs, test=develop
lyuwenyu Mar 24, 2021
f9dd61d
update docs
lyuwenyu Mar 25, 2021
c8d2ae3
upadte docs
lyuwenyu Mar 31, 2021
8acff14
Merge remote-tracking branch 'upstream/develop' into hub_L
lyuwenyu Mar 31, 2021
2237120
init unittest for hapi/hub
lyuwenyu Apr 6, 2021
8185cc4
add hub unittest, test=develop
lyuwenyu Apr 7, 2021
c4347b2
support source `gitee`
lyuwenyu Apr 7, 2021
aa0eac9
update hapi hub test
lyuwenyu Apr 19, 2021
dfa6d03
for ci test
lyuwenyu Apr 20, 2021
ec5183d
add hub test
lyuwenyu Apr 20, 2021
cc7bbe7
Merge remote-tracking branch 'upstream/develop' into hub_L
lyuwenyu Apr 20, 2021
58d9632
remove *args in `load` for py2
lyuwenyu Apr 20, 2021
8ecb192
ci timeout problem
lyuwenyu Apr 20, 2021
8d544b9
replace importlib to __import__ for py2
lyuwenyu Apr 21, 2021
17e6517
replace importlib to __import__ for py2
lyuwenyu Apr 21, 2021
697ccbd
test local, github timeout
lyuwenyu Apr 21, 2021
3d7d9d4
add exception test
lyuwenyu Apr 21, 2021
3552ea2
fix ci timeout problem
lyuwenyu Apr 22, 2021
b8d5d28
fix ci timeout problem
lyuwenyu Apr 22, 2021
efb1a83
fix ci timeout problem
lyuwenyu Apr 22, 2021
cb03c54
fix ci timeout problem
lyuwenyu Apr 22, 2021
ba15ac1
tests
lyuwenyu Apr 22, 2021
ae0bb17
fix docs, bugs of import, and more unittest
lyuwenyu Apr 23, 2021
59e57e3
update
lyuwenyu Apr 23, 2021
204e047
update
lyuwenyu Apr 23, 2021
56e1042
update
lyuwenyu Apr 23, 2021
aea875a
Merge remote-tracking branch 'upstream/develop' into hub_L
lyuwenyu Apr 23, 2021
67d17ca
update
lyuwenyu Apr 23, 2021
cad8c80
update
lyuwenyu Apr 23, 2021
7408e84
update default branch
lyuwenyu Apr 25, 2021
7829730
update, remove default branch var
lyuwenyu Apr 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@
from .hapi import callbacks
from .hapi import summary
from .hapi import flops
from .hapi import hub

import paddle.text
import paddle.vision

Expand Down
1 change: 1 addition & 0 deletions python/paddle/hapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import logger
from . import callbacks
from . import model_summary
from . import hub

from . import model
from .model import *
Expand Down
277 changes: 277 additions & 0 deletions python/paddle/hapi/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import sys
import shutil
import zipfile
from paddle.utils.download import get_path_from_url

MAIN_BRANCH = 'main'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名称是不是应该叫DEFAULT_BRANCH
另外,是不是github的时候default是main, gitee的时候default是master,比较合理。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub'))


def _remove_if_exists(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)


def _import_module(name, repo_dir):
sys.path.insert(0, repo_dir)
try:
hub_module = __import__(name)
sys.modules.pop(name)
except ImportError:
print('Cannot import `{}`, please make sure `{}`.py in repo root dir'.
format(name, name))
sys.path.remove(repo_dir)
sys.path.remove(repo_dir)

return hub_module


def _git_archive_link(repo_owner, repo_name, branch, source):
if source == 'github':
return 'https://github.com/{}/{}/archive/{}.zip'.format(
repo_owner, repo_name, branch)
elif source == 'gitee':
return 'https://gitee.com/{}/{}/repository/archive/{}.zip'.format(
repo_owner, repo_name, branch)


def _parse_repo_info(github):
branch = MAIN_BRANCH
if ':' in github:
repo_info, branch = github.split(':')
else:
repo_info = github
repo_owner, repo_name = repo_info.split('/')
return repo_owner, repo_name, branch


def _make_dirs(dirname):
try:
from pathlib import Path
except ImportError:
from pathlib2 import Path
Path(dirname).mkdir(exist_ok=True)


def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
# Setup hub_dir to save downloaded files
hub_dir = HUB_DIR
if not os.path.exists(hub_dir):
# os.makedirs(hub_dir)
_make_dirs(hub_dir)

# Parse github/gitee repo information
repo_owner, repo_name, branch = _parse_repo_info(repo)
# Github allows branch name with slash '/',
# this causes confusion with path on both Linux and Windows.
# Backslash is not allowed in Github branch name so no need to
# to worry about it.
normalized_br = branch.replace('/', '_')
# Github renames folder repo/v1.x.x to repo-1.x.x
# We don't know the repo name before downloading the zip file
# and inspect name from it.
# To check if cached repo exists, we need to normalize folder names.
repo_dir = os.path.join(hub_dir,
'_'.join([repo_owner, repo_name, normalized_br]))

use_cache = (not force_reload) and os.path.exists(repo_dir)

if use_cache:
if verbose:
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
else:
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
_remove_if_exists(cached_file)

url = _git_archive_link(repo_owner, repo_name, branch, source=source)

get_path_from_url(url, hub_dir, decompress=False)

with zipfile.ZipFile(cached_file) as cached_zipfile:
extraced_repo_name = cached_zipfile.infolist()[0].filename
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
_remove_if_exists(extracted_repo)
# Unzip the code and rename the base folder
cached_zipfile.extractall(hub_dir)

_remove_if_exists(cached_file)
_remove_if_exists(repo_dir)
# rename the repo
shutil.move(extracted_repo, repo_dir)

return repo_dir


def _load_entry_from_hubconf(m, name):
'''load entry from hubconf
'''
if not isinstance(name, str):
raise ValueError(
'Invalid input: model should be a str of function name')

func = getattr(m, name, None)

if func is None or not callable(func):
raise RuntimeError('Cannot find callable {} in hubconf'.format(name))

return func


def _check_module_exists(name):
try:
__import__(name)
return True
except ImportError:
return False


def _check_dependencies(m):
dependencies = getattr(m, VAR_DEPENDENCY, None)

if dependencies is not None:
missing_deps = [
pkg for pkg in dependencies if not _check_module_exists(pkg)
]
if len(missing_deps):
raise RuntimeError('Missing dependencies: {}'.format(', '.join(
missing_deps)))


def list(repo_dir, source='github', force_reload=False):
r"""
List all entrypoints available in `github` hubconf.

Args:
repo_dir(str): github or local path
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional): whether to discard the existing cache and force a fresh download, default is `False`.
Returns:
entrypoints: a list of available entrypoint names

Example:
```python
import paddle

paddle.hub.list('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False)

```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))

if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)

hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

entrypoints = [
f for f in dir(hub_module)
if callable(getattr(hub_module, f)) and not f.startswith('_')
]

return entrypoints


def help(repo_dir, model, source='github', force_reload=False):
"""
Show help information of model

Args:
repo_dir(str): github or local path
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
model (str): model name
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional): default is `False`
Return:
docs

Example:
```python
import paddle

paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))

if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)

hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

entry = _load_entry_from_hubconf(hub_module, model)

return entry.__doc__


def load(repo_dir, model, source='github', force_reload=False, **kwargs):
"""
Load model

Args:
repo_dir(str): github or local path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as offline discussion, also can add gitee

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will push it after testing offline

github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
model (str): model name
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional), default is `False`
**kwargs: parameters using for model
Return:
paddle model
Example:
```python
import paddle
paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))

if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)

hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

_check_dependencies(hub_module)

entry = _load_entry_from_hubconf(hub_module, model)

return entry(**kwargs)
1 change: 1 addition & 0 deletions python/paddle/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300)
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300)
24 changes: 24 additions & 0 deletions python/paddle/tests/hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

dependencies = ['paddle']

import paddle
from test_hapi_hub_model import MM as _MM


def MM(out_channels=8, pretrained=False):
'''This is a test demo for paddle hub
'''
return _MM(out_channels)
Loading