Skip to content

Commit

Permalink
Correct attention mask dtype for Flax GPT2 (huggingface#25636)
Browse files Browse the repository at this point in the history
* Correct attention mask dtype

* reformat code

* add a test for boolean mask

* convert test to fast test

* delete unwanted print

* use assertTrue for testing
  • Loading branch information
liutianlin0121 authored and parambharat committed Sep 26, 2023
1 parent e164f2d commit 6e16f63
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/gpt2/modeling_flax_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,9 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
extended_attention_mask = lax.dynamic_update_slice(
extended_attention_mask, attention_mask.astype("i4"), (0, 0)
)
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

Expand Down
27 changes: 27 additions & 0 deletions tests/models/gpt2/test_modeling_flax_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")

def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask):
model = model_class_name(config)

output_int_att_mask = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=3,
)

output_bool_att_mask = model.generate(
input_ids=input_ids,
attention_mask=attention_mask.astype(bool),
max_new_tokens=3,
)

self.parent.assertTrue(
(output_bool_att_mask.sequences == output_int_att_mask.sequences).all(),
"Generated response differ between boolean and integer attention mask",
)


@require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
Expand All @@ -208,6 +228,13 @@ def test_use_cache_forward_with_attn_mask(self):
model_class_name, config, input_ids, attention_mask
)

def test_bool_attention_mask_in_generation(self):
for model_class_name in self.all_generative_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_bool_attention_mask_in_generation(
model_class_name, config, input_ids, attention_mask
)

@slow
def test_batch_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
Expand Down

0 comments on commit 6e16f63

Please sign in to comment.