Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tests to use pytest parametrized #89

Merged
merged 6 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Build PIP Package
run: bazel-bin/build_pip_pkg artifacts
- name: Install PIP Package
run: pip3 install artifacts/*.whl
run: pip3 install pytest artifacts/*.whl
- name: Run Python Tests
run: bazel test larq_compute_engine:py_tests --python_top=//larq_compute_engine:pyruntime --test_output=errors

Expand All @@ -41,7 +41,7 @@ jobs:
- name: Build PIP Package
run: bazel-bin/build_pip_pkg artifacts
- name: Install PIP Package
run: pip3 install artifacts/*.whl
run: pip3 install pytest artifacts/*.whl
- name: Run Python Tests
run: bazel test larq_compute_engine:py_tests --python_top=//larq_compute_engine:pyruntime --test_output=errors

Expand Down
2 changes: 2 additions & 0 deletions larq_compute_engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ py_test(
":compute_engine_ops_py",
":compute_engine_utils_py",
],
size = "small",
python_version = "PY3",
srcs_version = "PY3",
)
Expand All @@ -281,6 +282,7 @@ py_test(
":compute_engine_ops_py",
":compute_engine_utils_py",
],
size = "large",
python_version = "PY3",
srcs_version = "PY3",
)
Expand Down
101 changes: 42 additions & 59 deletions larq_compute_engine/python/ops/bconv2d_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for compute engine ops."""
import numpy as np
import tensorflow as tf
import itertools
import sys
import pytest


try:
Expand All @@ -16,67 +17,49 @@
from ..utils import eval_op


class BConv2DTest(tf.test.TestCase):
def __test_bconv(self, bconv_op):
data_types = [np.float32, np.float64]
data_formats = ["NHWC"]
in_sizes = [[10, 10], [11, 11], [10, 11]]
filter_sizes = [[3, 3], [4, 5]]
in_channels = [1, 31, 32, 33, 64]
out_channels = [1, 16]
hw_strides = [[1, 1], [2, 2]]
paddings = ["VALID", "SAME"]
@pytest.mark.parametrize("bconv_op", [bconv2d8, bconv2d32, bconv2d64])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("data_format", ["NHWC"])
@pytest.mark.parametrize("in_size", [[10, 10], [11, 11], [10, 11]])
@pytest.mark.parametrize("filter_size", [[3, 3], [4, 5]])
@pytest.mark.parametrize("in_channel", [1, 31, 32, 33, 64])
@pytest.mark.parametrize("out_channel", [1, 16])
@pytest.mark.parametrize("hw_stride", [[1, 1], [2, 2]])
@pytest.mark.parametrize("padding", ["VALID", "SAME"])
def test_bconv(
bconv_op,
dtype,
data_format,
in_size,
filter_size,
in_channel,
out_channel,
hw_stride,
padding,
):
batch_size = out_channel
h, w = in_size
fh, fw = filter_size
if data_format == "NHWC":
ishape = [batch_size, h, w, in_channel]
strides = [1] + hw_stride + [1]
else:
raise ValueError("Unknown data_format: " + str(data_format))
fshape = [fh, fw, in_channel, out_channel]

args_lists = [
data_types,
data_formats,
in_sizes,
filter_sizes,
in_channels,
out_channels,
hw_strides,
paddings,
]
for args in itertools.product(*args_lists):
dtype, data_format, in_size, filter_size, in_channel, out_channel, hw_stride, padding = (
args
)
sample_list = [-1, 1]
inp = np.random.choice(sample_list, np.prod(ishape)).astype(dtype)
inp = np.reshape(inp, ishape)

batch_size = out_channel
h, w = in_size
fh, fw = filter_size
if data_format == "NHWC":
ishape = [batch_size, h, w, in_channel]
strides = [1] + hw_stride + [1]
else:
raise ValueError("Unknown data_format: " + str(data_format))
fshape = [fh, fw, in_channel, out_channel]
filt = np.random.choice(sample_list, np.prod(fshape)).astype(dtype)
filt = np.reshape(filt, fshape)

sample_list = [-1, 1]
inp = np.random.choice(sample_list, np.prod(ishape)).astype(dtype)
inp = np.reshape(inp, ishape)

filt = np.random.choice(sample_list, np.prod(fshape)).astype(dtype)
filt = np.reshape(filt, fshape)

with self.test_session():
output = eval_op(
bconv_op(inp, filt, strides, padding, data_format=data_format)
)
expected = eval_op(
tf.nn.conv2d(inp, filt, strides, padding, data_format=data_format)
)
self.assertAllClose(output, expected)

def test_bconv2d8(self):
self.__test_bconv(bconv2d8)

def test_bconv2d32(self):
self.__test_bconv(bconv2d32)

def test_bconv2d64(self):
self.__test_bconv(bconv2d64)
output = eval_op(bconv_op(inp, filt, strides, padding, data_format=data_format))
expected = eval_op(
tf.nn.conv2d(inp, filt, strides, padding, data_format=data_format)
)
np.testing.assert_allclose(output, expected)


if __name__ == "__main__":
tf.test.main()
sys.exit(pytest.main([__file__]))
44 changes: 15 additions & 29 deletions larq_compute_engine/python/ops/bsign_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for compute engine ops."""
import numpy as np
import tensorflow as tf
import sys
import pytest


try:
Expand All @@ -11,36 +12,21 @@
from compute_engine_ops.python.utils import eval_op


class SignTest(tf.test.TestCase):
def run_test_for_integers(self, dtype):
with self.test_session():
x = np.array([[2, -5], [-3, 0]]).astype(dtype)
expected_output = np.array([[1, -1], [-1, 1]])
self.assertAllClose(eval_op(bsign(x)), expected_output)
@pytest.mark.parametrize("dtype", [np.int8, np.int32, np.int64])
def test_sign_int(dtype):
x = np.array([[2, -5], [-3, 0]]).astype(dtype)
expected_output = np.array([[1, -1], [-1, 1]])
np.testing.assert_allclose(eval_op(bsign(x)), expected_output)

# Test for +0 and -0 floating points.
# We have sign(+0) = 1 and sign(-0) = -1
def run_test_for_floating(self, dtype):
with self.test_session():
x = np.array([[0.1, -5.8], [-3.0, 0.00], [0.0, -0.0]]).astype(dtype)
expected_output = np.array([[1, -1], [-1, 1], [1, -1]])
self.assertAllClose(eval_op(bsign(x)), expected_output)

def test_sign_int8(self):
self.run_test_for_integers(np.int8)

def test_sign_int32(self):
self.run_test_for_integers(np.int32)

def test_sign_int64(self):
self.run_test_for_integers(np.int64)

def test_sign_float32(self):
self.run_test_for_floating(np.float32)

def test_sign_float64(self):
self.run_test_for_floating(np.float64)
# Test for +0 and -0 floating points.
# We have sign(+0) = 1 and sign(-0) = -1
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_sign_float(dtype):
x = np.array([[0.1, -5.8], [-3.0, 0.00], [0.0, -0.0]]).astype(dtype)
expected_output = np.array([[1, -1], [-1, 1], [1, -1]])
np.testing.assert_allclose(eval_op(bsign(x)), expected_output)


if __name__ == "__main__":
tf.test.main()
sys.exit(pytest.main([__file__]))
10 changes: 4 additions & 6 deletions larq_compute_engine/python/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Utils for testing compute engine ops."""
from tensorflow import __version__
import tensorflow as tf

from distutils.version import LooseVersion


def tf_2_or_newer():
return LooseVersion(__version__) >= LooseVersion("2.0")
return LooseVersion(tf.__version__) >= LooseVersion("2.0")


def eval_op(op):
if tf_2_or_newer():
return op # op.numpy() also works
else:
return op.eval()
return tf.keras.backend.get_value(op)