diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index ada9dd3ff..70cfaa138 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -26,6 +26,16 @@ jobs: python-version: "3.8" cache: 'pip' cache-dependency-path: 'environments/**_requires.txt' + - name: Increase swapfile + run: | + df -h + free -h + sudo swapoff -a + sudo fallocate -l 12G /swapfile + sudo chmod 600 /swapfile + sudo mkswap /swapfile + sudo swapon /swapfile + sudo swapon --show - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/app.py b/app.py index 9ccc8e255..82499c062 100644 --- a/app.py +++ b/app.py @@ -40,7 +40,7 @@ def convert_to_jsonl(df): @st.cache_data def get_diversity_model(lang): - model_key = prepare_model(lang, 'spacy') + model_key = prepare_model('spacy', lang=lang) diversity_model = MODEL_ZOO.get(model_key) return diversity_model diff --git a/data_juicer/analysis/diversity_analysis.py b/data_juicer/analysis/diversity_analysis.py index 8d7c479b4..6a6a0b260 100644 --- a/data_juicer/analysis/diversity_analysis.py +++ b/data_juicer/analysis/diversity_analysis.py @@ -4,7 +4,7 @@ import spacy from loguru import logger -from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model +from data_juicer.utils.model_utils import get_model, prepare_model # Modify from self_instruct, please refer to @@ -110,8 +110,8 @@ def compute(self, lang_or_model=None, column_name='text'): # load diversity model lang_or_model = lang_or_model if lang_or_model else self.lang_or_model if isinstance(lang_or_model, str): - diversity_model = MODEL_ZOO.get( - prepare_model(lang_or_model, 'spacy')) + model_key = prepare_model('spacy', lang=lang_or_model) + diversity_model = get_model(model_key) else: diversity_model = lang_or_model diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index 68ce2560b..111c97ddc 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -51,7 +51,8 @@ def __init__(self, if tokenization: self.model_key = prepare_model( model_type='huggingface', - model_key='EleutherAI/pythia-6.9b-deduped') + model_name_or_path='EleutherAI/pythia-6.9b-deduped', + return_model=False) def compute_stats(self, sample): if self.tokenization: @@ -60,7 +61,7 @@ def compute_stats(self, sample): alpha_count = sum( map(lambda char: 1 if char.isalpha() else 0, sample[self.text_key])) - tokenizer = get_model(self.model_key, model_type='huggingface') + tokenizer = get_model(self.model_key) token_count = len( get_words_from_document( sample[self.text_key], diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index f2dcad826..c63036914 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -71,8 +71,8 @@ def __init__(self, val for vals in self.FLAGGED_WORDS.values() for val in vals ] if tokenization: - self.model_key = prepare_model(lang=lang, - model_type='sentencepiece') + self.model_key = prepare_model(model_type='sentencepiece', + lang=lang) def compute_stats(self, sample, context=False): # check if it's computed already @@ -84,9 +84,7 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index 05b048114..688b78420 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -68,7 +68,8 @@ def __init__(self, raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - self.model_key = prepare_model(model_type='hf_blip', model_key=hf_blip) + self.model_key = prepare_model(model_type='huggingface', + model_name_or_path=hf_blip) self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip @@ -118,7 +119,7 @@ def compute_stats(self, sample, context=False): truncation=True, max_length=model.config.text_config. max_position_embeddings, - padding=True) + padding=True).to(model.device) outputs = model(**inputs) itm_scores = outputs.itm_score.detach().cpu().softmax( diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index 3bc5ffef6..d67356cd1 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -69,7 +69,8 @@ def __init__(self, raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - self.model_key = prepare_model(model_type='hf_clip', model_key=hf_clip) + self.model_key = prepare_model(model_type='huggingface', + model_name_or_path=hf_clip) self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip @@ -118,7 +119,7 @@ def compute_stats(self, sample, context=False): truncation=True, max_length=model.config.text_config. max_position_embeddings, - padding=True) + padding=True).to(model.device) outputs = model(**inputs) chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0 diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index cae60f99b..6b71cf112 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -54,7 +54,7 @@ def compute_stats(self, sample): return sample text = sample[self.text_key].lower().replace('\n', ' ') - ft_model = get_model(self.model_key, model_type='fasttext') + ft_model = get_model(self.model_key) if ft_model is None: err_msg = 'Model not loaded. Please retry later.' logger.error(err_msg) diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index d125d548c..64408b872 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -42,9 +42,9 @@ def __init__(self, super().__init__(*args, **kwargs) self.max_ppl = max_ppl self.lang = lang - self.sp_model_key = prepare_model(lang=lang, - model_type='sentencepiece') - self.kl_model_key = prepare_model(lang=lang, model_type='kenlm') + self.sp_model_key = prepare_model(model_type='sentencepiece', + lang=lang) + self.kl_model_key = prepare_model(model_type='kenlm', lang=lang) def compute_stats(self, sample, context=False): # check if it's computed already @@ -56,8 +56,7 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.sp_model_key, self.lang, - 'sentencepiece') + tokenizer = get_model(self.sp_model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) @@ -66,7 +65,7 @@ def compute_stats(self, sample, context=False): text = ' '.join(words) # compute perplexity logits, length = 0, 0 - kenlm_model = get_model(self.kl_model_key, self.lang, 'kenlm') + kenlm_model = get_model(self.kl_model_key) for line in text.splitlines(): logits += kenlm_model.score(line) length += (len(line.split()) + 1) diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py index 45d381b51..f19f21843 100644 --- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py +++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py @@ -129,9 +129,8 @@ def __init__(self, raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - self.model_type = 'hf_owlvit' - self.model_key = prepare_model(model_type=self.model_type, - model_key=hf_owlvit) + self.model_key = prepare_model(model_type='huggingface', + model_name_or_path=hf_owlvit) self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip @@ -164,8 +163,7 @@ def compute_stats(self, sample, context=False): text = sample[self.text_key] offset = 0 recalls = [] - model, processor = get_model(self.model_key, - model_type=self.model_type) + model, processor = get_model(self.model_key) for chunk in text.split(SpecialTokens.eoc): count = chunk.count(SpecialTokens.image) @@ -195,7 +193,7 @@ def compute_stats(self, sample, context=False): images=images_this_chunk, return_tensors='pt', padding=True, - truncation=True) + truncation=True).to(model.device) with torch.no_grad(): outputs = model(**inputs) diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py index 3d73f752f..f61542e13 100644 --- a/data_juicer/ops/filter/stopwords_filter.py +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -61,7 +61,6 @@ def __init__(self, self.words_aug_group_sizes = words_aug_group_sizes self.words_aug_join_char = words_aug_join_char self.model_key = None - self.lang = lang self.STOPWORDS = load_words_asset(words_dir=stopwords_dir, words_type='stopwords') @@ -70,8 +69,8 @@ def __init__(self, val for vals in self.STOPWORDS.values() for val in vals ] if tokenization: - self.model_key = prepare_model(lang=lang, - model_type='sentencepiece') + self.model_key = prepare_model(model_type='sentencepiece', + lang=lang) def compute_stats(self, sample, context=False): # check if it's computed already @@ -83,9 +82,7 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index 11ebd0ce5..342e77cbd 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -44,14 +44,15 @@ def __init__(self, self.max_num = max_num self.hf_tokenizer = hf_tokenizer self.model_key = prepare_model(model_type='huggingface', - model_key=hf_tokenizer) + model_name_or_path=hf_tokenizer, + return_model=False) def compute_stats(self, sample): # check if it's computed already if StatsKeys.num_token in sample[Fields.stats]: return sample - tokenizer = get_model(self.model_key, model_type='huggingface') + tokenizer = get_model(self.model_key) tokens = get_words_from_document( sample[self.text_key], token_func=tokenizer.tokenize if tokenizer else None) diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/word_num_filter.py index cc740c9d0..08ec90ee8 100644 --- a/data_juicer/ops/filter/word_num_filter.py +++ b/data_juicer/ops/filter/word_num_filter.py @@ -51,8 +51,8 @@ def __init__(self, self.lang = lang if tokenization: - self.model_key = prepare_model(lang=lang, - model_type='sentencepiece') + self.model_key = prepare_model(model_type='sentencepiece', + lang=lang) def compute_stats(self, sample, context=False): # check if it's computed already @@ -63,9 +63,7 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 126895a1a..187c23e06 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -56,8 +56,8 @@ def __init__(self, self.lang = lang if tokenization: - self.model_key = prepare_model(lang=lang, - model_type='sentencepiece') + self.model_key = prepare_model(model_type='sentencepiece', + lang=lang) def compute_stats(self, sample, context=False): # check if it's computed already @@ -69,9 +69,7 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index 056ebe20c..032743604 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -87,13 +87,8 @@ def __init__(self, f'Can only be one of ' f'["random_any", "similar_one_simhash", "all"].') - self.model_key = prepare_model(model_type='hf_blip', - model_key=hf_blip2, - usage='conditional_generation') - model, img_processor = get_model(model_key=self.model_key, - usage='conditional_generation') - self.model_in_ctx = model - self.img_processor_in_ctx = img_processor + self.model_key = prepare_model(model_type='huggingface', + model_name_or_path=hf_blip2) self.caption_num = caption_num self.keep_candidate_mode = keep_candidate_mode self.keep_original_sample = keep_original_sample @@ -151,6 +146,8 @@ def _process_single_sample(self, ori_sample): # the generated text will be placed following each SpecialTokens.img # and the original special tokens are kept in an order-preserving way. + model, processor = get_model(self.model_key) + # do generation for each image chunk by chunk for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc): # skip empty chunks or contents after the last eoc token @@ -182,13 +179,13 @@ def _process_single_sample(self, ori_sample): else: prompt_texts = None - inputs = self.img_processor_in_ctx(images=image_chunk, - text=prompt_texts, - return_tensors='pt') + inputs = processor(images=image_chunk, + text=prompt_texts, + return_tensors='pt').to(model.device) for i in range(self.caption_num): - generated_ids = self.model_in_ctx.generate(**inputs, - do_sample=True) - generated_text = self.img_processor_in_ctx.batch_decode( + generated_ids = model.generate(**inputs, + do_sample=True).to(model.device) + generated_text = processor.batch_decode( generated_ids, skip_special_tokens=True) generated_text_candidates_single_chunk[i] = generated_text diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index 4835d4de5..605a75e3b 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -40,8 +40,8 @@ def __init__(self, self.substrings = substrings self.lang = lang if tokenization: - self.model_key = prepare_model(lang=lang, - model_type='sentencepiece') + self.model_key = prepare_model(model_type='sentencepiece', + lang=lang) def should_keep_word_with_incorrect_substrings(self, word, substrings): word = strip(word, SPECIAL_CHARACTERS) @@ -50,9 +50,7 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings): def process(self, sample): if self.tokenization: - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) sentences = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py index 12e1372c8..522c01300 100644 --- a/data_juicer/ops/mapper/sentence_split_mapper.py +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -24,13 +24,11 @@ def __init__(self, lang: str = 'en', *args, **kwargs): """ super().__init__(*args, **kwargs) self.lang = lang - self.model_key = prepare_model(lang=lang, model_type='nltk') + self.model_key = prepare_model(model_type='nltk', lang=lang) def process(self, sample): - nltk_model = get_model(self.model_key, - lang=self.lang, - model_type='nltk') + nltk_model = get_model(self.model_key) sample[self.text_key] = get_sentences_from_document( sample[self.text_key], model_func=nltk_model.tokenize if nltk_model else None) diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 6e7497268..b989649f6 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -1,82 +1,92 @@ +import fnmatch import os +from functools import partial import wget from loguru import logger -from .cache_utils import DATA_JUICER_MODELS_CACHE +from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC -# Default directory to store models -MODEL_PATH = DATA_JUICER_MODELS_CACHE +MODEL_ZOO = {} + +# Default cached models links for downloading +MODEL_LINKS = 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' \ + 'data_juicer/models/' -# Default backup cached models links for downloading +# Backup cached models links for downloading BACKUP_MODEL_LINKS = { # language identification model from fasttext 'lid.176.bin': 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/', # tokenizer and language model for English from sentencepiece and KenLM - '%s.sp.model': + '*.sp.model': 'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/', - '%s.arpa.bin': + '*.arpa.bin': 'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/', # sentence split model from nltk punkt - 'punkt.%s.pickle': + 'punkt.*.pickle': 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' 'data_juicer/models/' } -# Default cached models links for downloading -MODEL_LINKS = 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' \ - 'data_juicer/models/' -MODEL_ZOO = {} +def get_backup_model_link(model_name): + for pattern, url in BACKUP_MODEL_LINKS.items(): + if fnmatch.fnmatch(model_name, pattern): + return url + return None -def check_model(model_name, args=(), force=False): +def check_model(model_name, force=False): """ - Check whether a model exists in MODEL_PATH. If exists, return its full path + Check whether a model exists in DATA_JUICER_MODELS_CACHE. + If exists, return its full path. Else, download it from cached models links. :param model_name: a specified model name - :param args: optional extra args of model. :param force: Whether to download model forcefully or not, Sometimes the model file maybe incomplete for some reason, so need to download again forcefully. """ - if not os.path.exists(MODEL_PATH): - os.makedirs(MODEL_PATH) + # check for local model + if os.path.exists(model_name): + return model_name + + if not os.path.exists(DJMC): + os.makedirs(DJMC) # check if the specified model exists. If it does not exist, download it - true_model_name = model_name % args - mdp = os.path.join(MODEL_PATH, true_model_name) + cached_model_path = os.path.join(DJMC, model_name) if force: - if os.path.exists(mdp): - os.remove(mdp) + if os.path.exists(cached_model_path): + os.remove(cached_model_path) logger.info( - f'Model [{true_model_name}] invalid, force to downloading...') + f'Model [{cached_model_path}] invalid, force to downloading...' + ) else: logger.info( - f'Model [{true_model_name}] not found . Downloading...') + f'Model [{cached_model_path}] not found . Downloading...') try: - model_link = os.path.join(MODEL_LINKS, true_model_name) - wget.download(model_link, mdp, bar=None) + model_link = os.path.join(MODEL_LINKS, model_name) + wget.download(model_link, cached_model_path, bar=None) except: # noqa: E722 try: backup_model_link = os.path.join( - BACKUP_MODEL_LINKS[model_name], true_model_name) - wget.download(backup_model_link, mdp, bar=None) + get_backup_model_link(model_name), model_name) + wget.download(backup_model_link, cached_model_path, bar=None) except: # noqa: E722 logger.error( - f'Downloading model [{true_model_name}] error. ' - f'Please retry later or download it into {MODEL_PATH} ' + f'Downloading model [{model_name}] error. ' + f'Please retry later or download it into {DJMC} ' f'manually from {model_link} or {backup_model_link} ') exit(1) - return mdp + return cached_model_path -def prepare_fasttext_model(model_name): +def prepare_fasttext_model(model_name='lid.176.bin'): """ Prepare and load a fasttext model. @@ -84,6 +94,7 @@ def prepare_fasttext_model(model_name): :return: model instance. """ import fasttext + logger.info('Loading fasttext language identification model...') try: ft_model = fasttext.load_model(check_model(model_name)) @@ -92,7 +103,7 @@ def prepare_fasttext_model(model_name): return ft_model -def prepare_sentencepiece_model(model_name, lang): +def prepare_sentencepiece_model(lang, name_pattern='{}.sp.model'): """ Prepare and load a sentencepiece model. @@ -101,16 +112,19 @@ def prepare_sentencepiece_model(model_name, lang): :return: model instance. """ import sentencepiece + + model_name = name_pattern.format(lang) + logger.info('Loading sentencepiece model...') sentencepiece_model = sentencepiece.SentencePieceProcessor() try: - sentencepiece_model.load(check_model(model_name, lang)) + sentencepiece_model.load(check_model(model_name)) except: # noqa: E722 - sentencepiece_model.load(check_model(model_name, lang, force=True)) + sentencepiece_model.load(check_model(model_name, force=True)) return sentencepiece_model -def prepare_kenlm_model(model_name, lang): +def prepare_kenlm_model(lang, name_pattern='{}.arpa.bin'): """ Prepare and load a kenlm model. @@ -119,15 +133,18 @@ def prepare_kenlm_model(model_name, lang): :return: model instance. """ import kenlm + + model_name = name_pattern.format(lang) + logger.info('Loading kenlm language model...') try: - kenlm_model = kenlm.Model(check_model(model_name, lang)) + kenlm_model = kenlm.Model(check_model(model_name)) except: # noqa: E722 - kenlm_model = kenlm.Model(check_model(model_name, lang, force=True)) + kenlm_model = kenlm.Model(check_model(model_name, force=True)) return kenlm_model -def prepare_nltk_model(model_name, lang): +def prepare_nltk_model(lang, name_pattern='punkt.{}.pickle'): """ Prepare and load a nltk punkt model. @@ -135,6 +152,7 @@ def prepare_nltk_model(model_name, lang): :param lang: language to render model name :return: model instance. """ + from nltk.data import load nltk_to_punkt = { 'en': 'english', @@ -145,114 +163,74 @@ def prepare_nltk_model(model_name, lang): assert lang in nltk_to_punkt.keys( ), 'lang must be one of the following: {}'.format( list(nltk_to_punkt.keys())) + model_name = name_pattern.format(nltk_to_punkt[lang]) - from nltk.data import load logger.info('Loading nltk punkt split model...') try: - nltk_model = load(check_model(model_name, nltk_to_punkt[lang])) + nltk_model = load(check_model(model_name)) except: # noqa: E722 - nltk_model = load( - check_model(model_name, nltk_to_punkt[lang], force=True)) + nltk_model = load(check_model(model_name, force=True)) return nltk_model -def prepare_huggingface_tokenizer(tokenizer_name): - """ - Prepare and load a tokenizer from HuggingFace. - - :param tokenizer_name: input tokenizer name - :return: a tokenizer instance. - """ - from transformers import AutoTokenizer - logger.info(f'Loading tokenizer {tokenizer_name} from HuggingFace...') - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - trust_remote_code=True) - return tokenizer - - -def prepare_huggingface_clip(clip_name): - """ - Prepare and load a clip and processor from HuggingFace. - - :param clip_name: input clip name - :return: a pair of clip instance and processor instance. - """ - from transformers import CLIPModel, CLIPProcessor - - model = CLIPModel.from_pretrained(clip_name) - processor = CLIPProcessor.from_pretrained(clip_name) - logger.info(f'Loading clip and processor {clip_name} from HuggingFace...') - - return model, processor - - -def prepare_huggingface_blip( - blip_name, - usage=None, -): +def prepare_huggingface_model(model_name_or_path, + return_model=True, + trust_remote_code=False): """ - Prepare and load a blip and processor from HuggingFace. + Prepare and load a HuggingFace model with the correspoding processor. - :param blip_name: input blip name in huggingface hup - :param usage: a string indicating the type for processor and model wrapper - :return: a pair of blip instance and processor instance. + :param model_name: model name or path + :param return_model: return model or not + :param trust_remote_code: passed to transformers + :return: a tuple (model, input processor) if `return_model` is True; + otherwise, only the processor is returned. """ - model = None - processor = None - if usage is None: - usage = 'image_text_retrieval' - if 'blip2' in blip_name: - if usage == 'conditional_generation': - from transformers import (Blip2ForConditionalGeneration, - Blip2Processor) - model = Blip2ForConditionalGeneration.from_pretrained(blip_name) - processor = Blip2Processor.from_pretrained(blip_name) - elif 'blip' in blip_name: - if usage == 'image_text_retrieval': - from transformers import BlipForImageTextRetrieval, BlipProcessor - model = BlipForImageTextRetrieval.from_pretrained(blip_name) - processor = BlipProcessor.from_pretrained(blip_name) - - if model is None or processor is None: - raise NotImplementedError('Unsupported model preparing behavior for ' - f'your given blip_name={blip_name} and ' - f'usage={usage}') - - logger.info(f'Loaded blip and processor {blip_name} from HuggingFace...') - return model, processor - - -def prepare_huggingface_owlvit(owlvit_name): + import transformers + from transformers import (AutoConfig, AutoImageProcessor, AutoProcessor, + AutoTokenizer) + from transformers.models.auto.image_processing_auto import \ + IMAGE_PROCESSOR_MAPPING_NAMES + from transformers.models.auto.processing_auto import \ + PROCESSOR_MAPPING_NAMES + from transformers.models.auto.tokenization_auto import \ + TOKENIZER_MAPPING_NAMES + + config = AutoConfig.from_pretrained(model_name_or_path) + # TODO: What happens when there are more than one? + arch = config.architectures[0] + model_class = getattr(transformers, arch) + model_type = config.model_type + if model_type in PROCESSOR_MAPPING_NAMES: + processor = AutoProcessor.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code) + elif model_type in IMAGE_PROCESSOR_MAPPING_NAMES: + processor = AutoImageProcessor.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code) + elif model_type in TOKENIZER_MAPPING_NAMES: + processor = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code) + else: + processor = None + + if return_model: + model = model_class.from_pretrained(model_name_or_path) + return (model, processor) if return_model else processor + + +def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'): """ - Prepare and load an OwlViT and processor from HuggingFace. + Prepare spacy model for specific language. - :param owlvit_name: input OwlViT name - :return: a pair of OwlViT instance and processor instance. - """ - from transformers import OwlViTForObjectDetection, OwlViTProcessor - - model = OwlViTForObjectDetection.from_pretrained(owlvit_name) - processor = OwlViTProcessor.from_pretrained(owlvit_name) - logger.info(f'Loading OwlViT and processor {owlvit_name} from ' - f'HuggingFace...') - - return (model, processor) - - -def prepare_diversity_model(model_name, lang): - """ - Prepare diversity model for specific language. - - :param model_name: the model name to be loaded. - :param lang: language of diversity model. Should be one of ["zh", + :param lang: language of sapcy model. Should be one of ["zh", "en"] - :return: corresponding diversity model + :return: corresponding spacy model """ import spacy + assert lang in ['zh', 'en'], 'Diversity only support zh and en' - model_name = model_name % lang + model_name = name_pattern.format(lang) logger.info(f'Loading spacy model [{model_name}]...') - compressed_model = '%s.zip' % model_name + compressed_model = '{}.zip'.format(model_name) # decompress the compressed model if it's not decompressed def decompress_model(compressed_model_path): @@ -262,7 +240,7 @@ def decompress_model(compressed_model_path): return decompressed_model_path import zipfile with zipfile.ZipFile(compressed_model_path) as zf: - zf.extractall(MODEL_PATH) + zf.extractall(DJMC) return decompressed_model_path try: @@ -274,65 +252,34 @@ def decompress_model(compressed_model_path): return diversity_model -def prepare_model(lang='en', - model_type='sentencepiece', - model_key=None, - usage=None): - """ - Prepare and load a model or a tokenizer from MODEL_ZOO. - - :param lang: which lang model to load - :param model_type: model or tokenizer type - :param model_key: tokenizer name, only used when - prepare HuggingFace tokenizer - :param usage: detailed usage to indicate some specific type - of the model or the tokenizer - :return: a model or tokenizer instance - """ +MODEL_FUNCTION_MAPPING = { + 'fasttext': prepare_fasttext_model, + 'sentencepiece': prepare_sentencepiece_model, + 'kenlm': prepare_kenlm_model, + 'nltk': prepare_nltk_model, + 'huggingface': prepare_huggingface_model, + 'spacy': prepare_spacy_model, +} - type_to_name = { - 'fasttext': ('lid.176.bin', prepare_fasttext_model), - 'sentencepiece': ('%s.sp.model', prepare_sentencepiece_model), - 'kenlm': ('%s.arpa.bin', prepare_kenlm_model), - 'nltk': ('punkt.%s.pickle', prepare_nltk_model), - 'huggingface': ('%s', prepare_huggingface_tokenizer), - 'hf_clip': ('%s', prepare_huggingface_clip), - 'hf_blip': ('%s', prepare_huggingface_blip), - 'hf_owlvit': ('%s', prepare_huggingface_owlvit), - 'spacy': ('%s_core_web_md-3.5.0', prepare_diversity_model), - } - assert model_type in type_to_name.keys( - ), 'model_type must be one of the following: {}'.format( - list(type_to_name.keys())) - if model_key is None: - model_key = model_type + '_' + lang - if model_key not in MODEL_ZOO.keys(): - model_name, model_func = type_to_name[model_type] - if model_type in ['fasttext']: - MODEL_ZOO[model_key] = model_func(model_name) - elif model_type in ['huggingface', 'hf_clip']: - MODEL_ZOO[model_key] = model_func(model_key) - elif model_type in ['hf_blip']: - MODEL_ZOO[model_key] = model_func(model_key, usage) - elif model_type == 'hf_owlvit': - MODEL_ZOO[model_key] = model_func(model_key) - else: - MODEL_ZOO[model_key] = model_func(model_name, lang) +def prepare_model(model_type, **model_kwargs): + assert (model_type in MODEL_FUNCTION_MAPPING.keys() + ), 'model_type must be one of the following: {}'.format( + list(MODEL_FUNCTION_MAPPING.keys())) + global MODEL_ZOO + model_func = MODEL_FUNCTION_MAPPING[model_type] + model_key = partial(model_func, **model_kwargs) + # always instantiate once for possible caching + model_objects = model_key() + MODEL_ZOO[model_key] = model_objects return model_key -def get_model(model_key, lang='en', model_type='sentencepiece', usage=None): - """ - Get a model or a tokenizer from MODEL_ZOO. - - :param model_key: name of the model or tokenzier - """ +def get_model(model_key=None): + global MODEL_ZOO if model_key is None: + logger.warning('Please specify model_key to get models') return None if model_key not in MODEL_ZOO: - prepare_model(lang=lang, - model_type=model_type, - model_key=model_key, - usage=usage) - return MODEL_ZOO.get(model_key, None) + MODEL_ZOO[model_key] = model_key() + return MODEL_ZOO[model_key] diff --git a/demos/tool_quality_classifier/quality_classifier/qc_utils.py b/demos/tool_quality_classifier/quality_classifier/qc_utils.py index 86187acbd..00174a3e2 100644 --- a/demos/tool_quality_classifier/quality_classifier/qc_utils.py +++ b/demos/tool_quality_classifier/quality_classifier/qc_utils.py @@ -127,13 +127,7 @@ def get_keep_method_udf(keep_method): def tokenize_dataset(ds, tokenizer): - if os.path.exists(tokenizer): - # if it's a local model - tkn = spm.SentencePieceProcessor() - tkn.load(tokenizer) - else: - # else, try to load it from our remote model list - tkn = prepare_sentencepiece_model(tokenizer, ()) + tkn = prepare_sentencepiece_model('', tokenizer) tokenizer_udf = udf(lambda text: tkn.encode_as_pieces(text), ArrayType(StringType())) logger.info('Tokenize texts using specific tokenizer...') diff --git a/tests/ops/filter/test_image_size_filter.py b/tests/ops/filter/test_image_size_filter.py index b9dc8fa01..46cfff62f 100644 --- a/tests/ops/filter/test_image_size_filter.py +++ b/tests/ops/filter/test_image_size_filter.py @@ -109,8 +109,7 @@ def test_all(self): }] tgt_list = [] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(min_size="120kb", max_size="180KB", - any_or_all='all') + op = ImageSizeFilter(min_size="120kb", max_size="180KB", any_or_all='all') self._run_image_size_filter(dataset, tgt_list, op) diff --git a/tools/quality_classifier/qc_utils.py b/tools/quality_classifier/qc_utils.py index 448ea2368..eca8845d8 100644 --- a/tools/quality_classifier/qc_utils.py +++ b/tools/quality_classifier/qc_utils.py @@ -2,7 +2,6 @@ import zipfile import numpy as np -import sentencepiece as spm import wget from loguru import logger from pyspark.ml import Pipeline, PipelineModel @@ -193,13 +192,7 @@ def tokenize_dataset(ds, tokenizer): :return: a dataset with an extra column "words" that stores the tokenized texts """ - if os.path.exists(tokenizer): - # if it's a local model - tkn = spm.SentencePieceProcessor() - tkn.load(tokenizer) - else: - # else, try to load it from our remote model list - tkn = prepare_sentencepiece_model(tokenizer, ()) + tkn = prepare_sentencepiece_model('', tokenizer) # create a PySpark udf to tokenize the dataset tokenizer_udf = udf(lambda text: tkn.encode_as_pieces(text), ArrayType(StringType()))