-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refector TransformersTokenizer
and change fallback behavior for byte_tokens
#973
Changes from 10 commits
b4ce184
9b66001
b28f447
f6fd1df
9a3c918
7dde112
d4dfd72
60554f7
8ee05cd
3a6c55f
9272c7d
4ec2b7b
979488f
7e08688
37d824f
2e49e70
0b8536e
28bf6fb
86b6c89
e4a4f45
cb1462e
5d28242
023dcd7
fbd68b2
5ee72e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import os | ||
import re | ||
import textwrap | ||
import warnings | ||
|
||
from typing import Sequence, Union | ||
|
||
|
@@ -38,7 +39,7 @@ | |
class TransformersTokenizer(Tokenizer): | ||
def __init__( | ||
self, | ||
model, | ||
model: Union[str, "transformers_package.PreTrainedModel"], | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
|
@@ -49,118 +50,20 @@ def __init__( | |
**kwargs, | ||
): | ||
if transformers_tokenizer is None: | ||
transformers_tokenizer = self._tokenizer(model, **kwargs) | ||
if isinstance(model, str): | ||
transformers_tokenizer = self._tokenizer(model, **kwargs) | ||
else: | ||
raise ValueError( | ||
"A model object was passed in, but no tokenizer was provided. Please provide a tokenizer." | ||
) | ||
else: | ||
is_ptt = isinstance(transformers_tokenizer, transformers_package.PreTrainedTokenizer) | ||
is_ptt_fast = isinstance( | ||
transformers_tokenizer, transformers_package.PreTrainedTokenizerFast | ||
) | ||
assert is_ptt or is_ptt_fast | ||
|
||
self._orig_tokenizer = transformers_tokenizer | ||
special_tokens_map = { | ||
id: token for token, id in transformers_tokenizer.get_added_vocab().items() | ||
} | ||
|
||
# build out the set of byte_string tokens | ||
byte_tokens = [b""] * len(transformers_tokenizer) | ||
if hasattr(transformers_tokenizer, "byte_decoder"): | ||
byte_decoder = transformers_tokenizer.byte_decoder | ||
|
||
for i in range(len(transformers_tokenizer)): | ||
byte_coded = bytes( | ||
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)] | ||
) | ||
byte_tokens[i] = byte_coded | ||
|
||
elif hasattr(transformers_tokenizer, "sp_model"): | ||
space_prefix = "▁".encode() | ||
for i in range(len(transformers_tokenizer)): | ||
if i in special_tokens_map: | ||
byte_coded = special_tokens_map[i].encode() | ||
else: | ||
byte_coded = re.sub( | ||
rb"<0x(..)>", | ||
lambda x: bytes.fromhex(x[1].decode()), | ||
transformers_tokenizer.sp_model.id_to_piece(i).encode(), | ||
) | ||
byte_tokens[i] = byte_coded.replace(space_prefix, b" ") | ||
|
||
elif hasattr(transformers_tokenizer, "get_vocab"): | ||
vocab = transformers_tokenizer.get_vocab() | ||
byte_encoder = self._bytes_to_unicode() | ||
byte_decoder = {v: k for k, v in byte_encoder.items()} | ||
|
||
for i in range(len(transformers_tokenizer)): | ||
if i in special_tokens_map: | ||
byte_coded = special_tokens_map[i].encode() | ||
else: | ||
token = transformers_tokenizer.convert_ids_to_tokens(i) | ||
if isinstance(token, bytes): | ||
byte_coded = token | ||
elif isinstance(token, str): | ||
if hasattr(transformers_tokenizer, "convert_tokens_to_string"): | ||
token_str = transformers_tokenizer.convert_tokens_to_string([token]) | ||
roundtrip_id = transformers_tokenizer.encode(token_str)[0] | ||
if roundtrip_id == i: | ||
byte_coded = token_str.encode() | ||
else: | ||
byte_coded = bytes([byte_decoder[c] for c in token]) | ||
else: | ||
byte_coded = token.encode() | ||
else: | ||
raise ValueError(f"Unexpected token type: {type(token)}") | ||
byte_tokens[i] = byte_coded | ||
|
||
else: | ||
byte_decoder = transformers_package.AutoTokenizer.from_pretrained( | ||
"gpt2", use_fast=False | ||
).byte_decoder # fall back to gpt2 mapping | ||
|
||
# some special tokens may not have their whitespace encoded... | ||
byte_decoder[" "] = 32 | ||
byte_decoder["\n"] = 10 | ||
byte_decoder["\r"] = 13 | ||
byte_decoder["\t"] = 9 | ||
byte_decoder["▁"] = 32 | ||
|
||
# run a quick spot check to verify we can rebuild complex multi-token unicode symbols | ||
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨" | ||
reconstructed = b"" | ||
try: | ||
input_ids = transformers_tokenizer(s)["input_ids"] | ||
for i in input_ids: | ||
nxt_bytes = [] | ||
token_str = transformers_tokenizer.convert_ids_to_tokens(i) | ||
for c in token_str: | ||
nxt_bytes.append(byte_decoder[c]) | ||
reconstructed += bytes(nxt_bytes) | ||
# Check if the tokenizer has a bos_token attribute, and if it does, check | ||
# if it's at the start of the reconstructed bytes | ||
# Some tokenizers add this automatically as part of the call function, so | ||
# we need to remove it to compare | ||
if hasattr(transformers_tokenizer, "bos_token") and reconstructed.startswith( | ||
transformers_tokenizer.bos_token.encode() | ||
): | ||
reconstructed = reconstructed[len(transformers_tokenizer.bos_token) :] | ||
except Exception as e: | ||
msg = textwrap.dedent( | ||
f""" | ||
The tokenizer being used is unable to convert a special character in {s}. | ||
For models with sentencepiece based tokenizers (e.g. llama, phi-3-mini), | ||
installing sentencepiece often fixes this issue (pip install sentencepiece). | ||
""" | ||
) | ||
raise ValueError(msg) from e | ||
assert ( | ||
reconstructed.decode() == s | ||
), "The passed tokenizer does not have a byte_decoder property and using a standard gpt2 byte_decoder fails!" | ||
|
||
for i in range(len(transformers_tokenizer)): | ||
byte_coded = bytes( | ||
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)] | ||
) | ||
byte_tokens[i] = byte_coded | ||
byte_tokens = self._byte_tokens(transformers_tokenizer) | ||
|
||
# Chat Template logic | ||
if chat_template is None and hasattr(self._orig_tokenizer, "chat_template"): | ||
|
@@ -174,6 +77,204 @@ def __init__( | |
transformers_tokenizer.eos_token_id, | ||
) | ||
|
||
def _tokenizer(self, model: str, **kwargs) -> Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
]: | ||
# make sure transformers is installed | ||
if not has_transformers: | ||
raise Exception("Please install transformers with `pip install transformers`") | ||
|
||
try: | ||
tokenizer = transformers_package.AutoTokenizer.from_pretrained( | ||
model, use_fast=False, **kwargs | ||
) | ||
except ImportError as e: | ||
# HuggingFace needs us to install something (sentencepiece, protobuf, etc. for some non-fast tokenizers) | ||
raise RuntimeError( | ||
f"Could not load tokenizer for model {model}. Please install the necessary dependencies for the tokenizer (see traceback for info)." | ||
) from e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll now fail early rather than falling back to the fast tokenizer if the cause of the exception here is a missing dependency There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting -- was guidance just eating a silent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we were trying to use the Raising the Best case for me would be to only catch a specific exception class, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I dislike the runtime error. We should try to be as consistent as possible in this repo if we are raising for "you are missing a dependency" reasons. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on dropping the runtime error |
||
except Exception as e: | ||
# Fall back for other exceptions | ||
warnings.warn(f"Falling back to fast tokenizer. Could not load tokenizer for model {model} due to exception {e.__class__.__name__}: {e}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a warning. 100% unsure of whether this branch is ever taken now... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good question. I agree -- it feels like we shouldn't hit this fallback anymore, particularly on a pathway dependent on import issues. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As Richard said, this is sadly just really hard to test. At least without making the community do it for us... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... or blowing up our test matrix even more |
||
tokenizer = transformers_package.AutoTokenizer.from_pretrained( | ||
model, use_fast=True, **kwargs | ||
) # fall back to the fast tokenizer | ||
|
||
return tokenizer | ||
|
||
def _byte_tokens( | ||
self, | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
], | ||
) -> list[bytes]: | ||
|
||
if ( | ||
hasattr(transformers_tokenizer, "byte_decoder") | ||
and self._byte_decoder_has_all_bytes( | ||
transformers_tokenizer.byte_decoder, | ||
transformers_tokenizer.get_vocab() | ||
) | ||
): | ||
return self._byte_tokens_from_byte_decoder(transformers_tokenizer) | ||
|
||
if hasattr(transformers_tokenizer, "sp_model"): | ||
return self._byte_tokens_from_sp_model(transformers_tokenizer) | ||
|
||
try: | ||
return self._byte_tokens_from_vocab(transformers_tokenizer) | ||
except ValueError: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, should we issue a warning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, ditto on the language |
||
|
||
return self._byte_tokens_fallback(transformers_tokenizer) | ||
|
||
def _byte_tokens_from_byte_decoder( | ||
self, | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
], | ||
) -> list[bytes]: | ||
byte_tokens = [b""] * len(transformers_tokenizer) | ||
byte_decoder: dict[str, int] = transformers_tokenizer.byte_decoder | ||
for i in range(len(transformers_tokenizer)): | ||
byte_coded = bytes( | ||
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)] | ||
) | ||
byte_tokens[i] = byte_coded | ||
return byte_tokens | ||
|
||
def _byte_tokens_from_sp_model( | ||
self, | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
], | ||
) -> list[bytes]: | ||
byte_tokens = [b""] * len(transformers_tokenizer) | ||
special_tokens_map = { | ||
id: token for token, id in transformers_tokenizer.get_added_vocab().items() | ||
} | ||
space_prefix = "▁".encode() | ||
for i in range(len(transformers_tokenizer)): | ||
if i in special_tokens_map: | ||
byte_coded = special_tokens_map[i].encode() | ||
else: | ||
byte_coded = re.sub( | ||
rb"<0x(..)>", | ||
lambda x: bytes.fromhex(x[1].decode()), | ||
transformers_tokenizer.sp_model.id_to_piece(i).encode(), | ||
) | ||
byte_tokens[i] = byte_coded.replace(space_prefix, b" ") | ||
return byte_tokens | ||
|
||
def _byte_tokens_from_vocab( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This path used to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think originally I may have used that method to get the tokenizer key/value pairs before replacing it with just iterating through a range the same length as the tokenizer, I guess I didn't realize that we just weren't using the vocabulary at all anymore, so I think removing it makes sense if the functionality stays the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks! |
||
self, | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
], | ||
) -> list[bytes]: | ||
byte_tokens = [b""] * len(transformers_tokenizer) | ||
special_tokens_map = { | ||
id: token for token, id in transformers_tokenizer.get_added_vocab().items() | ||
} | ||
byte_encoder = self._bytes_to_unicode() | ||
byte_decoder = {v: k for k, v in byte_encoder.items()} | ||
|
||
for i in range(len(transformers_tokenizer)): | ||
if i in special_tokens_map: | ||
byte_coded = special_tokens_map[i].encode() | ||
else: | ||
token = transformers_tokenizer.convert_ids_to_tokens(i) | ||
if isinstance(token, bytes): | ||
byte_coded = token | ||
elif isinstance(token, str): | ||
if hasattr(transformers_tokenizer, "convert_tokens_to_string"): | ||
token_str = transformers_tokenizer.convert_tokens_to_string([token]) | ||
encoded_str = transformers_tokenizer.encode(token_str) | ||
if len(encoded_str) != 1: | ||
raise ValueError(f"Round-trip encoding of tokens [{token}] failed! Got {encoded_str}") | ||
roundtrip_id = encoded_str[0] | ||
Comment on lines
+231
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, here's where that |
||
if roundtrip_id == i: | ||
byte_coded = token_str.encode() | ||
else: | ||
byte_coded = bytes([byte_decoder[c] for c in token]) | ||
else: | ||
byte_coded = token.encode() | ||
else: | ||
raise ValueError(f"Unexpected token type: {type(token)}") | ||
byte_tokens[i] = byte_coded | ||
return byte_tokens | ||
|
||
def _byte_tokens_fallback( | ||
self, | ||
transformers_tokenizer: Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
], | ||
) -> list[bytes]: | ||
byte_tokens = [b""] * len(transformers_tokenizer) | ||
byte_decoder: dict[str, int] = transformers_package.AutoTokenizer.from_pretrained( | ||
"gpt2", use_fast=False | ||
).byte_decoder # fall back to gpt2 mapping | ||
|
||
# some special tokens may not have their whitespace encoded... | ||
byte_decoder[" "] = 32 | ||
byte_decoder["\n"] = 10 | ||
byte_decoder["\r"] = 13 | ||
byte_decoder["\t"] = 9 | ||
byte_decoder["▁"] = 32 | ||
|
||
# run a quick spot check to verify we can rebuild complex multi-token unicode symbols | ||
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨" | ||
reconstructed = b"" | ||
try: | ||
input_ids = transformers_tokenizer(s)["input_ids"] | ||
for i in input_ids: | ||
nxt_bytes = [] | ||
token_str = transformers_tokenizer.convert_ids_to_tokens(i) | ||
for c in token_str: | ||
nxt_bytes.append(byte_decoder[c]) | ||
reconstructed += bytes(nxt_bytes) | ||
# Check if the tokenizer has a bos_token attribute, and if it does, check | ||
# if it's at the start of the reconstructed bytes | ||
# Some tokenizers add this automatically as part of the call function, so | ||
# we need to remove it to compare | ||
if hasattr(transformers_tokenizer, "bos_token") and reconstructed.startswith( | ||
transformers_tokenizer.bos_token.encode() | ||
): | ||
reconstructed = reconstructed[len(transformers_tokenizer.bos_token) :] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, these reconstitution checks should happen more generally than just in the fast-fallback path. Very possible that guidance wasn't doing this before, but I think it should happen after all the paths are exhausted and a single method is picked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe it was doing it in any case but the fallback, but I could be wrong. But agreed with you that this is reasonable to do always... at least to give a warning if not an exception |
||
except Exception as e: | ||
msg = textwrap.dedent( | ||
f""" | ||
The tokenizer being used is unable to convert a special character in {s}. | ||
For models with sentencepiece based tokenizers (e.g. llama, phi-3-mini), | ||
installing sentencepiece often fixes this issue (pip install sentencepiece). | ||
""" | ||
) | ||
raise ValueError(msg) from e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure of whether this method will ever be called now that we "fail early" with respect to the fast tokenizer. If we hit this exception, the message may be wrong now, as missing an import probably isn't the cause of the exception anymore... |
||
assert ( | ||
reconstructed.decode() == s | ||
), "The passed tokenizer does not have a byte_decoder property and using a standard gpt2 byte_decoder fails!" | ||
|
||
for i in range(len(transformers_tokenizer)): | ||
byte_coded = bytes( | ||
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)] | ||
) | ||
byte_tokens[i] = byte_coded | ||
return byte_tokens | ||
|
||
def _byte_decoder_has_all_bytes(self, byte_decoder: dict[str, int], vocab: dict[str, int]) -> bool: | ||
# This is here because some tokenizers are bad and don't have all the bytes (I'm looking at you, microsoft/phi2) | ||
all_bytes = set() | ||
for x in vocab.keys(): | ||
for y in x: | ||
all_bytes.add(y) | ||
return set(byte_decoder.keys()) >= all_bytes | ||
|
||
def _bytes_to_unicode(self): | ||
bs = ( | ||
list(range(ord("!"), ord("~") + 1)) | ||
|
@@ -190,38 +291,6 @@ def _bytes_to_unicode(self): | |
cs = [chr(n) for n in cs] | ||
return dict(zip(bs, cs)) | ||
|
||
def _tokenizer(self, model, **kwargs) -> Union[ | ||
"transformers_package.PreTrainedTokenizer", | ||
"transformers_package.PreTrainedTokenizerFast", | ||
]: | ||
# intantiate the tokenizer | ||
if isinstance(model, str): | ||
# make sure transformers is installed | ||
if not has_transformers: | ||
raise Exception("Please install transformers with `pip install transformers`") | ||
|
||
try: | ||
tokenizer = transformers_package.AutoTokenizer.from_pretrained( | ||
model, use_fast=False, **kwargs | ||
) | ||
# This is here because some tokenizers are bad and don't have all the bytes (I'm looking at you, microsoft/phi2) | ||
if hasattr(tokenizer, "byte_decoder"): | ||
all_bytes = set() | ||
for x in tokenizer.get_vocab().keys(): | ||
for y in x: | ||
all_bytes.add(y) | ||
assert set(tokenizer.byte_decoder.keys()).intersection(all_bytes) == all_bytes | ||
except: | ||
tokenizer = transformers_package.AutoTokenizer.from_pretrained( | ||
model, use_fast=True, **kwargs | ||
) # fall back to the fast tokenizer | ||
|
||
assert ( | ||
tokenizer is not None | ||
), "You must give a model name when you provide a tokenizer object!" | ||
|
||
return tokenizer | ||
|
||
def encode(self, byte_string: bytes) -> Sequence[int]: | ||
assert isinstance(byte_string, bytes) | ||
# HF tokenizers take in strings apparently | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we asserted
_byte_decoder_has_all_bytes
here, potentially causing a fallback tofast
. I instead moved that check into_byte_tokens
.Note: we could call
_byte_tokens
inside of this try/except instead of further up the stack if we want to fall-back tofast
on failure... I'm not sure what semantics are better...