diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 5aedfb3de9fd..451ad31d6d74 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -1299,8 +1299,12 @@ def init_weight_shape(self, config): self.moe_ffn2_weight_shape = [self.config.moe_config.num_experts, self.dim_feedforward, self.embed_dim] if config.quant_type == "weight_only_int4": - self.moe_ffn1_weight_shape[2] //= 2 - self.moe_ffn2_weight_shape[2] //= 2 + if config.moe_config.has_shared_expert(): + self.moe_ffn1_weight_shape[2] //= 2 + self.moe_ffn2_weight_shape[1] //= 2 + else: + self.moe_ffn1_weight_shape[2] //= 2 + self.moe_ffn2_weight_shape[2] //= 2 if self.config.moe_config.has_shared_expert(): self.shared_expert_ffn1_weight_shape = [ @@ -1315,6 +1319,9 @@ def init_weight_shape(self, config): self.embed_dim, 1, ] + if config.quant_type == "weight_only_int4": + self.shared_expert_ffn1_weight_shape[0] //= 2 + self.shared_expert_ffn2_weight_shape[0] //= 2 def compute_qkv_linear(self, ln_out, i): return weight_only_linear(