Skip to content

Commit

Permalink
update batch params
Browse files Browse the repository at this point in the history
  • Loading branch information
josh146 authored Nov 9, 2021
1 parent 5a4d8d8 commit 859d9d7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/transforms/test_batch_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def circuit(data, x, weights):

spy = mocker.spy(circuit.device, "batch_execute")
res = circuit(data, x, weights)
assert res.shape == (batch_size, 1, 4)
assert res.shape == (batch_size, 4)
assert len(spy.call_args[0][0]) == batch_size


Expand All @@ -61,7 +61,7 @@ def circuit(data):

spy = mocker.spy(circuit.device, "batch_execute")
res = circuit(data)
assert res.shape == (batch_size, 1, 4)
assert res.shape == (batch_size, 4)
assert len(spy.call_args[0][0]) == batch_size


Expand Down Expand Up @@ -244,7 +244,7 @@ def circuit(x, weights):

spy = mocker.spy(circuit.device, "batch_execute")
res = circuit(x, weights)
assert res.shape == (batch_size, 1, 4)
assert res.shape == (batch_size, 4)
assert len(spy.call_args[0][0]) == batch_size


Expand Down

0 comments on commit 859d9d7

Please sign in to comment.