-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support Qwen2 * support Qwen2 * Delete README.md * Revert "Delete README.md" This reverts commit 026b05f. * Update README.md * Qwen2 == Mistral * Update llama.py * Update __init__.py * Update README.md --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
- Loading branch information
1 parent
7c53652
commit cf83fe3
Showing
6 changed files
with
101 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .llama import * | ||
from .mistral import FastMistralModel | ||
import os | ||
from ._utils import __version__ | ||
|
||
from transformers.models.qwen2.modeling_qwen2 import ( | ||
Qwen2Attention, | ||
Qwen2DecoderLayer, | ||
Qwen2Model, | ||
Qwen2ForCausalLM, | ||
) | ||
# For Pytorch 2.1.1 | ||
try: | ||
from transformers.models.qwen2.modeling_qwen2 import ( | ||
Qwen2SdpaAttention, | ||
Qwen2FlashAttention2, | ||
) | ||
except: | ||
Qwen2SdpaAttention = Qwen2Attention | ||
Qwen2FlashAttention2 = Qwen2Attention | ||
pass | ||
|
||
|
||
class FastQwen2Model(FastLlamaModel): | ||
|
||
@staticmethod | ||
def pre_patch(): | ||
Qwen2Attention .forward = LlamaAttention_fast_forward | ||
Qwen2SdpaAttention .forward = LlamaAttention_fast_forward | ||
Qwen2FlashAttention2.forward = LlamaAttention_fast_forward | ||
Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward | ||
Qwen2Model .forward = LlamaModel_fast_forward | ||
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) | ||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward | ||
|
||
# Solves https://github.com/unslothai/unsloth/issues/168 | ||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower. | ||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. | ||
# https://github.com/huggingface/transformers/pull/27931 | ||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py | ||
import transformers.models.qwen2.modeling_qwen2 | ||
transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding | ||
return | ||
pass | ||
|
||
|
||
@staticmethod | ||
def from_pretrained( | ||
model_name = "Qwen/Qwen1.5-7B", | ||
max_seq_length = 4096, | ||
dtype = None, | ||
load_in_4bit = True, | ||
token = None, | ||
device_map = "sequential", | ||
rope_scaling = None, # Qwen2 does not support RoPE scaling | ||
fix_tokenizer = True, | ||
model_patcher = None, | ||
tokenizer_name = None, | ||
trust_remote_code = False, | ||
**kwargs, | ||
): | ||
return FastMistralModel.from_pretrained( | ||
model_name = model_name, | ||
max_seq_length = max_seq_length, | ||
dtype = dtype, | ||
load_in_4bit = load_in_4bit, | ||
token = token, | ||
device_map = device_map, | ||
rope_scaling = rope_scaling, | ||
fix_tokenizer = fix_tokenizer, | ||
model_patcher = FastQwen2Model, | ||
tokenizer_name = tokenizer_name, | ||
trust_remote_code = trust_remote_code, | ||
**kwargs, | ||
) | ||
pass | ||
pass |