From 52c3928b0995640a3b2862cd9ba0050700922e6b Mon Sep 17 00:00:00 2001 From: Tiexin Guo Date: Tue, 16 Apr 2024 16:02:12 +0800 Subject: [PATCH 1/3] test: refactor some for loop tests into pytest parametrize --- test/test_jujuversion.py | 205 +++++++++++++-------------- test/test_log.py | 118 +++++++-------- test/test_pebble.py | 299 +++++++++++++++++++-------------------- test/test_private.py | 158 ++++++++++----------- 4 files changed, 384 insertions(+), 396 deletions(-) diff --git a/test/test_jujuversion.py b/test/test_jujuversion.py index 644cb747c..f9b2bdb52 100644 --- a/test/test_jujuversion.py +++ b/test/test_jujuversion.py @@ -20,28 +20,25 @@ import ops -class TestJujuVersion(unittest.TestCase): - - def test_parsing(self): - test_cases = [ - ("0.0.0", 0, 0, '', 0, 0), - ("0.0.2", 0, 0, '', 2, 0), - ("0.1.0", 0, 1, '', 0, 0), - ("0.2.3", 0, 2, '', 3, 0), - ("10.234.3456", 10, 234, '', 3456, 0), - ("10.234.3456.1", 10, 234, '', 3456, 1), - ("1.21-alpha12", 1, 21, 'alpha', 12, 0), - ("1.21-alpha1.34", 1, 21, 'alpha', 1, 34), - ("2.7", 2, 7, '', 0, 0) - ] - - for vs, major, minor, tag, patch, build in test_cases: - v = ops.JujuVersion(vs) - assert v.major == major - assert v.minor == minor - assert v.tag == tag - assert v.patch == patch - assert v.build == build +class TestJujuVersion: + @pytest.mark.parametrize("vs,major,minor,tag,patch,build", [ + ("0.0.0", 0, 0, '', 0, 0), + ("0.0.2", 0, 0, '', 2, 0), + ("0.1.0", 0, 1, '', 0, 0), + ("0.2.3", 0, 2, '', 3, 0), + ("10.234.3456", 10, 234, '', 3456, 0), + ("10.234.3456.1", 10, 234, '', 3456, 1), + ("1.21-alpha12", 1, 21, 'alpha', 12, 0), + ("1.21-alpha1.34", 1, 21, 'alpha', 1, 34), + ("2.7", 2, 7, '', 0, 0) + ]) + def test_parsing(self, vs: str, major: int, minor: int, tag: str, patch: int, build: int): + v = ops.JujuVersion(vs) + assert v.major == major + assert v.minor == minor + assert v.tag == tag + assert v.patch == patch + assert v.build == build @unittest.mock.patch('os.environ', new={}) # type: ignore def test_from_environ(self): @@ -93,88 +90,82 @@ def test_supports_exec_service_context(self): assert ops.JujuVersion('3.3.0').supports_exec_service_context assert ops.JujuVersion('3.4.0').supports_exec_service_context - def test_parsing_errors(self): - invalid_versions = [ - "xyz", - "foo.bar", - "foo.bar.baz", - "dead.beef.ca.fe", - "1234567890.2.1", # The major version is too long. - "0.2..1", # Two periods next to each other. - "1.21.alpha1", # Tag comes after period. - "1.21-alpha", # No patch number but a tag is present. - "1.21-alpha1beta", # Non-numeric string after the patch number. - "1.21-alpha-dev", # Tag duplication. - "1.21-alpha_dev3", # Underscore in a tag. - "1.21-alpha123dev3", # Non-numeric string after the patch number. - ] - for v in invalid_versions: - with pytest.raises(RuntimeError): - ops.JujuVersion(v) - - def test_equality(self): - test_cases = [ - ("1.0.0", "1.0.0", True), - ("01.0.0", "1.0.0", True), - ("10.0.0", "9.0.0", False), - ("1.0.0", "1.0.1", False), - ("1.0.1", "1.0.0", False), - ("1.0.0", "1.1.0", False), - ("1.1.0", "1.0.0", False), - ("1.0.0", "2.0.0", False), - ("1.2-alpha1", "1.2.0", False), - ("1.2-alpha2", "1.2-alpha1", False), - ("1.2-alpha2.1", "1.2-alpha2", False), - ("1.2-alpha2.2", "1.2-alpha2.1", False), - ("1.2-beta1", "1.2-alpha1", False), - ("1.2-beta1", "1.2-alpha2.1", False), - ("1.2-beta1", "1.2.0", False), - ("1.2.1", "1.2.0", False), - ("2.0.0", "1.0.0", False), - ("2.0.0.0", "2.0.0", True), - ("2.0.0.0", "2.0.0.0", True), - ("2.0.0.1", "2.0.0.0", False), - ("2.0.1.10", "2.0.0.0", False), - ] - - for a, b, expected in test_cases: - assert (ops.JujuVersion(a) == ops.JujuVersion(b)) == expected - assert (ops.JujuVersion(a) == b) == expected - - def test_comparison(self): - test_cases = [ - ("1.0.0", "1.0.0", False, True), - ("01.0.0", "1.0.0", False, True), - ("10.0.0", "9.0.0", False, False), - ("1.0.0", "1.0.1", True, True), - ("1.0.1", "1.0.0", False, False), - ("1.0.0", "1.1.0", True, True), - ("1.1.0", "1.0.0", False, False), - ("1.0.0", "2.0.0", True, True), - ("1.2-alpha1", "1.2.0", True, True), - ("1.2-alpha2", "1.2-alpha1", False, False), - ("1.2-alpha2.1", "1.2-alpha2", False, False), - ("1.2-alpha2.2", "1.2-alpha2.1", False, False), - ("1.2-beta1", "1.2-alpha1", False, False), - ("1.2-beta1", "1.2-alpha2.1", False, False), - ("1.2-beta1", "1.2.0", True, True), - ("1.2.1", "1.2.0", False, False), - ("2.0.0", "1.0.0", False, False), - ("2.0.0.0", "2.0.0", False, True), - ("2.0.0.0", "2.0.0.0", False, True), - ("2.0.0.1", "2.0.0.0", False, False), - ("2.0.1.10", "2.0.0.0", False, False), - ("2.10.0", "2.8.0", False, False), - ] - - for a, b, expected_strict, expected_weak in test_cases: - with self.subTest(a=a, b=b): - assert (ops.JujuVersion(a) < ops.JujuVersion(b)) == expected_strict - assert (ops.JujuVersion(a) <= ops.JujuVersion(b)) == expected_weak - assert (ops.JujuVersion(b) > ops.JujuVersion(a)) == expected_strict - assert (ops.JujuVersion(b) >= ops.JujuVersion(a)) == expected_weak - # Implicit conversion. - assert (ops.JujuVersion(a) < b) == expected_strict - assert (ops.JujuVersion(a) <= b) == expected_weak - assert (b > ops.JujuVersion(a)) == expected_strict - assert (b >= ops.JujuVersion(a)) == expected_weak + @pytest.mark.parametrize("invalid_version", [ + "xyz", + "foo.bar", + "foo.bar.baz", + "dead.beef.ca.fe", + "1234567890.2.1", # The major version is too long. + "0.2..1", # Two periods next to each other. + "1.21.alpha1", # Tag comes after period. + "1.21-alpha", # No patch number but a tag is present. + "1.21-alpha1beta", # Non-numeric string after the patch number. + "1.21-alpha-dev", # Tag duplication. + "1.21-alpha_dev3", # Underscore in a tag. + "1.21-alpha123dev3", # Non-numeric string after the patch number. + ]) + def test_parsing_errors(self, invalid_version: str): + with pytest.raises(RuntimeError): + ops.JujuVersion(invalid_version) + + @pytest.mark.parametrize("a,b,expected", [ + ("1.0.0", "1.0.0", True), + ("01.0.0", "1.0.0", True), + ("10.0.0", "9.0.0", False), + ("1.0.0", "1.0.1", False), + ("1.0.1", "1.0.0", False), + ("1.0.0", "1.1.0", False), + ("1.1.0", "1.0.0", False), + ("1.0.0", "2.0.0", False), + ("1.2-alpha1", "1.2.0", False), + ("1.2-alpha2", "1.2-alpha1", False), + ("1.2-alpha2.1", "1.2-alpha2", False), + ("1.2-alpha2.2", "1.2-alpha2.1", False), + ("1.2-beta1", "1.2-alpha1", False), + ("1.2-beta1", "1.2-alpha2.1", False), + ("1.2-beta1", "1.2.0", False), + ("1.2.1", "1.2.0", False), + ("2.0.0", "1.0.0", False), + ("2.0.0.0", "2.0.0", True), + ("2.0.0.0", "2.0.0.0", True), + ("2.0.0.1", "2.0.0.0", False), + ("2.0.1.10", "2.0.0.0", False), + ]) + def test_equality(self, a: str, b: str, expected: bool): + assert (ops.JujuVersion(a) == ops.JujuVersion(b)) == expected + assert (ops.JujuVersion(a) == b) == expected + + @pytest.mark.parametrize("a,b,expected_strict,expected_weak", [ + ("1.0.0", "1.0.0", False, True), + ("01.0.0", "1.0.0", False, True), + ("10.0.0", "9.0.0", False, False), + ("1.0.0", "1.0.1", True, True), + ("1.0.1", "1.0.0", False, False), + ("1.0.0", "1.1.0", True, True), + ("1.1.0", "1.0.0", False, False), + ("1.0.0", "2.0.0", True, True), + ("1.2-alpha1", "1.2.0", True, True), + ("1.2-alpha2", "1.2-alpha1", False, False), + ("1.2-alpha2.1", "1.2-alpha2", False, False), + ("1.2-alpha2.2", "1.2-alpha2.1", False, False), + ("1.2-beta1", "1.2-alpha1", False, False), + ("1.2-beta1", "1.2-alpha2.1", False, False), + ("1.2-beta1", "1.2.0", True, True), + ("1.2.1", "1.2.0", False, False), + ("2.0.0", "1.0.0", False, False), + ("2.0.0.0", "2.0.0", False, True), + ("2.0.0.0", "2.0.0.0", False, True), + ("2.0.0.1", "2.0.0.0", False, False), + ("2.0.1.10", "2.0.0.0", False, False), + ("2.10.0", "2.8.0", False, False), + ]) + def test_comparison(self, a: str, b: str, expected_strict: bool, expected_weak: bool): + assert (ops.JujuVersion(a) < ops.JujuVersion(b)) == expected_strict + assert (ops.JujuVersion(a) <= ops.JujuVersion(b)) == expected_weak + assert (ops.JujuVersion(b) > ops.JujuVersion(a)) == expected_strict + assert (ops.JujuVersion(b) >= ops.JujuVersion(a)) == expected_weak + # Implicit conversion. + assert (ops.JujuVersion(a) < b) == expected_strict + assert (ops.JujuVersion(a) <= b) == expected_weak + assert (b > ops.JujuVersion(a)) == expected_strict + assert (b >= ops.JujuVersion(a)) == expected_weak diff --git a/test/test_log.py b/test/test_log.py index 1f6da2fdc..11f052a09 100644 --- a/test/test_log.py +++ b/test/test_log.py @@ -19,6 +19,8 @@ import unittest from unittest.mock import patch +import pytest + import ops.log from ops.model import MAX_LOG_LINE_LEN, _ModelBackend @@ -39,76 +41,78 @@ def juju_log(self, level: str, message: str): self._calls.append((level, line)) -class TestLogging(unittest.TestCase): - - def setUp(self): - self.backend = FakeModelBackend() - - def tearDown(self): - logging.getLogger().handlers.clear() - - def test_default_logging(self): - ops.log.setup_root_logging(self.backend) - - logger = logging.getLogger() +@pytest.fixture() +def backend(): + return FakeModelBackend() + + +@pytest.fixture() +def logger(): + logger = logging.getLogger() + yield logger + logging.getLogger().handlers.clear() + + +class TestLogging: + @pytest.mark.parametrize("message,result", [ + ('critical', ('CRITICAL', 'critical')), + ('error', ('ERROR', 'error')), + ('warning', ('WARNING', 'warning')), + ('info', ('INFO', 'info')), + ('debug', ('DEBUG', 'debug')), + ]) + def test_default_logging(self, + backend: FakeModelBackend, + logger: logging.Logger, + message: str, + result: typing.Tuple[str, str]): + ops.log.setup_root_logging(backend) assert logger.level == logging.DEBUG assert isinstance(logger.handlers[-1], ops.log.JujuLogHandler) - test_cases = [ - (logger.critical, 'critical', ('CRITICAL', 'critical')), - (logger.error, 'error', ('ERROR', 'error')), - (logger.warning, 'warning', ('WARNING', 'warning')), - (logger.info, 'info', ('INFO', 'info')), - (logger.debug, 'debug', ('DEBUG', 'debug')), - ] - - for method, message, result in test_cases: - with self.subTest(message): - method(message) - calls = self.backend.calls(clear=True) - assert calls == [result] + method = getattr(logger, message) + method(message) + calls = backend.calls(clear=True) + assert calls == [result] - def test_handler_filtering(self): - logger = logging.getLogger() + def test_handler_filtering(self, backend: FakeModelBackend, logger: logging.Logger): logger.setLevel(logging.INFO) - logger.addHandler(ops.log.JujuLogHandler(self.backend, logging.WARNING)) + logger.addHandler(ops.log.JujuLogHandler(backend, logging.WARNING)) logger.info('foo') - assert self.backend.calls() == [] + assert backend.calls() == [] logger.warning('bar') - assert self.backend.calls() == [('WARNING', 'bar')] + assert backend.calls() == [('WARNING', 'bar')] - def test_no_stderr_without_debug(self): + def test_no_stderr_without_debug(self, backend: FakeModelBackend, logger: logging.Logger): buffer = io.StringIO() with patch('sys.stderr', buffer): - ops.log.setup_root_logging(self.backend, debug=False) - logger = logging.getLogger() + ops.log.setup_root_logging(backend, debug=False) logger.debug('debug message') logger.info('info message') logger.warning('warning message') logger.critical('critical message') - assert self.backend.calls() == \ - [('DEBUG', 'debug message'), - ('INFO', 'info message'), - ('WARNING', 'warning message'), - ('CRITICAL', 'critical message'), - ] + assert backend.calls() == [ + ('DEBUG', 'debug message'), + ('INFO', 'info message'), + ('WARNING', 'warning message'), + ('CRITICAL', 'critical message'), + ] assert buffer.getvalue() == "" - def test_debug_logging(self): + def test_debug_logging(self, backend: FakeModelBackend, logger: logging.Logger): buffer = io.StringIO() with patch('sys.stderr', buffer): - ops.log.setup_root_logging(self.backend, debug=True) - logger = logging.getLogger() + ops.log.setup_root_logging(backend, debug=True) logger.debug('debug message') logger.info('info message') logger.warning('warning message') logger.critical('critical message') - assert self.backend.calls() == \ - [('DEBUG', 'debug message'), - ('INFO', 'info message'), - ('WARNING', 'warning message'), - ('CRITICAL', 'critical message'), - ] + assert backend.calls() == [ + ('DEBUG', 'debug message'), + ('INFO', 'info message'), + ('WARNING', 'warning message'), + ('CRITICAL', 'critical message'), + ] assert re.search( r"\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d,\d\d\d DEBUG debug message\n" r"\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d,\d\d\d INFO info message\n" @@ -117,31 +121,29 @@ def test_debug_logging(self): buffer.getvalue() ) - def test_reduced_logging(self): - ops.log.setup_root_logging(self.backend) - logger = logging.getLogger() + def test_reduced_logging(self, backend: FakeModelBackend, logger: logging.Logger): + ops.log.setup_root_logging(backend) logger.setLevel(logging.WARNING) logger.debug('debug') logger.info('info') logger.warning('warning') - assert self.backend.calls() == [('WARNING', 'warning')] + assert backend.calls() == [('WARNING', 'warning')] - def test_long_string_logging(self): + def test_long_string_logging(self, backend: FakeModelBackend, logger: logging.Logger): buffer = io.StringIO() with patch('sys.stderr', buffer): - ops.log.setup_root_logging(self.backend, debug=True) - logger = logging.getLogger() + ops.log.setup_root_logging(backend, debug=True) logger.debug('l' * MAX_LOG_LINE_LEN) - assert len(self.backend.calls()) == 1 + assert len(backend.calls()) == 1 - self.backend.calls(clear=True) + backend.calls(clear=True) with patch('sys.stderr', buffer): logger.debug('l' * (MAX_LOG_LINE_LEN + 9)) - calls = self.backend.calls() + calls = backend.calls() assert len(calls) == 3 # Verify that we note that we are splitting the log message. assert "Splitting into multiple chunks" in calls[0][1] diff --git a/test/test_pebble.py b/test/test_pebble.py index 08470dfae..04086921b 100644 --- a/test/test_pebble.py +++ b/test/test_pebble.py @@ -1426,159 +1426,158 @@ def build_mock_change_dict(change_id: str = '70') -> 'pebble._ChangeDict': } -class TestMultipartParser(unittest.TestCase): - class _Case: - def __init__( - self, - name: str, - data: bytes, - want_headers: typing.List[bytes], - want_bodies: typing.List[bytes], - want_bodies_done: typing.List[bool], - max_boundary: int = 14, - max_lookahead: int = 8 * 1024, - error: str = ''): - self.name = name - self.data = data - self.want_headers = want_headers - self.want_bodies = want_bodies - self.want_bodies_done = want_bodies_done - self.max_boundary = max_boundary - self.max_lookahead = max_lookahead - self.error = error - - def test_multipart_parser(self): - tests = [ - TestMultipartParser._Case( - 'baseline', - b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\nfoo bar\r\n--qwerty--\r\n', - [b'header foo\r\n\r\n'], - [b'foo bar\nfoo bar'], - want_bodies_done=[True], - ), - TestMultipartParser._Case( - 'incomplete header', - b'\r\n--qwerty\r\nheader foo\r\n', - [], - [], - want_bodies_done=[], - ), - TestMultipartParser._Case( - 'missing header', - b'\r\n--qwerty\r\nheader foo\r\n' + 40 * b' ', - [], - [], - want_bodies_done=[], - max_lookahead=40, - error='header terminator not found', - ), - TestMultipartParser._Case( - 'incomplete body terminator', - b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\rhello my name is joe and I work in a button factory', # noqa - [b'header foo\r\n\r\n'], - [b'foo bar\r\n--qwerty\rhello my name is joe and I work in a '], - want_bodies_done=[False], - ), - TestMultipartParser._Case( - 'empty body', - b'\r\n--qwerty\r\nheader foo\r\n\r\n\r\n--qwerty\r\n', - [b'header foo\r\n\r\n'], - [b''], - want_bodies_done=[True], - ), - TestMultipartParser._Case( - 'ignore leading garbage', - b'hello my name is joe\r\n\n\n\n\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\n', # noqa - [b'header foo\r\n\r\n'], - [b'foo bar'], - want_bodies_done=[True], - ), - TestMultipartParser._Case( - 'ignore trailing garbage', - b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\nhello my name is joe', - [b'header foo\r\n\r\n'], - [b'foo bar'], - want_bodies_done=[True], - ), - TestMultipartParser._Case( - 'boundary allow linear whitespace', - b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\n', - [b'header foo\r\n\r\n'], - [b'foo bar'], - want_bodies_done=[True], - max_boundary=20, - ), - TestMultipartParser._Case( - 'terminal boundary allow linear whitespace', - b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty-- \t \r\n', - [b'header foo\r\n\r\n'], - [b'foo bar'], - want_bodies_done=[True], - max_boundary=20, - ), - TestMultipartParser._Case( - 'multiple parts', - b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\nheader bar\r\n\r\nfoo baz\r\n--qwerty--\r\n', # noqa - [b'header foo\r\n\r\n', b'header bar\r\n\r\n'], - [b'foo bar', b'foo baz'], - want_bodies_done=[True, True], - ), - TestMultipartParser._Case( - 'ignore after terminal boundary', - b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty--\r\nheader bar\r\n\r\nfoo baz\r\n--qwerty--\r\n', # noqa - [b'header foo\r\n\r\n'], - [b'foo bar'], - want_bodies_done=[True], - ), - ] - +class MultipartParserTestCase: + def __init__( + self, + name: str, + data: bytes, + want_headers: typing.List[bytes], + want_bodies: typing.List[bytes], + want_bodies_done: typing.List[bool], + max_boundary: int = 14, + max_lookahead: int = 8 * 1024, + error: str = ''): + self.name = name + self.data = data + self.want_headers = want_headers + self.want_bodies = want_bodies + self.want_bodies_done = want_bodies_done + self.max_boundary = max_boundary + self.max_lookahead = max_lookahead + self.error = error + + +class TestMultipartParser: + @pytest.mark.parametrize("test", [ + MultipartParserTestCase( + 'baseline', + b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\nfoo bar\r\n--qwerty--\r\n', + [b'header foo\r\n\r\n'], + [b'foo bar\nfoo bar'], + want_bodies_done=[True], + ), + MultipartParserTestCase( + 'incomplete header', + b'\r\n--qwerty\r\nheader foo\r\n', + [], + [], + want_bodies_done=[], + ), + MultipartParserTestCase( + 'missing header', + b'\r\n--qwerty\r\nheader foo\r\n' + 40 * b' ', + [], + [], + want_bodies_done=[], + max_lookahead=40, + error='header terminator not found', + ), + MultipartParserTestCase( + 'incomplete body terminator', + b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\rhello my name is joe and I work in a button factory', # noqa + [b'header foo\r\n\r\n'], + [b'foo bar\r\n--qwerty\rhello my name is joe and I work in a '], + want_bodies_done=[False], + ), + MultipartParserTestCase( + 'empty body', + b'\r\n--qwerty\r\nheader foo\r\n\r\n\r\n--qwerty\r\n', + [b'header foo\r\n\r\n'], + [b''], + want_bodies_done=[True], + ), + MultipartParserTestCase( + 'ignore leading garbage', + b'hello my name is joe\r\n\n\n\n\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\n', # noqa + [b'header foo\r\n\r\n'], + [b'foo bar'], + want_bodies_done=[True], + ), + MultipartParserTestCase( + 'ignore trailing garbage', + b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\nhello my name is joe', + [b'header foo\r\n\r\n'], + [b'foo bar'], + want_bodies_done=[True], + ), + MultipartParserTestCase( + 'boundary allow linear whitespace', + b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\n', + [b'header foo\r\n\r\n'], + [b'foo bar'], + want_bodies_done=[True], + max_boundary=20, + ), + MultipartParserTestCase( + 'terminal boundary allow linear whitespace', + b'\r\n--qwerty\r\nheader foo\r\n\r\nfoo bar\r\n--qwerty-- \t \r\n', + [b'header foo\r\n\r\n'], + [b'foo bar'], + want_bodies_done=[True], + max_boundary=20, + ), + MultipartParserTestCase( + 'multiple parts', + b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty\r\nheader bar\r\n\r\nfoo baz\r\n--qwerty--\r\n', # noqa + [b'header foo\r\n\r\n', b'header bar\r\n\r\n'], + [b'foo bar', b'foo baz'], + want_bodies_done=[True, True], + ), + MultipartParserTestCase( + 'ignore after terminal boundary', + b'\r\n--qwerty \t \r\nheader foo\r\n\r\nfoo bar\r\n--qwerty--\r\nheader bar\r\n\r\nfoo baz\r\n--qwerty--\r\n', # noqa + [b'header foo\r\n\r\n'], + [b'foo bar'], + want_bodies_done=[True], + ), + ]) + def test_multipart_parser(self, test: MultipartParserTestCase): chunk_sizes = [1, 2, 3, 4, 5, 7, 13, 17, 19, 23, 29, 31, 37, 42, 50, 100, 1000] marker = b'qwerty' - for i, test in enumerate(tests): - for chunk_size in chunk_sizes: - headers: typing.List[bytes] = [] - bodies: typing.List[bytes] = [] - bodies_done: typing.List[bool] = [] - - # All of the "noqa: B023" here are due to a ruff bug: - # https://github.com/astral-sh/ruff/issues/7847 - # ruff should tell us when the 'noqa's are no longer required. - def handle_header(data: typing.Any): - headers.append(bytes(data)) # noqa: B023 - bodies.append(b'') # noqa: B023 - bodies_done.append(False) # noqa: B023 - - def handle_body(data: bytes, done: bool = False): - bodies[-1] += data # noqa: B023 - bodies_done[-1] = done # noqa: B023 - - parser = pebble._MultipartParser( - marker, - handle_header, - handle_body, - max_boundary_length=test.max_boundary, - max_lookahead=test.max_lookahead) - src = io.BytesIO(test.data) - - try: - while True: - data = src.read(chunk_size) - if not data: - break - parser.feed(data) - except Exception as err: - if not test.error: - self.fail(f'unexpected error: {err}') + for chunk_size in chunk_sizes: + headers: typing.List[bytes] = [] + bodies: typing.List[bytes] = [] + bodies_done: typing.List[bool] = [] + + # All of the "noqa: B023" here are due to a ruff bug: + # https://github.com/astral-sh/ruff/issues/7847 + # ruff should tell us when the 'noqa's are no longer required. + def handle_header(data: typing.Any): + headers.append(bytes(data)) # noqa: B023 + bodies.append(b'') # noqa: B023 + bodies_done.append(False) # noqa: B023 + + def handle_body(data: bytes, done: bool = False): + bodies[-1] += data # noqa: B023 + bodies_done[-1] = done # noqa: B023 + + parser = pebble._MultipartParser( + marker, + handle_header, + handle_body, + max_boundary_length=test.max_boundary, + max_lookahead=test.max_lookahead) + src = io.BytesIO(test.data) + + try: + while True: + data = src.read(chunk_size) + if not data: break - assert test.error == str(err) - else: - if test.error: - self.fail(f'missing expected error: {test.error!r}') - - msg = f'test case {i + 1} ({test.name}), chunk size {chunk_size}' - assert test.want_headers == headers, msg - assert test.want_bodies == bodies, msg - assert test.want_bodies_done == bodies_done, msg + parser.feed(data) + except Exception as err: + if not test.error: + pytest.fail(f'unexpected error: {err}') + break + assert test.error == str(err) + else: + if test.error: + pytest.fail(f'missing expected error: {test.error!r}') + + msg = f'test case ({test.name}), chunk size {chunk_size}' + assert test.want_headers == headers, msg + assert test.want_bodies == bodies, msg + assert test.want_bodies_done == bodies_done, msg class TestClient(unittest.TestCase): diff --git a/test/test_private.py b/test/test_private.py index 54ec5fefe..0e31cec5f 100644 --- a/test/test_private.py +++ b/test/test_private.py @@ -50,7 +50,7 @@ def test_safe_dump(self): yaml.safe_dump(YAMLTest()) -class TestStrconv(unittest.TestCase): +class TestStrconv: def test_parse_rfc3339(self): nzdt = datetime.timezone(datetime.timedelta(hours=13)) utc = datetime.timezone.utc @@ -102,84 +102,80 @@ def test_parse_rfc3339(self): with pytest.raises(ValueError): timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777-99:99') - def test_parse_duration(self): + @pytest.mark.parametrize("input,expected", [ # Test cases taken from Go's time.ParseDuration tests - cases = [ - # simple - ('0', datetime.timedelta(seconds=0)), - ('5s', datetime.timedelta(seconds=5)), - ('30s', datetime.timedelta(seconds=30)), - ('1478s', datetime.timedelta(seconds=1478)), - # sign - ('-5s', datetime.timedelta(seconds=-5)), - ('+5s', datetime.timedelta(seconds=5)), - ('-0', datetime.timedelta(seconds=0)), - ('+0', datetime.timedelta(seconds=0)), - # decimal - ('5.0s', datetime.timedelta(seconds=5)), - ('5.6s', datetime.timedelta(seconds=5.6)), - ('5.s', datetime.timedelta(seconds=5)), - ('.5s', datetime.timedelta(seconds=0.5)), - ('1.0s', datetime.timedelta(seconds=1)), - ('1.00s', datetime.timedelta(seconds=1)), - ('1.004s', datetime.timedelta(seconds=1.004)), - ('1.0040s', datetime.timedelta(seconds=1.004)), - ('100.00100s', datetime.timedelta(seconds=100.001)), - # different units - ('10ns', datetime.timedelta(seconds=0.000_000_010)), - ('11us', datetime.timedelta(seconds=0.000_011)), - ('12µs', datetime.timedelta(seconds=0.000_012)), # U+00B5 # noqa: RUF001 - ('12μs', datetime.timedelta(seconds=0.000_012)), # U+03BC - ('13ms', datetime.timedelta(seconds=0.013)), - ('14s', datetime.timedelta(seconds=14)), - ('15m', datetime.timedelta(seconds=15 * 60)), - ('16h', datetime.timedelta(seconds=16 * 60 * 60)), - # composite durations - ('3h30m', datetime.timedelta(seconds=3 * 60 * 60 + 30 * 60)), - ('10.5s4m', datetime.timedelta(seconds=4 * 60 + 10.5)), - ('-2m3.4s', datetime.timedelta(seconds=-(2 * 60 + 3.4))), - ('1h2m3s4ms5us6ns', datetime.timedelta(seconds=1 * 60 * 60 + 2 * 60 + 3.004_005_006)), - ('39h9m14.425s', datetime.timedelta(seconds=39 * 60 * 60 + 9 * 60 + 14.425)), - # large value - ('52763797000ns', datetime.timedelta(seconds=52.763_797_000)), - # more than 9 digits after decimal point, see https://golang.org/issue/6617 - ('0.3333333333333333333h', datetime.timedelta(seconds=20 * 60)), - # huge string; issue 15011. - ('0.100000000000000000000h', datetime.timedelta(seconds=6 * 60)), - # This value tests the first overflow check in leadingFraction. - ('0.830103483285477580700h', datetime.timedelta(seconds=49 * 60 + 48.372_539_827)), - - # Test precision handling - ('7200000h1us', datetime.timedelta(hours=7_200_000, microseconds=1)) - ] - - for input, expected in cases: - output = timeconv.parse_duration(input) - assert output == expected, \ - f'parse_duration({input!r}): expected {expected!r}, got {output!r}' - - def test_parse_duration_errors(self): - cases = [ - # Test cases taken from Go's time.ParseDuration tests - '', - '3', - '-', - 's', - '.', - '-.', - '.s', - '+.s', - '1d', - '\x85\x85', - '\xffff', - 'hello \xffff world', - - # Additional cases - 'X3h', - '3hY', - 'X3hY', - '3.4.5s', - ] - for input in cases: - with pytest.raises(ValueError): - timeconv.parse_duration(input) + # simple + ('0', datetime.timedelta(seconds=0)), + ('5s', datetime.timedelta(seconds=5)), + ('30s', datetime.timedelta(seconds=30)), + ('1478s', datetime.timedelta(seconds=1478)), + # sign + ('-5s', datetime.timedelta(seconds=-5)), + ('+5s', datetime.timedelta(seconds=5)), + ('-0', datetime.timedelta(seconds=0)), + ('+0', datetime.timedelta(seconds=0)), + # decimal + ('5.0s', datetime.timedelta(seconds=5)), + ('5.6s', datetime.timedelta(seconds=5.6)), + ('5.s', datetime.timedelta(seconds=5)), + ('.5s', datetime.timedelta(seconds=0.5)), + ('1.0s', datetime.timedelta(seconds=1)), + ('1.00s', datetime.timedelta(seconds=1)), + ('1.004s', datetime.timedelta(seconds=1.004)), + ('1.0040s', datetime.timedelta(seconds=1.004)), + ('100.00100s', datetime.timedelta(seconds=100.001)), + # different units + ('10ns', datetime.timedelta(seconds=0.000_000_010)), + ('11us', datetime.timedelta(seconds=0.000_011)), + ('12µs', datetime.timedelta(seconds=0.000_012)), # U+00B5 # noqa: RUF001 + ('12μs', datetime.timedelta(seconds=0.000_012)), # U+03BC + ('13ms', datetime.timedelta(seconds=0.013)), + ('14s', datetime.timedelta(seconds=14)), + ('15m', datetime.timedelta(seconds=15 * 60)), + ('16h', datetime.timedelta(seconds=16 * 60 * 60)), + # composite durations + ('3h30m', datetime.timedelta(seconds=3 * 60 * 60 + 30 * 60)), + ('10.5s4m', datetime.timedelta(seconds=4 * 60 + 10.5)), + ('-2m3.4s', datetime.timedelta(seconds=-(2 * 60 + 3.4))), + ('1h2m3s4ms5us6ns', datetime.timedelta(seconds=1 * 60 * 60 + 2 * 60 + 3.004_005_006)), + ('39h9m14.425s', datetime.timedelta(seconds=39 * 60 * 60 + 9 * 60 + 14.425)), + # large value + ('52763797000ns', datetime.timedelta(seconds=52.763_797_000)), + # more than 9 digits after decimal point, see https://golang.org/issue/6617 + ('0.3333333333333333333h', datetime.timedelta(seconds=20 * 60)), + # huge string; issue 15011. + ('0.100000000000000000000h', datetime.timedelta(seconds=6 * 60)), + # This value tests the first overflow check in leadingFraction. + ('0.830103483285477580700h', datetime.timedelta(seconds=49 * 60 + 48.372_539_827)), + # Test precision handling + ('7200000h1us', datetime.timedelta(hours=7_200_000, microseconds=1)) + ]) + def test_parse_duration(self, input: str, expected: datetime.timedelta): + output = timeconv.parse_duration(input) + assert output == expected, \ + f'parse_duration({input!r}): expected {expected!r}, got {output!r}' + + @pytest.mark.parametrize("input", [ + # Test cases taken from Go's time.ParseDuration tests + '', + '3', + '-', + 's', + '.', + '-.', + '.s', + '+.s', + '1d', + '\x85\x85', + '\xffff', + 'hello \xffff world', + + # Additional cases + 'X3h', + '3hY', + 'X3hY', + '3.4.5s', + ]) + def test_parse_duration_errors(self, input: str): + with pytest.raises(ValueError): + timeconv.parse_duration(input) From b7d7f07722611bb3763d09bfe0c9b524fef2e25e Mon Sep 17 00:00:00 2001 From: Tiexin Guo Date: Fri, 19 Apr 2024 18:51:13 +0800 Subject: [PATCH 2/3] test: remove class from test jujuversion --- test/test_jujuversion.py | 307 ++++++++++++++++++++------------------- 1 file changed, 158 insertions(+), 149 deletions(-) diff --git a/test/test_jujuversion.py b/test/test_jujuversion.py index f9b2bdb52..a2399f74f 100644 --- a/test/test_jujuversion.py +++ b/test/test_jujuversion.py @@ -20,152 +20,161 @@ import ops -class TestJujuVersion: - @pytest.mark.parametrize("vs,major,minor,tag,patch,build", [ - ("0.0.0", 0, 0, '', 0, 0), - ("0.0.2", 0, 0, '', 2, 0), - ("0.1.0", 0, 1, '', 0, 0), - ("0.2.3", 0, 2, '', 3, 0), - ("10.234.3456", 10, 234, '', 3456, 0), - ("10.234.3456.1", 10, 234, '', 3456, 1), - ("1.21-alpha12", 1, 21, 'alpha', 12, 0), - ("1.21-alpha1.34", 1, 21, 'alpha', 1, 34), - ("2.7", 2, 7, '', 0, 0) - ]) - def test_parsing(self, vs: str, major: int, minor: int, tag: str, patch: int, build: int): - v = ops.JujuVersion(vs) - assert v.major == major - assert v.minor == minor - assert v.tag == tag - assert v.patch == patch - assert v.build == build - - @unittest.mock.patch('os.environ', new={}) # type: ignore - def test_from_environ(self): - # JUJU_VERSION is not set - v = ops.JujuVersion.from_environ() - assert v == ops.JujuVersion('0.0.0') - - os.environ['JUJU_VERSION'] = 'no' - with pytest.raises(RuntimeError, match='not a valid Juju version'): - ops.JujuVersion.from_environ() - - os.environ['JUJU_VERSION'] = '2.8.0' - v = ops.JujuVersion.from_environ() - assert v == ops.JujuVersion('2.8.0') - - def test_has_app_data(self): - assert ops.JujuVersion('2.8.0').has_app_data() - assert ops.JujuVersion('2.7.0').has_app_data() - assert not ops.JujuVersion('2.6.9').has_app_data() - - def test_is_dispatch_aware(self): - assert ops.JujuVersion('2.8.0').is_dispatch_aware() - assert not ops.JujuVersion('2.7.9').is_dispatch_aware() - - def test_has_controller_storage(self): - assert ops.JujuVersion('2.8.0').has_controller_storage() - assert not ops.JujuVersion('2.7.9').has_controller_storage() - - def test_has_secrets(self): - assert ops.JujuVersion('3.0.3').has_secrets - assert ops.JujuVersion('3.1.0').has_secrets - assert not ops.JujuVersion('3.0.2').has_secrets - assert not ops.JujuVersion('2.9.30').has_secrets - - def test_supports_open_port_on_k8s(self): - assert ops.JujuVersion('3.0.3').supports_open_port_on_k8s - assert ops.JujuVersion('3.3.0').supports_open_port_on_k8s - assert not ops.JujuVersion('3.0.2').supports_open_port_on_k8s - assert not ops.JujuVersion('2.9.30').supports_open_port_on_k8s - - def test_supports_exec_service_context(self): - assert not ops.JujuVersion('2.9.30').supports_exec_service_context - assert ops.JujuVersion('4.0.0').supports_exec_service_context - assert not ops.JujuVersion('3.0.0').supports_exec_service_context - assert not ops.JujuVersion('3.1.5').supports_exec_service_context - assert ops.JujuVersion('3.1.6').supports_exec_service_context - assert not ops.JujuVersion('3.2.0').supports_exec_service_context - assert ops.JujuVersion('3.2.2').supports_exec_service_context - assert ops.JujuVersion('3.3.0').supports_exec_service_context - assert ops.JujuVersion('3.4.0').supports_exec_service_context - - @pytest.mark.parametrize("invalid_version", [ - "xyz", - "foo.bar", - "foo.bar.baz", - "dead.beef.ca.fe", - "1234567890.2.1", # The major version is too long. - "0.2..1", # Two periods next to each other. - "1.21.alpha1", # Tag comes after period. - "1.21-alpha", # No patch number but a tag is present. - "1.21-alpha1beta", # Non-numeric string after the patch number. - "1.21-alpha-dev", # Tag duplication. - "1.21-alpha_dev3", # Underscore in a tag. - "1.21-alpha123dev3", # Non-numeric string after the patch number. - ]) - def test_parsing_errors(self, invalid_version: str): - with pytest.raises(RuntimeError): - ops.JujuVersion(invalid_version) - - @pytest.mark.parametrize("a,b,expected", [ - ("1.0.0", "1.0.0", True), - ("01.0.0", "1.0.0", True), - ("10.0.0", "9.0.0", False), - ("1.0.0", "1.0.1", False), - ("1.0.1", "1.0.0", False), - ("1.0.0", "1.1.0", False), - ("1.1.0", "1.0.0", False), - ("1.0.0", "2.0.0", False), - ("1.2-alpha1", "1.2.0", False), - ("1.2-alpha2", "1.2-alpha1", False), - ("1.2-alpha2.1", "1.2-alpha2", False), - ("1.2-alpha2.2", "1.2-alpha2.1", False), - ("1.2-beta1", "1.2-alpha1", False), - ("1.2-beta1", "1.2-alpha2.1", False), - ("1.2-beta1", "1.2.0", False), - ("1.2.1", "1.2.0", False), - ("2.0.0", "1.0.0", False), - ("2.0.0.0", "2.0.0", True), - ("2.0.0.0", "2.0.0.0", True), - ("2.0.0.1", "2.0.0.0", False), - ("2.0.1.10", "2.0.0.0", False), - ]) - def test_equality(self, a: str, b: str, expected: bool): - assert (ops.JujuVersion(a) == ops.JujuVersion(b)) == expected - assert (ops.JujuVersion(a) == b) == expected - - @pytest.mark.parametrize("a,b,expected_strict,expected_weak", [ - ("1.0.0", "1.0.0", False, True), - ("01.0.0", "1.0.0", False, True), - ("10.0.0", "9.0.0", False, False), - ("1.0.0", "1.0.1", True, True), - ("1.0.1", "1.0.0", False, False), - ("1.0.0", "1.1.0", True, True), - ("1.1.0", "1.0.0", False, False), - ("1.0.0", "2.0.0", True, True), - ("1.2-alpha1", "1.2.0", True, True), - ("1.2-alpha2", "1.2-alpha1", False, False), - ("1.2-alpha2.1", "1.2-alpha2", False, False), - ("1.2-alpha2.2", "1.2-alpha2.1", False, False), - ("1.2-beta1", "1.2-alpha1", False, False), - ("1.2-beta1", "1.2-alpha2.1", False, False), - ("1.2-beta1", "1.2.0", True, True), - ("1.2.1", "1.2.0", False, False), - ("2.0.0", "1.0.0", False, False), - ("2.0.0.0", "2.0.0", False, True), - ("2.0.0.0", "2.0.0.0", False, True), - ("2.0.0.1", "2.0.0.0", False, False), - ("2.0.1.10", "2.0.0.0", False, False), - ("2.10.0", "2.8.0", False, False), - ]) - def test_comparison(self, a: str, b: str, expected_strict: bool, expected_weak: bool): - assert (ops.JujuVersion(a) < ops.JujuVersion(b)) == expected_strict - assert (ops.JujuVersion(a) <= ops.JujuVersion(b)) == expected_weak - assert (ops.JujuVersion(b) > ops.JujuVersion(a)) == expected_strict - assert (ops.JujuVersion(b) >= ops.JujuVersion(a)) == expected_weak - # Implicit conversion. - assert (ops.JujuVersion(a) < b) == expected_strict - assert (ops.JujuVersion(a) <= b) == expected_weak - assert (b > ops.JujuVersion(a)) == expected_strict - assert (b >= ops.JujuVersion(a)) == expected_weak +@pytest.mark.parametrize("vs,major,minor,tag,patch,build", [ + ("0.0.0", 0, 0, '', 0, 0), + ("0.0.2", 0, 0, '', 2, 0), + ("0.1.0", 0, 1, '', 0, 0), + ("0.2.3", 0, 2, '', 3, 0), + ("10.234.3456", 10, 234, '', 3456, 0), + ("10.234.3456.1", 10, 234, '', 3456, 1), + ("1.21-alpha12", 1, 21, 'alpha', 12, 0), + ("1.21-alpha1.34", 1, 21, 'alpha', 1, 34), + ("2.7", 2, 7, '', 0, 0) +]) +def test_parsing(vs: str, major: int, minor: int, tag: str, patch: int, build: int): + v = ops.JujuVersion(vs) + assert v.major == major + assert v.minor == minor + assert v.tag == tag + assert v.patch == patch + assert v.build == build + + +@unittest.mock.patch('os.environ', new={}) # type: ignore +def test_from_environ(): + # JUJU_VERSION is not set + v = ops.JujuVersion.from_environ() + assert v == ops.JujuVersion('0.0.0') + + os.environ['JUJU_VERSION'] = 'no' + with pytest.raises(RuntimeError, match='not a valid Juju version'): + ops.JujuVersion.from_environ() + + os.environ['JUJU_VERSION'] = '2.8.0' + v = ops.JujuVersion.from_environ() + assert v == ops.JujuVersion('2.8.0') + + +def test_has_app_data(): + assert ops.JujuVersion('2.8.0').has_app_data() + assert ops.JujuVersion('2.7.0').has_app_data() + assert not ops.JujuVersion('2.6.9').has_app_data() + + +def test_is_dispatch_aware(): + assert ops.JujuVersion('2.8.0').is_dispatch_aware() + assert not ops.JujuVersion('2.7.9').is_dispatch_aware() + + +def test_has_controller_storage(): + assert ops.JujuVersion('2.8.0').has_controller_storage() + assert not ops.JujuVersion('2.7.9').has_controller_storage() + + +def test_has_secrets(): + assert ops.JujuVersion('3.0.3').has_secrets + assert ops.JujuVersion('3.1.0').has_secrets + assert not ops.JujuVersion('3.0.2').has_secrets + assert not ops.JujuVersion('2.9.30').has_secrets + + +def test_supports_open_port_on_k8s(): + assert ops.JujuVersion('3.0.3').supports_open_port_on_k8s + assert ops.JujuVersion('3.3.0').supports_open_port_on_k8s + assert not ops.JujuVersion('3.0.2').supports_open_port_on_k8s + assert not ops.JujuVersion('2.9.30').supports_open_port_on_k8s + + +def test_supports_exec_service_context(): + assert not ops.JujuVersion('2.9.30').supports_exec_service_context + assert ops.JujuVersion('4.0.0').supports_exec_service_context + assert not ops.JujuVersion('3.0.0').supports_exec_service_context + assert not ops.JujuVersion('3.1.5').supports_exec_service_context + assert ops.JujuVersion('3.1.6').supports_exec_service_context + assert not ops.JujuVersion('3.2.0').supports_exec_service_context + assert ops.JujuVersion('3.2.2').supports_exec_service_context + assert ops.JujuVersion('3.3.0').supports_exec_service_context + assert ops.JujuVersion('3.4.0').supports_exec_service_context + + +@pytest.mark.parametrize("invalid_version", [ + "xyz", + "foo.bar", + "foo.bar.baz", + "dead.beef.ca.fe", + "1234567890.2.1", # The major version is too long. + "0.2..1", # Two periods next to each other. + "1.21.alpha1", # Tag comes after period. + "1.21-alpha", # No patch number but a tag is present. + "1.21-alpha1beta", # Non-numeric string after the patch number. + "1.21-alpha-dev", # Tag duplication. + "1.21-alpha_dev3", # Underscore in a tag. + "1.21-alpha123dev3", # Non-numeric string after the patch number. +]) +def test_parsing_errors(invalid_version: str): + with pytest.raises(RuntimeError): + ops.JujuVersion(invalid_version) + + +@pytest.mark.parametrize("a,b,expected", [ + ("1.0.0", "1.0.0", True), + ("01.0.0", "1.0.0", True), + ("10.0.0", "9.0.0", False), + ("1.0.0", "1.0.1", False), + ("1.0.1", "1.0.0", False), + ("1.0.0", "1.1.0", False), + ("1.1.0", "1.0.0", False), + ("1.0.0", "2.0.0", False), + ("1.2-alpha1", "1.2.0", False), + ("1.2-alpha2", "1.2-alpha1", False), + ("1.2-alpha2.1", "1.2-alpha2", False), + ("1.2-alpha2.2", "1.2-alpha2.1", False), + ("1.2-beta1", "1.2-alpha1", False), + ("1.2-beta1", "1.2-alpha2.1", False), + ("1.2-beta1", "1.2.0", False), + ("1.2.1", "1.2.0", False), + ("2.0.0", "1.0.0", False), + ("2.0.0.0", "2.0.0", True), + ("2.0.0.0", "2.0.0.0", True), + ("2.0.0.1", "2.0.0.0", False), + ("2.0.1.10", "2.0.0.0", False), +]) +def test_equality(a: str, b: str, expected: bool): + assert (ops.JujuVersion(a) == ops.JujuVersion(b)) == expected + assert (ops.JujuVersion(a) == b) == expected + + +@pytest.mark.parametrize("a,b,expected_strict,expected_weak", [ + ("1.0.0", "1.0.0", False, True), + ("01.0.0", "1.0.0", False, True), + ("10.0.0", "9.0.0", False, False), + ("1.0.0", "1.0.1", True, True), + ("1.0.1", "1.0.0", False, False), + ("1.0.0", "1.1.0", True, True), + ("1.1.0", "1.0.0", False, False), + ("1.0.0", "2.0.0", True, True), + ("1.2-alpha1", "1.2.0", True, True), + ("1.2-alpha2", "1.2-alpha1", False, False), + ("1.2-alpha2.1", "1.2-alpha2", False, False), + ("1.2-alpha2.2", "1.2-alpha2.1", False, False), + ("1.2-beta1", "1.2-alpha1", False, False), + ("1.2-beta1", "1.2-alpha2.1", False, False), + ("1.2-beta1", "1.2.0", True, True), + ("1.2.1", "1.2.0", False, False), + ("2.0.0", "1.0.0", False, False), + ("2.0.0.0", "2.0.0", False, True), + ("2.0.0.0", "2.0.0.0", False, True), + ("2.0.0.1", "2.0.0.0", False, False), + ("2.0.1.10", "2.0.0.0", False, False), + ("2.10.0", "2.8.0", False, False), +]) +def test_comparison(a: str, b: str, expected_strict: bool, expected_weak: bool): + assert (ops.JujuVersion(a) < ops.JujuVersion(b)) == expected_strict + assert (ops.JujuVersion(a) <= ops.JujuVersion(b)) == expected_weak + assert (ops.JujuVersion(b) > ops.JujuVersion(a)) == expected_strict + assert (ops.JujuVersion(b) >= ops.JujuVersion(a)) == expected_weak + # Implicit conversion. + assert (ops.JujuVersion(a) < b) == expected_strict + assert (ops.JujuVersion(a) <= b) == expected_weak + assert (b > ops.JujuVersion(a)) == expected_strict + assert (b >= ops.JujuVersion(a)) == expected_weak From 8c677f4433ad5d30637176da7dcdf3798848d735 Mon Sep 17 00:00:00 2001 From: Tiexin Guo Date: Fri, 19 Apr 2024 18:56:08 +0800 Subject: [PATCH 3/3] test: split test_private into two test files --- test/test_private.py | 181 ------------------------------------------ test/test_timeconv.py | 151 +++++++++++++++++++++++++++++++++++ test/test_yaml.py | 48 +++++++++++ 3 files changed, 199 insertions(+), 181 deletions(-) delete mode 100644 test/test_private.py create mode 100644 test/test_timeconv.py create mode 100644 test/test_yaml.py diff --git a/test/test_private.py b/test/test_private.py deleted file mode 100644 index 0e31cec5f..000000000 --- a/test/test_private.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2021 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -import io -import unittest - -import pytest -import yaml as base_yaml - -from ops._private import timeconv, yaml - - -class YAMLTest: - pass - - -class TestYAML(unittest.TestCase): - def test_safe_load(self): - d = yaml.safe_load('foo: bar\nbaz: 123\n') - assert len(d) == 2 - assert d['foo'] == 'bar' - assert d['baz'] == 123 - - # Should error -- it's not safe to load an instance of a user-defined class - with pytest.raises(base_yaml.YAMLError): - yaml.safe_load('!!python/object:test.test_helpers.YAMLTest {}') - - def test_safe_dump(self): - s = yaml.safe_dump({'foo': 'bar', 'baz': 123}) - assert s == 'baz: 123\nfoo: bar\n' - - f = io.StringIO() - yaml.safe_dump({'foo': 'bar', 'baz': 123}, stream=f) - assert f.getvalue() == 'baz: 123\nfoo: bar\n' - - # Should error -- it's not safe to dump an instance of a user-defined class - with pytest.raises(base_yaml.YAMLError): - yaml.safe_dump(YAMLTest()) - - -class TestStrconv: - def test_parse_rfc3339(self): - nzdt = datetime.timezone(datetime.timedelta(hours=13)) - utc = datetime.timezone.utc - - assert timeconv.parse_rfc3339('2020-12-25T13:45:50+13:00') == \ - datetime.datetime(2020, 12, 25, 13, 45, 50, 0, tzinfo=nzdt) - - assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789+13:00') == \ - datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=nzdt) - - assert timeconv.parse_rfc3339('2021-02-10T04:36:22Z') == \ - datetime.datetime(2021, 2, 10, 4, 36, 22, 0, tzinfo=utc) - - assert timeconv.parse_rfc3339('2021-02-10t04:36:22z') == \ - datetime.datetime(2021, 2, 10, 4, 36, 22, 0, tzinfo=utc) - - assert timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777Z') == \ - datetime.datetime(2021, 2, 10, 4, 36, 22, 118971, tzinfo=utc) - - assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789+00:00') == \ - datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=utc) - - assert timeconv.parse_rfc3339('2006-08-28T13:20:00.9999999Z') == \ - datetime.datetime(2006, 8, 28, 13, 20, 0, 999999, tzinfo=utc) - - assert timeconv.parse_rfc3339('2006-12-31T23:59:59.9999999Z') == \ - datetime.datetime(2006, 12, 31, 23, 59, 59, 999999, tzinfo=utc) - - tzinfo = datetime.timezone(datetime.timedelta(hours=-11, minutes=-30)) - assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789-11:30') == \ - datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=tzinfo) - - tzinfo = datetime.timezone(datetime.timedelta(hours=4)) - assert timeconv.parse_rfc3339('2000-01-02T03:04:05.006000+04:00') == \ - datetime.datetime(2000, 1, 2, 3, 4, 5, 6000, tzinfo=tzinfo) - - with pytest.raises(ValueError): - timeconv.parse_rfc3339('') - - with pytest.raises(ValueError): - timeconv.parse_rfc3339('foobar') - - with pytest.raises(ValueError): - timeconv.parse_rfc3339('2021-99-99T04:36:22Z') - - with pytest.raises(ValueError): - timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777x') - - with pytest.raises(ValueError): - timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777-99:99') - - @pytest.mark.parametrize("input,expected", [ - # Test cases taken from Go's time.ParseDuration tests - # simple - ('0', datetime.timedelta(seconds=0)), - ('5s', datetime.timedelta(seconds=5)), - ('30s', datetime.timedelta(seconds=30)), - ('1478s', datetime.timedelta(seconds=1478)), - # sign - ('-5s', datetime.timedelta(seconds=-5)), - ('+5s', datetime.timedelta(seconds=5)), - ('-0', datetime.timedelta(seconds=0)), - ('+0', datetime.timedelta(seconds=0)), - # decimal - ('5.0s', datetime.timedelta(seconds=5)), - ('5.6s', datetime.timedelta(seconds=5.6)), - ('5.s', datetime.timedelta(seconds=5)), - ('.5s', datetime.timedelta(seconds=0.5)), - ('1.0s', datetime.timedelta(seconds=1)), - ('1.00s', datetime.timedelta(seconds=1)), - ('1.004s', datetime.timedelta(seconds=1.004)), - ('1.0040s', datetime.timedelta(seconds=1.004)), - ('100.00100s', datetime.timedelta(seconds=100.001)), - # different units - ('10ns', datetime.timedelta(seconds=0.000_000_010)), - ('11us', datetime.timedelta(seconds=0.000_011)), - ('12µs', datetime.timedelta(seconds=0.000_012)), # U+00B5 # noqa: RUF001 - ('12μs', datetime.timedelta(seconds=0.000_012)), # U+03BC - ('13ms', datetime.timedelta(seconds=0.013)), - ('14s', datetime.timedelta(seconds=14)), - ('15m', datetime.timedelta(seconds=15 * 60)), - ('16h', datetime.timedelta(seconds=16 * 60 * 60)), - # composite durations - ('3h30m', datetime.timedelta(seconds=3 * 60 * 60 + 30 * 60)), - ('10.5s4m', datetime.timedelta(seconds=4 * 60 + 10.5)), - ('-2m3.4s', datetime.timedelta(seconds=-(2 * 60 + 3.4))), - ('1h2m3s4ms5us6ns', datetime.timedelta(seconds=1 * 60 * 60 + 2 * 60 + 3.004_005_006)), - ('39h9m14.425s', datetime.timedelta(seconds=39 * 60 * 60 + 9 * 60 + 14.425)), - # large value - ('52763797000ns', datetime.timedelta(seconds=52.763_797_000)), - # more than 9 digits after decimal point, see https://golang.org/issue/6617 - ('0.3333333333333333333h', datetime.timedelta(seconds=20 * 60)), - # huge string; issue 15011. - ('0.100000000000000000000h', datetime.timedelta(seconds=6 * 60)), - # This value tests the first overflow check in leadingFraction. - ('0.830103483285477580700h', datetime.timedelta(seconds=49 * 60 + 48.372_539_827)), - # Test precision handling - ('7200000h1us', datetime.timedelta(hours=7_200_000, microseconds=1)) - ]) - def test_parse_duration(self, input: str, expected: datetime.timedelta): - output = timeconv.parse_duration(input) - assert output == expected, \ - f'parse_duration({input!r}): expected {expected!r}, got {output!r}' - - @pytest.mark.parametrize("input", [ - # Test cases taken from Go's time.ParseDuration tests - '', - '3', - '-', - 's', - '.', - '-.', - '.s', - '+.s', - '1d', - '\x85\x85', - '\xffff', - 'hello \xffff world', - - # Additional cases - 'X3h', - '3hY', - 'X3hY', - '3.4.5s', - ]) - def test_parse_duration_errors(self, input: str): - with pytest.raises(ValueError): - timeconv.parse_duration(input) diff --git a/test/test_timeconv.py b/test/test_timeconv.py new file mode 100644 index 000000000..ec7ba3c14 --- /dev/null +++ b/test/test_timeconv.py @@ -0,0 +1,151 @@ +# Copyright 2024 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +from ops._private import timeconv + + +def test_parse_rfc3339(): + nzdt = datetime.timezone(datetime.timedelta(hours=13)) + utc = datetime.timezone.utc + + assert timeconv.parse_rfc3339('2020-12-25T13:45:50+13:00') == \ + datetime.datetime(2020, 12, 25, 13, 45, 50, 0, tzinfo=nzdt) + + assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789+13:00') == \ + datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=nzdt) + + assert timeconv.parse_rfc3339('2021-02-10T04:36:22Z') == \ + datetime.datetime(2021, 2, 10, 4, 36, 22, 0, tzinfo=utc) + + assert timeconv.parse_rfc3339('2021-02-10t04:36:22z') == \ + datetime.datetime(2021, 2, 10, 4, 36, 22, 0, tzinfo=utc) + + assert timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777Z') == \ + datetime.datetime(2021, 2, 10, 4, 36, 22, 118971, tzinfo=utc) + + assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789+00:00') == \ + datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=utc) + + assert timeconv.parse_rfc3339('2006-08-28T13:20:00.9999999Z') == \ + datetime.datetime(2006, 8, 28, 13, 20, 0, 999999, tzinfo=utc) + + assert timeconv.parse_rfc3339('2006-12-31T23:59:59.9999999Z') == \ + datetime.datetime(2006, 12, 31, 23, 59, 59, 999999, tzinfo=utc) + + tzinfo = datetime.timezone(datetime.timedelta(hours=-11, minutes=-30)) + assert timeconv.parse_rfc3339('2020-12-25T13:45:50.123456789-11:30') == \ + datetime.datetime(2020, 12, 25, 13, 45, 50, 123457, tzinfo=tzinfo) + + tzinfo = datetime.timezone(datetime.timedelta(hours=4)) + assert timeconv.parse_rfc3339('2000-01-02T03:04:05.006000+04:00') == \ + datetime.datetime(2000, 1, 2, 3, 4, 5, 6000, tzinfo=tzinfo) + + with pytest.raises(ValueError): + timeconv.parse_rfc3339('') + + with pytest.raises(ValueError): + timeconv.parse_rfc3339('foobar') + + with pytest.raises(ValueError): + timeconv.parse_rfc3339('2021-99-99T04:36:22Z') + + with pytest.raises(ValueError): + timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777x') + + with pytest.raises(ValueError): + timeconv.parse_rfc3339('2021-02-10T04:36:22.118970777-99:99') + + +@pytest.mark.parametrize("input,expected", [ + # Test cases taken from Go's time.ParseDuration tests + # simple + ('0', datetime.timedelta(seconds=0)), + ('5s', datetime.timedelta(seconds=5)), + ('30s', datetime.timedelta(seconds=30)), + ('1478s', datetime.timedelta(seconds=1478)), + # sign + ('-5s', datetime.timedelta(seconds=-5)), + ('+5s', datetime.timedelta(seconds=5)), + ('-0', datetime.timedelta(seconds=0)), + ('+0', datetime.timedelta(seconds=0)), + # decimal + ('5.0s', datetime.timedelta(seconds=5)), + ('5.6s', datetime.timedelta(seconds=5.6)), + ('5.s', datetime.timedelta(seconds=5)), + ('.5s', datetime.timedelta(seconds=0.5)), + ('1.0s', datetime.timedelta(seconds=1)), + ('1.00s', datetime.timedelta(seconds=1)), + ('1.004s', datetime.timedelta(seconds=1.004)), + ('1.0040s', datetime.timedelta(seconds=1.004)), + ('100.00100s', datetime.timedelta(seconds=100.001)), + # different units + ('10ns', datetime.timedelta(seconds=0.000_000_010)), + ('11us', datetime.timedelta(seconds=0.000_011)), + ('12µs', datetime.timedelta(seconds=0.000_012)), # U+00B5 # noqa: RUF001 + ('12μs', datetime.timedelta(seconds=0.000_012)), # U+03BC + ('13ms', datetime.timedelta(seconds=0.013)), + ('14s', datetime.timedelta(seconds=14)), + ('15m', datetime.timedelta(seconds=15 * 60)), + ('16h', datetime.timedelta(seconds=16 * 60 * 60)), + # composite durations + ('3h30m', datetime.timedelta(seconds=3 * 60 * 60 + 30 * 60)), + ('10.5s4m', datetime.timedelta(seconds=4 * 60 + 10.5)), + ('-2m3.4s', datetime.timedelta(seconds=-(2 * 60 + 3.4))), + ('1h2m3s4ms5us6ns', datetime.timedelta(seconds=1 * 60 * 60 + 2 * 60 + 3.004_005_006)), + ('39h9m14.425s', datetime.timedelta(seconds=39 * 60 * 60 + 9 * 60 + 14.425)), + # large value + ('52763797000ns', datetime.timedelta(seconds=52.763_797_000)), + # more than 9 digits after decimal point, see https://golang.org/issue/6617 + ('0.3333333333333333333h', datetime.timedelta(seconds=20 * 60)), + # huge string; issue 15011. + ('0.100000000000000000000h', datetime.timedelta(seconds=6 * 60)), + # This value tests the first overflow check in leadingFraction. + ('0.830103483285477580700h', datetime.timedelta(seconds=49 * 60 + 48.372_539_827)), + # Test precision handling + ('7200000h1us', datetime.timedelta(hours=7_200_000, microseconds=1)) +]) +def test_parse_duration(input: str, expected: datetime.timedelta): + output = timeconv.parse_duration(input) + assert output == expected, \ + f'parse_duration({input!r}): expected {expected!r}, got {output!r}' + + +@pytest.mark.parametrize("input", [ + # Test cases taken from Go's time.ParseDuration tests + '', + '3', + '-', + 's', + '.', + '-.', + '.s', + '+.s', + '1d', + '\x85\x85', + '\xffff', + 'hello \xffff world', + + # Additional cases + 'X3h', + '3hY', + 'X3hY', + '3.4.5s', +]) +def test_parse_duration_errors(input: str): + with pytest.raises(ValueError): + timeconv.parse_duration(input) diff --git a/test/test_yaml.py b/test/test_yaml.py new file mode 100644 index 000000000..39d56ab20 --- /dev/null +++ b/test/test_yaml.py @@ -0,0 +1,48 @@ +# Copyright 2024 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io + +import pytest +import yaml as base_yaml + +from ops._private import yaml + + +class YAMLTest: + pass + + +def test_safe_load(): + d = yaml.safe_load('foo: bar\nbaz: 123\n') + assert len(d) == 2 + assert d['foo'] == 'bar' + assert d['baz'] == 123 + + # Should error -- it's not safe to load an instance of a user-defined class + with pytest.raises(base_yaml.YAMLError): + yaml.safe_load('!!python/object:test.test_helpers.YAMLTest {}') + + +def test_safe_dump(): + s = yaml.safe_dump({'foo': 'bar', 'baz': 123}) + assert s == 'baz: 123\nfoo: bar\n' + + f = io.StringIO() + yaml.safe_dump({'foo': 'bar', 'baz': 123}, stream=f) + assert f.getvalue() == 'baz: 123\nfoo: bar\n' + + # Should error -- it's not safe to dump an instance of a user-defined class + with pytest.raises(base_yaml.YAMLError): + yaml.safe_dump(YAMLTest())