Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop using model-defined truncation in perplexity calculation #333

Merged
merged 3 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions measurements/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def _info(self):
reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
)

def _compute(self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
def _compute(
self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
):

if device is not None:
assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
Expand All @@ -126,20 +128,20 @@ def _compute(self, data, model_id, batch_size: int = 16, add_start_token: bool =
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if add_start_token:
if add_start_token and max_length:
# leave room for <BOS> token to be added:
assert (
tokenizer.bos_token is not None
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
max_tokenized_len = model.config.max_length - 1
max_tokenized_len = max_length - 1
else:
max_tokenized_len = model.config.max_length
max_tokenized_len = max_length

encodings = tokenizer(
data,
add_special_tokens=False,
padding=True,
truncation=True,
truncation=True if max_tokenized_len else False,
max_length=max_tokenized_len,
return_tensors="pt",
return_attention_mask=True,
Expand Down
12 changes: 7 additions & 5 deletions metrics/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def _info(self):
reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
)

def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
def _compute(
self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
):

if device is not None:
assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
Expand All @@ -126,20 +128,20 @@ def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token:
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if add_start_token:
if add_start_token and max_length:
# leave room for <BOS> token to be added:
assert (
tokenizer.bos_token is not None
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
max_tokenized_len = model.config.max_length - 1
max_tokenized_len = max_length - 1
else:
max_tokenized_len = model.config.max_length
max_tokenized_len = max_length

encodings = tokenizer(
predictions,
add_special_tokens=False,
padding=True,
truncation=True,
truncation=True if max_tokenized_len else False,
max_length=max_tokenized_len,
return_tensors="pt",
return_attention_mask=True,
Expand Down