Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refector TransformersTokenizer and change fallback behavior for byte_tokens #973

Merged
merged 25 commits into from
Aug 16, 2024
Merged
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b4ce184
Encapsulate byte_tokens logic
hudson-ai Jul 30, 2024
9b66001
split different ways of getting byte_tokens into separate methods
hudson-ai Jul 30, 2024
b28f447
annotations
hudson-ai Jul 30, 2024
f6fd1df
no need to assert byte decoder has all bytes unless we're trying to g…
hudson-ai Jul 30, 2024
9a3c918
Slightly more informative exception
hudson-ai Jul 30, 2024
7dde112
mypy says we always have a grammar... instead do unconditionally in t…
hudson-ai Jul 30, 2024
d4dfd72
model must be a str by the time we call _tokenizer
hudson-ai Jul 30, 2024
60554f7
reorganize methods for readability
hudson-ai Jul 30, 2024
8ee05cd
Add some comments about alternate except behavior
hudson-ai Jul 30, 2024
3a6c55f
Add special case for ImportError; warn user when falling back to fast…
hudson-ai Jul 30, 2024
9272c7d
raise ImportError
hudson-ai Aug 8, 2024
4ec2b7b
just reraise caught ImportError (no RuntimeError)
hudson-ai Aug 8, 2024
979488f
Merge branch 'main' into transformers_fast_slow
hudson-ai Aug 12, 2024
7e08688
reuse work in _byte_tokens_from_byte_decoder
hudson-ai Aug 12, 2024
37d824f
factor out check_byte_decoder to reuse across cases
hudson-ai Aug 12, 2024
2e49e70
more informative method name
hudson-ai Aug 12, 2024
0b8536e
annotations
hudson-ai Aug 12, 2024
28bf6fb
more informative method name
hudson-ai Aug 12, 2024
86b6c89
reorder methods
hudson-ai Aug 12, 2024
e4a4f45
check byte decoder has all bytes in both byte_decoder branches
hudson-ai Aug 12, 2024
cb1462e
encapsulate both check_byte_decoder funcs into one
hudson-ai Aug 12, 2024
5d28242
get_vocab
hudson-ai Aug 13, 2024
023dcd7
check right byte decoder
hudson-ai Aug 13, 2024
fbd68b2
try to build byte_tokens in the slow branch and fall back to fast bra…
hudson-ai Aug 15, 2024
5ee72e2
add more warnings
hudson-ai Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 207 additions & 138 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import textwrap
import warnings

from typing import Sequence, Union

Expand Down Expand Up @@ -38,7 +39,7 @@
class TransformersTokenizer(Tokenizer):
def __init__(
self,
model,
model: Union[str, "transformers_package.PreTrainedModel"],
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
Expand All @@ -49,118 +50,20 @@ def __init__(
**kwargs,
):
if transformers_tokenizer is None:
transformers_tokenizer = self._tokenizer(model, **kwargs)
if isinstance(model, str):
transformers_tokenizer = self._tokenizer(model, **kwargs)
else:
raise ValueError(
"A model object was passed in, but no tokenizer was provided. Please provide a tokenizer."
)
else:
is_ptt = isinstance(transformers_tokenizer, transformers_package.PreTrainedTokenizer)
is_ptt_fast = isinstance(
transformers_tokenizer, transformers_package.PreTrainedTokenizerFast
)
assert is_ptt or is_ptt_fast

self._orig_tokenizer = transformers_tokenizer
special_tokens_map = {
id: token for token, id in transformers_tokenizer.get_added_vocab().items()
}

# build out the set of byte_string tokens
byte_tokens = [b""] * len(transformers_tokenizer)
if hasattr(transformers_tokenizer, "byte_decoder"):
byte_decoder = transformers_tokenizer.byte_decoder

for i in range(len(transformers_tokenizer)):
byte_coded = bytes(
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)]
)
byte_tokens[i] = byte_coded

elif hasattr(transformers_tokenizer, "sp_model"):
space_prefix = "▁".encode()
for i in range(len(transformers_tokenizer)):
if i in special_tokens_map:
byte_coded = special_tokens_map[i].encode()
else:
byte_coded = re.sub(
rb"<0x(..)>",
lambda x: bytes.fromhex(x[1].decode()),
transformers_tokenizer.sp_model.id_to_piece(i).encode(),
)
byte_tokens[i] = byte_coded.replace(space_prefix, b" ")

elif hasattr(transformers_tokenizer, "get_vocab"):
vocab = transformers_tokenizer.get_vocab()
byte_encoder = self._bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(len(transformers_tokenizer)):
if i in special_tokens_map:
byte_coded = special_tokens_map[i].encode()
else:
token = transformers_tokenizer.convert_ids_to_tokens(i)
if isinstance(token, bytes):
byte_coded = token
elif isinstance(token, str):
if hasattr(transformers_tokenizer, "convert_tokens_to_string"):
token_str = transformers_tokenizer.convert_tokens_to_string([token])
roundtrip_id = transformers_tokenizer.encode(token_str)[0]
if roundtrip_id == i:
byte_coded = token_str.encode()
else:
byte_coded = bytes([byte_decoder[c] for c in token])
else:
byte_coded = token.encode()
else:
raise ValueError(f"Unexpected token type: {type(token)}")
byte_tokens[i] = byte_coded

else:
byte_decoder = transformers_package.AutoTokenizer.from_pretrained(
"gpt2", use_fast=False
).byte_decoder # fall back to gpt2 mapping

# some special tokens may not have their whitespace encoded...
byte_decoder[" "] = 32
byte_decoder["\n"] = 10
byte_decoder["\r"] = 13
byte_decoder["\t"] = 9
byte_decoder["▁"] = 32

# run a quick spot check to verify we can rebuild complex multi-token unicode symbols
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨"
reconstructed = b""
try:
input_ids = transformers_tokenizer(s)["input_ids"]
for i in input_ids:
nxt_bytes = []
token_str = transformers_tokenizer.convert_ids_to_tokens(i)
for c in token_str:
nxt_bytes.append(byte_decoder[c])
reconstructed += bytes(nxt_bytes)
# Check if the tokenizer has a bos_token attribute, and if it does, check
# if it's at the start of the reconstructed bytes
# Some tokenizers add this automatically as part of the call function, so
# we need to remove it to compare
if hasattr(transformers_tokenizer, "bos_token") and reconstructed.startswith(
transformers_tokenizer.bos_token.encode()
):
reconstructed = reconstructed[len(transformers_tokenizer.bos_token) :]
except Exception as e:
msg = textwrap.dedent(
f"""
The tokenizer being used is unable to convert a special character in {s}.
For models with sentencepiece based tokenizers (e.g. llama, phi-3-mini),
installing sentencepiece often fixes this issue (pip install sentencepiece).
"""
)
raise ValueError(msg) from e
assert (
reconstructed.decode() == s
), "The passed tokenizer does not have a byte_decoder property and using a standard gpt2 byte_decoder fails!"

for i in range(len(transformers_tokenizer)):
byte_coded = bytes(
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)]
)
byte_tokens[i] = byte_coded
byte_tokens = self._byte_tokens(transformers_tokenizer)

# Chat Template logic
if chat_template is None and hasattr(self._orig_tokenizer, "chat_template"):
Expand All @@ -174,6 +77,204 @@ def __init__(
transformers_tokenizer.eos_token_id,
)

def _tokenizer(self, model: str, **kwargs) -> Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
]:
# make sure transformers is installed
if not has_transformers:
raise Exception("Please install transformers with `pip install transformers`")

try:
tokenizer = transformers_package.AutoTokenizer.from_pretrained(
model, use_fast=False, **kwargs
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we asserted _byte_decoder_has_all_bytes here, potentially causing a fallback to fast. I instead moved that check into _byte_tokens.

Note: we could call _byte_tokens inside of this try/except instead of further up the stack if we want to fall-back to fast on failure... I'm not sure what semantics are better...

except ImportError as e:
# HuggingFace needs us to install something (sentencepiece, protobuf, etc. for some non-fast tokenizers)
raise RuntimeError(
f"Could not load tokenizer for model {model}. Please install the necessary dependencies for the tokenizer (see traceback for info)."
) from e
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll now fail early rather than falling back to the fast tokenizer if the cause of the exception here is a missing dependency

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting -- was guidance just eating a silent ImportFailure with e.g. attempting to load phi-3 without sentencepiece? I really like this change if we can catch those reasonably reliably.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we were trying to use the fast=False tokenizer and had a catch-all except, which would fall back to the fast tokenizer.

Raising the ImportError seems like a good idea to me, as it gives the user something really actionable and specific to do to fix the issue. I'm still falling back to the fast tokenizer for any other exception, but maybe that's a bad idea if we know they are generally unreliable...

Best case for me would be to only catch a specific exception class, e.g. NotImplementedError if that's what's thrown when there is no non-fast implementation or something... Unbounded except statements that silently fall back are scary (hence me adding a warning...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I dislike the runtime error. We should try to be as consistent as possible in this repo if we are raising for "you are missing a dependency" reasons.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on dropping the runtime error

except Exception as e:
# Fall back for other exceptions
warnings.warn(f"Falling back to fast tokenizer. Could not load tokenizer for model {model} due to exception {e.__class__.__name__}: {e}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a warning. 100% unsure of whether this branch is ever taken now...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good question. I agree -- it feels like we shouldn't hit this fallback anymore, particularly on a pathway dependent on import issues.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Richard said, this is sadly just really hard to test. At least without making the community do it for us...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... or blowing up our test matrix even more

tokenizer = transformers_package.AutoTokenizer.from_pretrained(
model, use_fast=True, **kwargs
) # fall back to the fast tokenizer

return tokenizer

def _byte_tokens(
self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:

if (
hasattr(transformers_tokenizer, "byte_decoder")
and self._byte_decoder_has_all_bytes(
transformers_tokenizer.byte_decoder,
transformers_tokenizer.get_vocab()
)
):
return self._byte_tokens_from_byte_decoder(transformers_tokenizer)

if hasattr(transformers_tokenizer, "sp_model"):
return self._byte_tokens_from_sp_model(transformers_tokenizer)

try:
return self._byte_tokens_from_vocab(transformers_tokenizer)
except ValueError:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, should we issue a warning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, ditto on the language


return self._byte_tokens_fallback(transformers_tokenizer)

def _byte_tokens_from_byte_decoder(
self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:
byte_tokens = [b""] * len(transformers_tokenizer)
byte_decoder: dict[str, int] = transformers_tokenizer.byte_decoder
for i in range(len(transformers_tokenizer)):
byte_coded = bytes(
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)]
)
byte_tokens[i] = byte_coded
return byte_tokens

def _byte_tokens_from_sp_model(
self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:
byte_tokens = [b""] * len(transformers_tokenizer)
special_tokens_map = {
id: token for token, id in transformers_tokenizer.get_added_vocab().items()
}
space_prefix = "▁".encode()
for i in range(len(transformers_tokenizer)):
if i in special_tokens_map:
byte_coded = special_tokens_map[i].encode()
else:
byte_coded = re.sub(
rb"<0x(..)>",
lambda x: bytes.fromhex(x[1].decode()),
transformers_tokenizer.sp_model.id_to_piece(i).encode(),
)
byte_tokens[i] = byte_coded.replace(space_prefix, b" ")
return byte_tokens

def _byte_tokens_from_vocab(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path used to call transformers_tokenizer.get_vocab and then do nothing with it. It was also only happening if the tokenizer actually had that method. As far as I can tell, it always has that method (again, could be a version thing?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think originally I may have used that method to get the tokenizer key/value pairs before replacing it with just iterating through a range the same length as the tokenizer, I guess I didn't realize that we just weren't using the vocabulary at all anymore, so I think removing it makes sense if the functionality stays the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:
byte_tokens = [b""] * len(transformers_tokenizer)
special_tokens_map = {
id: token for token, id in transformers_tokenizer.get_added_vocab().items()
}
byte_encoder = self._bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(len(transformers_tokenizer)):
if i in special_tokens_map:
byte_coded = special_tokens_map[i].encode()
else:
token = transformers_tokenizer.convert_ids_to_tokens(i)
if isinstance(token, bytes):
byte_coded = token
elif isinstance(token, str):
if hasattr(transformers_tokenizer, "convert_tokens_to_string"):
token_str = transformers_tokenizer.convert_tokens_to_string([token])
encoded_str = transformers_tokenizer.encode(token_str)
if len(encoded_str) != 1:
raise ValueError(f"Round-trip encoding of tokens [{token}] failed! Got {encoded_str}")
roundtrip_id = encoded_str[0]
Comment on lines +231 to +234
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, here's where that IndexError was happening before

if roundtrip_id == i:
byte_coded = token_str.encode()
else:
byte_coded = bytes([byte_decoder[c] for c in token])
else:
byte_coded = token.encode()
else:
raise ValueError(f"Unexpected token type: {type(token)}")
byte_tokens[i] = byte_coded
return byte_tokens

def _byte_tokens_fallback(
self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:
byte_tokens = [b""] * len(transformers_tokenizer)
byte_decoder: dict[str, int] = transformers_package.AutoTokenizer.from_pretrained(
"gpt2", use_fast=False
).byte_decoder # fall back to gpt2 mapping

# some special tokens may not have their whitespace encoded...
byte_decoder[" "] = 32
byte_decoder["\n"] = 10
byte_decoder["\r"] = 13
byte_decoder["\t"] = 9
byte_decoder["▁"] = 32

# run a quick spot check to verify we can rebuild complex multi-token unicode symbols
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨"
reconstructed = b""
try:
input_ids = transformers_tokenizer(s)["input_ids"]
for i in input_ids:
nxt_bytes = []
token_str = transformers_tokenizer.convert_ids_to_tokens(i)
for c in token_str:
nxt_bytes.append(byte_decoder[c])
reconstructed += bytes(nxt_bytes)
# Check if the tokenizer has a bos_token attribute, and if it does, check
# if it's at the start of the reconstructed bytes
# Some tokenizers add this automatically as part of the call function, so
# we need to remove it to compare
if hasattr(transformers_tokenizer, "bos_token") and reconstructed.startswith(
transformers_tokenizer.bos_token.encode()
):
reconstructed = reconstructed[len(transformers_tokenizer.bos_token) :]
Copy link
Collaborator

@Harsha-Nori Harsha-Nori Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, these reconstitution checks should happen more generally than just in the fast-fallback path. Very possible that guidance wasn't doing this before, but I think it should happen after all the paths are exhausted and a single method is picked

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe it was doing it in any case but the fallback, but I could be wrong. But agreed with you that this is reasonable to do always... at least to give a warning if not an exception

except Exception as e:
msg = textwrap.dedent(
f"""
The tokenizer being used is unable to convert a special character in {s}.
For models with sentencepiece based tokenizers (e.g. llama, phi-3-mini),
installing sentencepiece often fixes this issue (pip install sentencepiece).
"""
)
raise ValueError(msg) from e
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure of whether this method will ever be called now that we "fail early" with respect to the fast tokenizer. If we hit this exception, the message may be wrong now, as missing an import probably isn't the cause of the exception anymore...

assert (
reconstructed.decode() == s
), "The passed tokenizer does not have a byte_decoder property and using a standard gpt2 byte_decoder fails!"

for i in range(len(transformers_tokenizer)):
byte_coded = bytes(
[byte_decoder[c] for c in transformers_tokenizer.convert_ids_to_tokens(i)]
)
byte_tokens[i] = byte_coded
return byte_tokens

def _byte_decoder_has_all_bytes(self, byte_decoder: dict[str, int], vocab: dict[str, int]) -> bool:
# This is here because some tokenizers are bad and don't have all the bytes (I'm looking at you, microsoft/phi2)
all_bytes = set()
for x in vocab.keys():
for y in x:
all_bytes.add(y)
return set(byte_decoder.keys()) >= all_bytes

def _bytes_to_unicode(self):
bs = (
list(range(ord("!"), ord("~") + 1))
Expand All @@ -190,38 +291,6 @@ def _bytes_to_unicode(self):
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def _tokenizer(self, model, **kwargs) -> Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
]:
# intantiate the tokenizer
if isinstance(model, str):
# make sure transformers is installed
if not has_transformers:
raise Exception("Please install transformers with `pip install transformers`")

try:
tokenizer = transformers_package.AutoTokenizer.from_pretrained(
model, use_fast=False, **kwargs
)
# This is here because some tokenizers are bad and don't have all the bytes (I'm looking at you, microsoft/phi2)
if hasattr(tokenizer, "byte_decoder"):
all_bytes = set()
for x in tokenizer.get_vocab().keys():
for y in x:
all_bytes.add(y)
assert set(tokenizer.byte_decoder.keys()).intersection(all_bytes) == all_bytes
except:
tokenizer = transformers_package.AutoTokenizer.from_pretrained(
model, use_fast=True, **kwargs
) # fall back to the fast tokenizer

assert (
tokenizer is not None
), "You must give a model name when you provide a tokenizer object!"

return tokenizer

def encode(self, byte_string: bytes) -> Sequence[int]:
assert isinstance(byte_string, bytes)
# HF tokenizers take in strings apparently
Expand Down
Loading