-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[INFER] llama&qwen2 A8W8 support skip_scale #8987
Conversation
Thanks for your contribution! |
@@ -700,6 +623,97 @@ def __init__(self, config: FusedMultiTransformerConfig): | |||
|
|||
self.linear = fused_linear | |||
|
|||
def init_weight(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_weight -> init_weights ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉不是特别有必要,init_weight_shape和get_weight_dtype都是单数
@@ -1773,7 +1787,101 @@ def __init__(self, config: FusedMultiTransformerConfig): | |||
self._add_parameter(ffn2_shift) | |||
self._add_parameter(ffn2_smooth) | |||
|
|||
def get_weight_create_dype(self): | |||
def init_weight(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_weight -> init_weights
|
||
def get_weight_create_dype(self, layer_name=None, layer_idx=None): | ||
if layer_name is not None and layer_idx is not None: | ||
if hasattr(self, "weight_scales") and np.all(self.weight_scales[layer_name][layer_idx] == -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看到好多处有类似的判断条件了,我建议把
if hasattr(self, "weight_scales") and np.all(self.weight_scales[layer_name][layer_idx] == -1)
封装成一个函数,并且加上注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,封装成skip_quant
@@ -326,6 +335,10 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): | |||
) | |||
logits = logits / temperature | |||
|
|||
# sample | |||
if self.config.top_k is not None and self.config.top_k != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top_k的逻辑给删掉,这个是业务特殊要求的,通用的后处理组网没有这么干的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除了
@@ -935,10 +985,16 @@ def set_state_dict(self, state_dict): | |||
self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight_tensor) | |||
self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale_tensor) | |||
elif "a8w8" in self.quant_type: | |||
w_dtype = ( | |||
paddle.get_default_dtype() | |||
if np.all(weight_scales_loader.scale["out_linear_weight_scale"][idx] == -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if np.all(weight_scales_loader.scale["out_linear_weight_scale"][idx] == -1)类似的,这种判断逻辑封装成一个函数并加上注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
num_key_value_heads=self.num_key_value_heads, | ||
mp_size=self.config.tensor_parallel_degree, | ||
) | ||
self.transformer_block.act_scales = act_scale_loader.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能删掉这行,后面推理找不到act_scales了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama文件里的这一行呢?
Merged with #9197 |
PR types
Bug fixes
PR changes
Models
Description
支持llama和qwen2 W8A8 跳层量化