-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support chatglm fine grained dybatch v1. #6798
Support chatglm fine grained dybatch v1. #6798
Conversation
Codecov Report
@@ Coverage Diff @@
## develop #6798 +/- ##
===========================================
- Coverage 60.30% 60.06% -0.24%
===========================================
Files 544 546 +2
Lines 80364 80680 +316
===========================================
Hits 48460 48460
- Misses 31904 32220 +316
|
@@ -181,7 +209,7 @@ def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_tok | |||
model_kwargs["seq_len_decoder"], | |||
model_kwargs["seq_len_decoder"] + 1, | |||
) | |||
return model_kwargs | |||
return model_kwargs, next_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否需要返回next_tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 set_multi_stops 应该可以放到 sample 函数里面来处理吧,所以就可以不在这里返回 next_tokens。
尽量和paddlenlp 现有的函数输入和输出保持一致。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done~
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16") | ||
|
||
inputs["attention_mask"] = self.attention_mask | ||
inputs["tgt_generation_mask"] = self.tgt_generation_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
商量了下,先暂时这个样子,等后面针对于 chatglm 的tokenizer 再调整一下,此时这里的分支代码就可以删掉了。
|
||
config.tensor_parallel_degree = tensor_parallel_degree | ||
config.tensor_parallel_rank = tensor_parallel_rank | ||
model = LlamaForCausalLMInferenceModel.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里看下能否走AutoModelForCausalLM那种方式,内部根据config去分发
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前还不能通过 AutoModelForCausalLM 来分发,所以初始化目前只能够 hardcode。
不过判断是哪种模型可以通过上面的 config.architectures 来判断。
llm/predictor.py
Outdated
"你好", | ||
"你好啊,请问你叫什么名字", | ||
"你好啊,你在干什么", | ||
# "My name is?" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释可以删一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -158,19 +163,42 @@ def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_tok | |||
if cache is None: | |||
# encoder's generation | |||
model_kwargs["tgt_ids"] = paddle.where(just_decoder, model_kwargs["tgt_ids"], next_tokens) | |||
model_kwargs["tgt_pos"] = paddle.where(just_decoder, model_kwargs["tgt_pos"], model_kwargs["tgt_pos"] + 1) | |||
# import pdb;pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 pdb 代码应该删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -181,7 +209,7 @@ def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_tok | |||
model_kwargs["seq_len_decoder"], | |||
model_kwargs["seq_len_decoder"] + 1, | |||
) | |||
return model_kwargs | |||
return model_kwargs, next_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 set_multi_stops 应该可以放到 sample 函数里面来处理吧,所以就可以不在这里返回 next_tokens。
尽量和paddlenlp 现有的函数输入和输出保持一致。
|
||
config.tensor_parallel_degree = tensor_parallel_degree | ||
config.tensor_parallel_rank = tensor_parallel_rank | ||
model = LlamaForCausalLMInferenceModel.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前还不能通过 AutoModelForCausalLM 来分发,所以初始化目前只能够 hardcode。
不过判断是哪种模型可以通过上面的 config.architectures 来判断。
zhengzekang seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16") | ||
|
||
inputs["attention_mask"] = self.attention_mask | ||
inputs["tgt_generation_mask"] = self.tgt_generation_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
商量了下,先暂时这个样子,等后面针对于 chatglm 的tokenizer 再调整一下,此时这里的分支代码就可以删掉了。
@@ -12,5 +12,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
from .chatglm import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个的 import 应该是要放到 from .fused_transformer_layers import *
下面的。
@@ -134,7 +139,7 @@ def generate( | |||
return ret | |||
|
|||
@staticmethod | |||
def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_token_id, model_kwargs): | |||
def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_token_id, config, model_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过和 @xiaoxiaohehe001 商量之后,为了给模型足够的控制范围,决定将参数转化为实例函数(去掉 @staticmethod), 这样派生模型就可以通过 self.config 来获取到对应的配置,同时也可以重写对应函数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Description
Support chatglm fine grained dybatch v1.