Skip to content

Commit

Permalink
Fix AlignModelTest tests (#21923)
Browse files Browse the repository at this point in the history
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Mar 3, 2023
1 parent c5a1ff9 commit d4306da
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tests/models/align/test_modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AlignVisionModelTester:
def __init__(
self,
parent,
batch_size=13,
batch_size=12,
image_size=32,
num_channels=3,
kernel_sizes=[3, 3, 5],
Expand Down Expand Up @@ -234,7 +234,7 @@ class AlignTextModelTester:
def __init__(
self,
parent,
batch_size=13,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
Expand Down Expand Up @@ -521,6 +521,15 @@ def _create_and_check_torchscript(self, config, inputs_dict):
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()

non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]

loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}

self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))

models_equal = True
Expand Down

0 comments on commit d4306da

Please sign in to comment.