Skip to content

Commit

Permalink
Nightly (#461)
Browse files Browse the repository at this point in the history
* Fix prompt

* Update chat_templates.py

* fix_untrained_tokens

* Update llama.py

* add tokens

* Update _utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* pad_token

* Update chat_templates.py

* Update chat_templates.py

* tokenizer

* Update save.py

* Update chat_templates.py

* Update chat_templates.py

* patch tokenizer padding

* Update tokenizer_utils.py

* Update save.py

* Fix: loading models with resized vocabulary (#377)

* new: vocab resize on load

* new: gitignore

* GGUF fix

* Readme (#390)

* Update README.md

* Update README.md

---------

Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>

* Update README.md

* Delete .gitignore

* Phi-3

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Fix reserved tokens

* Update save.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update chat_templates.py

* Update save.py

* Update _utils.py

* Update chat_templates.py

* Adds dependencies and extras for torch 2.3.0 with new xformers versions (#415)

* Adds dependencies and extras for torch 2.3.0 with new xformers versions

* Add 2.3.0 section to readme

* Support Qwen2 (#428)

* support Qwen2

* support Qwen2

* Delete README.md

* Revert "Delete README.md"

This reverts commit 026b05f.

* Update README.md

* Qwen2 == Mistral

* Update llama.py

* Update __init__.py

* Update README.md

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update save.py

* Update save.py

* Update _utils.py

* Update save.py

* Update save.py

* Update save.py

* test_hf_gguf_equivalence

* Update chat_templates.py

* Update chat_templates.py

* --pad-vocab

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Unspecified max_seq_length

* possible_pad_token

* Update tokenizer_utils.py

---------

Co-authored-by: Igor Kilbas <whitemarsstudios@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com>
Co-authored-by: Yang JianXin <995462226@qq.com>
  • Loading branch information
5 people authored May 13, 2024
1 parent d4512f7 commit 47ffd39
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
12 changes: 7 additions & 5 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,14 @@ def patch_tokenizer(model, tokenizer):
# Try unk_token
possible_pad_token = tokenizer.unk_token
pass

if possible_pad_token is None:
# Failure!!
raise RuntimeError(
"Unsloth: Tokenizer's pad_token cannot be = eos_token, and we couldn't find a\n"\
"replacement of either <|reserved... or <|placeholder..."
)
# Failure to find a good replacement!! We shall manually add one!
new_pad_token = "<|PAD_TOKEN|>"
while new_pad_token in tokenizer.get_vocab():
new_pad_token += "#"
pass
possible_pad_token = new_pad_token
pass

name = model.config._name_or_path if model is not None else "Model"
Expand Down
7 changes: 6 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def pre_patch():
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-2-7b-bnb-4bit",
max_seq_length = 4096,
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
Expand Down Expand Up @@ -1050,6 +1050,11 @@ def from_pretrained(
model_max_seq_length = \
AutoConfig.from_pretrained(model_name, token = token).max_position_embeddings

# If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
pass

if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
rope_scaling = max_seq_length / model_max_seq_length
logger.warning_once(
Expand Down
4 changes: 2 additions & 2 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _get_model_name(model_name, load_in_4bit = True):
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/mistral-7b-bnb-4bit",
max_seq_length = 4096,
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
Expand Down
7 changes: 6 additions & 1 deletion unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def pre_patch():
@staticmethod
def from_pretrained(
model_name = "unsloth/mistral-7b-bnb-4bit",
max_seq_length = 4096,
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
Expand Down Expand Up @@ -340,6 +340,11 @@ def from_pretrained(
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings

# If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
pass

# Mistral does NOT support RoPE Scaling sadly so we have to error out.
if max_seq_length > model_max_seq_length:
raise RuntimeError(
Expand Down
16 changes: 14 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class SentencePieceTokenTypes(IntEnum):
if len(added_tokens_json) == 0: return

added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1]))
new_size = sentence_piece_size + len(added_tokens_json)

# Confirm added_tokens_json is correct
added_tokens_ids = np.array(list(added_tokens_json.values()))
Expand All @@ -312,7 +313,11 @@ class SentencePieceTokenTypes(IntEnum):
if (added_tokens_ids.min() != sentence_piece_size): return

# Edit sentence piece tokens with added_tokens_json
logger.warning("Unsloth: Extending tokenizer.model with added_tokens.json!")
logger.warning(
f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\
f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\
f"But we need to extend to sentencepiece vocab size ({new_size})."
)
new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):])
for new_token, added_token in zip(new_tokens, added_tokens_json.keys()):
new_token.piece = added_token.encode("utf-8")
Expand Down Expand Up @@ -357,7 +362,10 @@ def load_correct_tokenizer(
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
except:
Expand Down Expand Up @@ -512,7 +520,10 @@ def check_tokenizer(
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
return check_tokenizer(
Expand Down Expand Up @@ -725,7 +736,8 @@ def fix_sft_trainer_tokenizer():
"test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])\n"\
"chat_template = getattr(tokenizer, 'chat_template', None)\n"\
"chat_template = '' if chat_template is None else chat_template\n"\
"has_bos_token_already = test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template\n"\
"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
"if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
"add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"

check_text = check_text.split("\n")
Expand Down

0 comments on commit 47ffd39

Please sign in to comment.