Skip to content

Commit

Permalink
[UnitTest]add bloom testing (#5483)
Browse files Browse the repository at this point in the history
* add bloom testing

* complete bloom test-modeling

* fix variable assigment

* update paralle_matmul method
  • Loading branch information
wj-Mcat authored Apr 4, 2023
1 parent 08319b8 commit 5acbba4
Show file tree
Hide file tree
Showing 3 changed files with 712 additions and 14 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/transformers/bloom/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self.dtype = dtype
self.slow_but_exact = slow_but_exact
self.mp_degree = mp_degree
self.pp_degree = mp_degree
self.pp_degree = pp_degree
self.mp_rank = mp_rank
self.use_recompute = use_recompute
self.use_pure_fp16 = use_pure_fp16
Expand Down
47 changes: 34 additions & 13 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@


def parallel_matmul(x: Tensor, y: Tensor, parallel_output=True):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
world_size = hcg.get_model_parallel_world_size()
world_size = paddle.distributed.get_world_size()
if world_size > 1:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
hcg = fleet.get_hybrid_communicate_group()
Expand Down Expand Up @@ -748,14 +746,20 @@ def _get_name_mappings(cls, config: BloomConfig) -> list[StateDictNameMapping]:
]
)

model_class_name = config.architectures[0]
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(hard_mapping)]
model_class_name = config.architectures[0]

if model_class_name != "BloomModel":
for mapping in mappings:
mapping.source_name = "transformer." + mapping.source_name
mapping.target_name = "bloom." + mapping.target_name

if model_class_name == "BloomForSequenceClassification":
mappings.append(StateDictNameMapping("score.weight", "score.weight", "transpose"))
if model_class_name == "BloomForTokenClassification":
mappings.append(StateDictNameMapping("classifier.weight", "classifier.weight", "transpose"))
mappings.append(StateDictNameMapping("classifier.bias", "classifier.bias"))

return mappings


Expand Down Expand Up @@ -901,16 +905,15 @@ def forward(
seq_length_with_past = seq_length_with_past + past_key_values_length

if attention_mask is None:
if input_ids is not None:
attention_mask = paddle.ones([batch_size, seq_length], dtype=paddle.get_default_dtype())
attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype=paddle.get_default_dtype())

alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if self.config.mp_rank >= 0:
if self.config.mp_degree > 1:
block_size = self.config.n_head // self.config.mp_degree
alibi = alibi[:, self.mp_rank * block_size : (self.mp_rank + 1) * block_size]
alibi = alibi.reshape([batch_size * block_size, 1, seq_length_with_past])
Expand Down Expand Up @@ -1081,6 +1084,20 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

@staticmethod
def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
input_ids == pad_token_id
).numpy().item()
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
attention_mask = (input_ids == pad_token_id).astype(paddle.get_default_dtype()) * -1e9
else:
attention_mask = paddle.zeros_like(input_ids, dtype=paddle.get_default_dtype())
return attention_mask

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
Expand Down Expand Up @@ -1219,28 +1236,32 @@ def forward(

if input_ids is not None:
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
else:
batch_size = inputs_embeds.shape[0]
sequence_length = inputs_embeds.shape[1]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

if self.config.pad_token_id is None:
sequence_lengths = -1

pooled_logits = logits[:, -1]
else:
if input_ids is not None:
# select the last word of batch sentence
sequence_lengths = paddle.where(input_ids != self.config.pad_token_id, 1, 0).sum(axis=-1) - 1
# sequence_lengths = paddle.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths += paddle.to_tensor([i * input_ids.shape[1] for i in range(batch_size)])
pooled_logits = paddle.index_select(
logits.reshape([batch_size * sequence_length, -1]), sequence_lengths, axis=0
)

else:
sequence_lengths = -1
pooled_logits = logits[:, -1]
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[paddle.arange(batch_size), sequence_lengths]

loss = None
if labels is not None:
if self.config.problem_type is None:
Expand Down
Loading

0 comments on commit 5acbba4

Please sign in to comment.