Skip to content

Commit

Permalink
add file download
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 9, 2022
1 parent 3d25a03 commit aeb9aa8
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 6 deletions.
21 changes: 18 additions & 3 deletions textgen/language_modeling/songnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from textgen.language_modeling.songnet_utils import (
ZHCharTokenizer, s2t, s2xy, s2xy_polish,
SongNetDataLoader,
BOS, EOS,
BOS,
EOS,
PRETRAINED_MODELS,
LOCAL_DIR,
http_get,
)

has_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -594,7 +598,7 @@ class SongNetModel:
def __init__(
self,
model_type='songnet',
model_name='shibing624/songnet-base-chinese-couplet',
model_name='songnet-base-chinese',
args=None,
use_cuda=has_cuda,
cuda_device=-1,
Expand Down Expand Up @@ -644,6 +648,17 @@ def __init__(
self.results = {}

if model_name:
bin_path = os.path.join(model_name, 'pytorch_model.bin')
if not os.path.exists(bin_path):
if model_name in PRETRAINED_MODELS:
local_model_dir = os.path.join(LOCAL_DIR, model_name)
local_bin_path = os.path.join(local_model_dir, 'pytorch_model.bin')
if not os.path.exists(bin_path):
url = PRETRAINED_MODELS[model_name]
http_get(url, local_model_dir)
else:
logger.warning(f'Model {bin_path} not exists, use local model {local_bin_path}')
model_name = local_model_dir
self.tokenizer = ZHCharTokenizer.from_pretrained(model_name, **kwargs)
self.model = SongNet(
self.tokenizer,
Expand All @@ -655,7 +670,7 @@ def __init__(
num_layers=self.args.num_layers,
smoothing_factor=self.args.smoothing_factor,
)
self.model.load_state_dict(torch.load(os.path.join(model_name, 'pytorch_model.bin')))
self.model.load_state_dict(torch.load(bin_path))

self.args.model_type = model_type
if model_name is None:
Expand Down
95 changes: 92 additions & 3 deletions textgen/language_modeling/songnet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
@description:
"""
import os
import sys
import random

import numpy as np
import torch
from loguru import logger
import shutil
import tarfile
import zipfile
import six
import requests

from tqdm.autonotebook import tqdm

PAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'
BOC, EOC = '<boc>', '<eoc>'
Expand All @@ -18,7 +25,14 @@
PS = ['<p-1>'] + ['<p' + str(i) + '>' for i in range(512)] # position
TS = ['<t-1>'] + ['<t' + str(i) + '>' for i in range(32)] # other types
PUNCS = {",", ".", "?", "!", ":", ",", "。", "?", "!", ":"}

PRETRAINED_MODELS = {
'songnet-base-chinese':
'https://github.com/shibing624/pycorrector/releases/download/0.4.5/convseq2seq_correction.tar.gz',
'songnet-base-chinese-couplet': '',
'songnet-base-chinese-poem': '',
'songnet-base-chinese-songci': '',
}
LOCAL_DIR = os.path.expanduser('~/.cache/torch/shibing624/')

class ZHCharTokenizer(object):
def __init__(self, vocab_file, specials=None):
Expand Down Expand Up @@ -68,7 +82,7 @@ def token2idx(self, x):
return self._token2idx.get(x, self.unk_idx)

def __repr__(self):
return f"ZHCharTokenizer<_token2idx size:{len(self._token2idx)}>"
return f"ZHCharTokenizer<vocab size:{len(self._token2idx)}>"

@classmethod
def from_pretrained(cls, model_dir, *init_inputs, **kwargs):
Expand Down Expand Up @@ -377,3 +391,78 @@ def preprocess_data(line, max_length, min_length):
if len(ys) < min_length:
return None
return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos


def http_get(url, path, extract: bool = True):
"""
Downloads a URL to a given path on disc
"""
if os.path.dirname(path) != '':
os.makedirs(os.path.dirname(path), exist_ok=True)

req = requests.get(url, stream=True)
if req.status_code != 200:
print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr)
req.raise_for_status()
return

download_filepath = path + "_part"
with open(download_filepath, "wb") as file_binary:
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=True)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
file_binary.write(chunk)

os.rename(download_filepath, path)
progress.close()

if extract:
data_dir = os.path.dirname(os.path.abspath(path))
_extract_archive(path, data_dir, 'auto')


def _extract_archive(file_path, path='.', archive_format='auto'):
"""
Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
:param file_path: path to the archive file
:param path: path to extract the archive file
:param archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
None or an empty list will return no matches found.
:return: True if a match was found and an archive extraction was completed,
False otherwise.
"""
if archive_format is None:
return False
if archive_format == 'auto':
archive_format = ['tar', 'zip']
if isinstance(archive_format, six.string_types):
archive_format = [archive_format]

for archive_type in archive_format:
if archive_type == 'tar':
open_fn = tarfile.open
is_match_fn = tarfile.is_tarfile
if archive_type == 'zip':
open_fn = zipfile.ZipFile
is_match_fn = zipfile.is_zipfile

if is_match_fn(file_path):
with open_fn(file_path) as archive:
try:
archive.extractall(path)
except (tarfile.TarError, RuntimeError,
KeyboardInterrupt):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
raise
return True
return False

0 comments on commit aeb9aa8

Please sign in to comment.