Skip to content

Commit

Permalink
fix small issue in test
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed May 1, 2020
1 parent 2959086 commit 22aed4a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/transformers/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,8 @@ def forward(
# if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
has_to_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0
if has_to_pad_to_match_chunk_length is True:

if has_to_pad_to_match_chunk_length:
# pad input
input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length(
input_ids,
Expand All @@ -1585,7 +1586,7 @@ def forward(
sequence_output = encoder_outputs.hidden_states

# if padding was applied
if has_to_pad_to_match_chunk_length is True:
if has_to_pad_to_match_chunk_length:
sequence_output = sequence_output[:, :orig_sequence_length]

outputs = (sequence_output,)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ def generate(
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
num_hashes=None,
**model_specific_kwargs
):
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,9 @@ def test_pretrained_generate_crime_and_punish(self):
model.eval()

input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False)
output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
)
output_text = tokenizer.decode(output_ids[0])
self.assertEqual(
output_text,
Expand Down

0 comments on commit 22aed4a

Please sign in to comment.