Skip to content

Commit

Permalink
use environment variable to determine the download source for models
Browse files Browse the repository at this point in the history
  • Loading branch information
breezedeus committed Oct 9, 2023
1 parent 330f958 commit cf72d25
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
19 changes: 13 additions & 6 deletions cnocr/consts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
# Copyright (C) 2021-2023, [Breezedeus](https://github.com/breezedeus).
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -17,6 +17,7 @@
# specific language governing permissions and limitations
# under the License.

import os
import string
from collections import OrderedDict
from pathlib import Path
Expand All @@ -32,6 +33,7 @@
# 模型版本只对应到第二层,第三层的改动表示模型兼容。
# 如: __version__ = '2.2.*',对应的 MODEL_VERSION 都是 '2.2'
MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2])
DOWNLOAD_SOURCE = os.environ.get('CNOCR_DOWNLOAD_SOURCE', 'CN')

IMG_STANDARD_HEIGHT = 32
CN_VOCAB_FP = Path(__file__).parent.absolute() / 'label_cn.txt'
Expand Down Expand Up @@ -135,20 +137,25 @@
HF_HUB_SUBFOLDER = "models/cnocr/%s" % MODEL_VERSION
PAID_HF_HUB_REPO_ID = "breezedeus/paid-models"
PAID_HF_HUB_SUBFOLDER = "cnocr/%s" % MODEL_VERSION
CN_OSS_ENDPOINT = (
"https://sg-models.oss-cn-beijing.aliyuncs.com/cnocr/%s/" % MODEL_VERSION
)


def format_hf_hub_url(url: str, is_paid_model=False) -> dict:
out_dict = {'filename': url}

if is_paid_model:
repo_id = PAID_HF_HUB_REPO_ID
subfolder = PAID_HF_HUB_SUBFOLDER
else:
repo_id = HF_HUB_REPO_ID
subfolder = HF_HUB_SUBFOLDER
return {
'repo_id': repo_id,
'subfolder': subfolder,
'filename': url,
}
out_dict['cn_oss'] = CN_OSS_ENDPOINT
out_dict.update(
{'repo_id': repo_id, 'subfolder': subfolder,}
)
return out_dict


class AvailableModels(object):
Expand Down
4 changes: 2 additions & 2 deletions cnocr/ppocr/pp_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .postprocess import build_post_process
from .utility import create_predictor
from .consts import PP_SPACE
from ..consts import MODEL_VERSION, AVAILABLE_MODELS
from ..consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,7 +106,7 @@ def _assert_and_prepare_model_files(self, model_fp, root):
% ((self._model_name, self._model_backend),)
)
url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend)
get_model_file(url, self._model_dir)
get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip

self._model_fp = model_fp
logger.info('use model: %s' % self._model_fp)
Expand Down
4 changes: 2 additions & 2 deletions cnocr/recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
from cnstd.utils import get_model_file

from .consts import MODEL_VERSION, AVAILABLE_MODELS
from .consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE
from .models.ocr_model import OcrModel
from .utils import (
data_dir,
Expand Down Expand Up @@ -177,7 +177,7 @@ def _assert_and_prepare_model_files(self, model_fp, root):
% ((self._model_name, self._model_backend),)
)
url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend)
get_model_file(url, self._model_dir)
get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip
fps = glob(
'%s/%s*.%s' % (self._model_dir, self._model_file_prefix, model_ext)
)
Expand Down

0 comments on commit cf72d25

Please sign in to comment.