diff --git a/cnocr/consts.py b/cnocr/consts.py index b3c2741..a87dfc8 100644 --- a/cnocr/consts.py +++ b/cnocr/consts.py @@ -1,5 +1,5 @@ # coding: utf-8 -# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). +# Copyright (C) 2021-2023, [Breezedeus](https://github.com/breezedeus). # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +import os import string from collections import OrderedDict from pathlib import Path @@ -32,6 +33,7 @@ # 模型版本只对应到第二层,第三层的改动表示模型兼容。 # 如: __version__ = '2.2.*',对应的 MODEL_VERSION 都是 '2.2' MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2]) +DOWNLOAD_SOURCE = os.environ.get('CNOCR_DOWNLOAD_SOURCE', 'CN') IMG_STANDARD_HEIGHT = 32 CN_VOCAB_FP = Path(__file__).parent.absolute() / 'label_cn.txt' @@ -135,20 +137,25 @@ HF_HUB_SUBFOLDER = "models/cnocr/%s" % MODEL_VERSION PAID_HF_HUB_REPO_ID = "breezedeus/paid-models" PAID_HF_HUB_SUBFOLDER = "cnocr/%s" % MODEL_VERSION +CN_OSS_ENDPOINT = ( + "https://sg-models.oss-cn-beijing.aliyuncs.com/cnocr/%s/" % MODEL_VERSION +) def format_hf_hub_url(url: str, is_paid_model=False) -> dict: + out_dict = {'filename': url} + if is_paid_model: repo_id = PAID_HF_HUB_REPO_ID subfolder = PAID_HF_HUB_SUBFOLDER else: repo_id = HF_HUB_REPO_ID subfolder = HF_HUB_SUBFOLDER - return { - 'repo_id': repo_id, - 'subfolder': subfolder, - 'filename': url, - } + out_dict['cn_oss'] = CN_OSS_ENDPOINT + out_dict.update( + {'repo_id': repo_id, 'subfolder': subfolder,} + ) + return out_dict class AvailableModels(object): diff --git a/cnocr/ppocr/pp_recognizer.py b/cnocr/ppocr/pp_recognizer.py index d535017..8d97528 100755 --- a/cnocr/ppocr/pp_recognizer.py +++ b/cnocr/ppocr/pp_recognizer.py @@ -33,7 +33,7 @@ from .postprocess import build_post_process from .utility import create_predictor from .consts import PP_SPACE -from ..consts import MODEL_VERSION, AVAILABLE_MODELS +from ..consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ def _assert_and_prepare_model_files(self, model_fp, root): % ((self._model_name, self._model_backend),) ) url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend) - get_model_file(url, self._model_dir) + get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip self._model_fp = model_fp logger.info('use model: %s' % self._model_fp) diff --git a/cnocr/recognizer.py b/cnocr/recognizer.py index 17fe7a8..b97330c 100644 --- a/cnocr/recognizer.py +++ b/cnocr/recognizer.py @@ -28,7 +28,7 @@ import torch from cnstd.utils import get_model_file -from .consts import MODEL_VERSION, AVAILABLE_MODELS +from .consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE from .models.ocr_model import OcrModel from .utils import ( data_dir, @@ -177,7 +177,7 @@ def _assert_and_prepare_model_files(self, model_fp, root): % ((self._model_name, self._model_backend),) ) url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend) - get_model_file(url, self._model_dir) + get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip fps = glob( '%s/%s*.%s' % (self._model_dir, self._model_file_prefix, model_ext) )