Skip to content

Commit

Permalink
Refactor/model management (conservative change) (#196)
Browse files Browse the repository at this point in the history
* unify prepare_*
* increase swap
  • Loading branch information
drcege authored Jan 30, 2024
1 parent 1e512c5 commit 543f36c
Show file tree
Hide file tree
Showing 21 changed files with 196 additions and 268 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ jobs:
python-version: "3.8"
cache: 'pip'
cache-dependency-path: 'environments/**_requires.txt'
- name: Increase swapfile
run: |
df -h
free -h
sudo swapoff -a
sudo fallocate -l 12G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
sudo swapon --show
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def convert_to_jsonl(df):

@st.cache_data
def get_diversity_model(lang):
model_key = prepare_model(lang, 'spacy')
model_key = prepare_model('spacy', lang=lang)
diversity_model = MODEL_ZOO.get(model_key)
return diversity_model

Expand Down
6 changes: 3 additions & 3 deletions data_juicer/analysis/diversity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import spacy
from loguru import logger

from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import get_model, prepare_model


# Modify from self_instruct, please refer to
Expand Down Expand Up @@ -110,8 +110,8 @@ def compute(self, lang_or_model=None, column_name='text'):
# load diversity model
lang_or_model = lang_or_model if lang_or_model else self.lang_or_model
if isinstance(lang_or_model, str):
diversity_model = MODEL_ZOO.get(
prepare_model(lang_or_model, 'spacy'))
model_key = prepare_model('spacy', lang=lang_or_model)
diversity_model = get_model(model_key)
else:
diversity_model = lang_or_model

Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self,
if tokenization:
self.model_key = prepare_model(
model_type='huggingface',
model_key='EleutherAI/pythia-6.9b-deduped')
model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)

def compute_stats(self, sample):
if self.tokenization:
Expand All @@ -60,7 +61,7 @@ def compute_stats(self, sample):
alpha_count = sum(
map(lambda char: 1
if char.isalpha() else 0, sample[self.text_key]))
tokenizer = get_model(self.model_key, model_type='huggingface')
tokenizer = get_model(self.model_key)
token_count = len(
get_words_from_document(
sample[self.text_key],
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(self,
val for vals in self.FLAGGED_WORDS.values() for val in vals
]
if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -84,9 +84,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key,
lang=self.lang,
model_type='sentencepiece')
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self,
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='hf_blip', model_key=hf_blip)
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_blip)
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip
Expand Down Expand Up @@ -118,7 +119,7 @@ def compute_stats(self, sample, context=False):
truncation=True,
max_length=model.config.text_config.
max_position_embeddings,
padding=True)
padding=True).to(model.device)

outputs = model(**inputs)
itm_scores = outputs.itm_score.detach().cpu().softmax(
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/image_text_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(self,
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='hf_clip', model_key=hf_clip)
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_clip)
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip
Expand Down Expand Up @@ -118,7 +119,7 @@ def compute_stats(self, sample, context=False):
truncation=True,
max_length=model.config.text_config.
max_position_embeddings,
padding=True)
padding=True).to(model.device)

outputs = model(**inputs)
chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/language_id_score_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def compute_stats(self, sample):
return sample

text = sample[self.text_key].lower().replace('\n', ' ')
ft_model = get_model(self.model_key, model_type='fasttext')
ft_model = get_model(self.model_key)
if ft_model is None:
err_msg = 'Model not loaded. Please retry later.'
logger.error(err_msg)
Expand Down
11 changes: 5 additions & 6 deletions data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self,
super().__init__(*args, **kwargs)
self.max_ppl = max_ppl
self.lang = lang
self.sp_model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.kl_model_key = prepare_model(lang=lang, model_type='kenlm')
self.sp_model_key = prepare_model(model_type='sentencepiece',
lang=lang)
self.kl_model_key = prepare_model(model_type='kenlm', lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -56,8 +56,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.sp_model_key, self.lang,
'sentencepiece')
tokenizer = get_model(self.sp_model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand All @@ -66,7 +65,7 @@ def compute_stats(self, sample, context=False):
text = ' '.join(words)
# compute perplexity
logits, length = 0, 0
kenlm_model = get_model(self.kl_model_key, self.lang, 'kenlm')
kenlm_model = get_model(self.kl_model_key)
for line in text.splitlines():
logits += kenlm_model.score(line)
length += (len(line.split()) + 1)
Expand Down
10 changes: 4 additions & 6 deletions data_juicer/ops/filter/phrase_grounding_recall_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def __init__(self,
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_type = 'hf_owlvit'
self.model_key = prepare_model(model_type=self.model_type,
model_key=hf_owlvit)
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_owlvit)
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip
Expand Down Expand Up @@ -164,8 +163,7 @@ def compute_stats(self, sample, context=False):
text = sample[self.text_key]
offset = 0
recalls = []
model, processor = get_model(self.model_key,
model_type=self.model_type)
model, processor = get_model(self.model_key)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)
Expand Down Expand Up @@ -195,7 +193,7 @@ def compute_stats(self, sample, context=False):
images=images_this_chunk,
return_tensors='pt',
padding=True,
truncation=True)
truncation=True).to(model.device)

with torch.no_grad():
outputs = model(**inputs)
Expand Down
9 changes: 3 additions & 6 deletions data_juicer/ops/filter/stopwords_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self,
self.words_aug_group_sizes = words_aug_group_sizes
self.words_aug_join_char = words_aug_join_char
self.model_key = None
self.lang = lang

self.STOPWORDS = load_words_asset(words_dir=stopwords_dir,
words_type='stopwords')
Expand All @@ -70,8 +69,8 @@ def __init__(self,
val for vals in self.STOPWORDS.values() for val in vals
]
if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -83,9 +82,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key,
lang=self.lang,
model_type='sentencepiece')
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/token_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def __init__(self,
self.max_num = max_num
self.hf_tokenizer = hf_tokenizer
self.model_key = prepare_model(model_type='huggingface',
model_key=hf_tokenizer)
model_name_or_path=hf_tokenizer,
return_model=False)

def compute_stats(self, sample):
# check if it's computed already
if StatsKeys.num_token in sample[Fields.stats]:
return sample

tokenizer = get_model(self.model_key, model_type='huggingface')
tokenizer = get_model(self.model_key)
tokens = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.tokenize if tokenizer else None)
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/word_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(self,
self.lang = lang

if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -63,9 +63,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key,
lang=self.lang,
model_type='sentencepiece')
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/word_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self,
self.lang = lang

if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -69,9 +69,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key,
lang=self.lang,
model_type='sentencepiece')
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
23 changes: 10 additions & 13 deletions data_juicer/ops/mapper/generate_caption_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,8 @@ def __init__(self,
f'Can only be one of '
f'["random_any", "similar_one_simhash", "all"].')

self.model_key = prepare_model(model_type='hf_blip',
model_key=hf_blip2,
usage='conditional_generation')
model, img_processor = get_model(model_key=self.model_key,
usage='conditional_generation')
self.model_in_ctx = model
self.img_processor_in_ctx = img_processor
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_blip2)
self.caption_num = caption_num
self.keep_candidate_mode = keep_candidate_mode
self.keep_original_sample = keep_original_sample
Expand Down Expand Up @@ -151,6 +146,8 @@ def _process_single_sample(self, ori_sample):
# the generated text will be placed following each SpecialTokens.img
# and the original special tokens are kept in an order-preserving way.

model, processor = get_model(self.model_key)

# do generation for each image chunk by chunk
for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc):
# skip empty chunks or contents after the last eoc token
Expand Down Expand Up @@ -182,13 +179,13 @@ def _process_single_sample(self, ori_sample):
else:
prompt_texts = None

inputs = self.img_processor_in_ctx(images=image_chunk,
text=prompt_texts,
return_tensors='pt')
inputs = processor(images=image_chunk,
text=prompt_texts,
return_tensors='pt').to(model.device)
for i in range(self.caption_num):
generated_ids = self.model_in_ctx.generate(**inputs,
do_sample=True)
generated_text = self.img_processor_in_ctx.batch_decode(
generated_ids = model.generate(**inputs,
do_sample=True).to(model.device)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)
generated_text_candidates_single_chunk[i] = generated_text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self,
self.substrings = substrings
self.lang = lang
if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def should_keep_word_with_incorrect_substrings(self, word, substrings):
word = strip(word, SPECIAL_CHARACTERS)
Expand All @@ -50,9 +50,7 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings):

def process(self, sample):
if self.tokenization:
tokenizer = get_model(self.model_key,
lang=self.lang,
model_type='sentencepiece')
tokenizer = get_model(self.model_key)
sentences = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
6 changes: 2 additions & 4 deletions data_juicer/ops/mapper/sentence_split_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ def __init__(self, lang: str = 'en', *args, **kwargs):
"""
super().__init__(*args, **kwargs)
self.lang = lang
self.model_key = prepare_model(lang=lang, model_type='nltk')
self.model_key = prepare_model(model_type='nltk', lang=lang)

def process(self, sample):

nltk_model = get_model(self.model_key,
lang=self.lang,
model_type='nltk')
nltk_model = get_model(self.model_key)
sample[self.text_key] = get_sentences_from_document(
sample[self.text_key],
model_func=nltk_model.tokenize if nltk_model else None)
Expand Down
Loading

0 comments on commit 543f36c

Please sign in to comment.