Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add type hints to test_charm #1022

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include = ["ops/*.py", "ops/_private/*.py",
"test/test_lib.py",
"test/test_model.py",
"test/test_testing.py",
"test/test_charm.py",
]
pythonVersion = "3.8" # check no python > 3.8 features are used
pythonPlatform = "All"
Expand Down
119 changes: 61 additions & 58 deletions test/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import shutil
import tempfile
import typing
import unittest
from pathlib import Path

Expand All @@ -29,7 +30,7 @@
class TestCharm(unittest.TestCase):

def setUp(self):
def restore_env(env):
def restore_env(env: typing.Dict[str, str]):
os.environ.clear()
os.environ.update(env)
self.addCleanup(restore_env, os.environ.copy())
Expand All @@ -51,10 +52,10 @@ class TestCharmEvents(ops.CharmEvents):

# Relations events are defined dynamically and modify the class attributes.
# We use a subclass temporarily to prevent these side effects from leaking.
ops.CharmBase.on = TestCharmEvents()
ops.CharmBase.on = TestCharmEvents() # type: ignore

def cleanup():
ops.CharmBase.on = ops.CharmEvents()
ops.CharmBase.on = ops.CharmEvents() # type: ignore
self.addCleanup(cleanup)

def create_framework(self):
Expand All @@ -69,16 +70,16 @@ def test_basic(self):

class MyCharm(ops.CharmBase):

def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)

self.started = False
framework.observe(self.on.start, self._on_start)

def _on_start(self, event):
def _on_start(self, event: ops.EventBase):
self.started = True

events = list(MyCharm.on.events())
events: typing.List[str] = list(MyCharm.on.events()) # type: ignore
self.assertIn('install', events)
self.assertIn('custom', events)

Expand All @@ -89,7 +90,7 @@ def _on_start(self, event):
self.assertEqual(charm.started, True)

with self.assertRaisesRegex(TypeError, "observer methods must now be explicitly provided"):
framework.observe(charm.on.start, charm)
framework.observe(charm.on.start, charm) # type: ignore

def test_observe_decorated_method(self):
# we test that charm methods decorated with @functools.wraps(wrapper)
Expand All @@ -98,25 +99,26 @@ def test_observe_decorated_method(self):
# is more careful and it still works, this test is here to ensure that
# it keeps working in future releases, as this is presently the only
# way we know of to cleanly decorate charm event observers.
events = []
events: typing.List[ops.EventBase] = []

def dec(fn):
def dec(fn: typing.Callable[['MyCharm', ops.EventBase], None] # noqa: F821
) -> typing.Callable[..., None]:
# simple decorator that appends to the nonlocal
# `events` list all events it receives
@functools.wraps(fn)
def wrapper(charm, evt):
def wrapper(charm: 'MyCharm', evt: ops.EventBase):
events.append(evt)
fn(charm, evt)
return wrapper

class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
framework.observe(self.on.start, self._on_start)
self.seen = None

@dec
def _on_start(self, event):
def _on_start(self, event: ops.EventBase):
self.seen = event

framework = self.create_framework()
Expand Down Expand Up @@ -147,18 +149,19 @@ class MyCharm(ops.CharmBase):
def test_relation_events(self):

class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.seen = []
self.seen: typing.List[str] = []
for rel in ('req1', 'req-2', 'pro1', 'pro-2', 'peer1', 'peer-2'):
# Hook up relation events to generic handler.
self.framework.observe(self.on[rel].relation_joined, self.on_any_relation)
self.framework.observe(self.on[rel].relation_changed, self.on_any_relation)
self.framework.observe(self.on[rel].relation_departed, self.on_any_relation)
self.framework.observe(self.on[rel].relation_broken, self.on_any_relation)

def on_any_relation(self, event):
def on_any_relation(self, event: ops.RelationEvent):
assert event.relation.name == 'req1'
assert event.relation.app is not None
assert event.relation.app.name == 'remote'
self.seen.append(type(event).__name__)

Expand Down Expand Up @@ -210,25 +213,25 @@ def test_storage_events(self):
this = self

class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.seen = []
self.seen: typing.List[str] = []
self.framework.observe(self.on['stor1'].storage_attached, self._on_stor1_attach)
self.framework.observe(self.on['stor2'].storage_detaching, self._on_stor2_detach)
self.framework.observe(self.on['stor3'].storage_attached, self._on_stor3_attach)
self.framework.observe(self.on['stor-4'].storage_attached, self._on_stor4_attach)

def _on_stor1_attach(self, event):
def _on_stor1_attach(self, event: ops.StorageAttachedEvent):
self.seen.append(type(event).__name__)
this.assertEqual(event.storage.location, Path("/var/srv/stor1/0"))

def _on_stor2_detach(self, event):
def _on_stor2_detach(self, event: ops.StorageDetachingEvent):
self.seen.append(type(event).__name__)

def _on_stor3_attach(self, event):
def _on_stor3_attach(self, event: ops.StorageAttachedEvent):
self.seen.append(type(event).__name__)

def _on_stor4_attach(self, event):
def _on_stor4_attach(self, event: ops.StorageAttachedEvent):
self.seen.append(type(event).__name__)

# language=YAML
Expand Down Expand Up @@ -320,17 +323,17 @@ def _on_stor4_attach(self, event):
def test_workload_events(self):

class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.seen = []
self.seen: typing.List[str] = []
self.count = 0
for workload in ('container-a', 'containerb'):
# Hook up relation events to generic handler.
self.framework.observe(
self.on[workload].pebble_ready,
self.on_any_pebble_ready)

def on_any_pebble_ready(self, event):
def on_any_pebble_ready(self, event: ops.PebbleReadyEvent):
self.seen.append(type(event).__name__)
self.count += 1

Expand Down Expand Up @@ -437,25 +440,25 @@ def test_action_events(self):

class MyCharm(ops.CharmBase):

def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
framework.observe(self.on.foo_bar_action, self._on_foo_bar_action)
framework.observe(self.on.start_action, self._on_start_action)

def _on_foo_bar_action(self, event):
def _on_foo_bar_action(self, event: ops.ActionEvent):
self.seen_action_params = event.params
event.log('test-log')
event.set_results({'res': 'val with spaces'})
event.fail('test-fail')

def _on_start_action(self, event):
def _on_start_action(self, event: ops.ActionEvent):
pass

self._setup_test_action()
framework = self.create_framework()
charm = MyCharm(framework)

events = list(MyCharm.on.events())
events: typing.List[str] = list(MyCharm.on.events()) # type: ignore
self.assertIn('foo_bar_action', events)
self.assertIn('start_action', events)

Expand All @@ -477,12 +480,12 @@ def test_invalid_action_results(self):

class MyCharm(ops.CharmBase):

def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.res = {}
self.res: typing.Dict[str, typing.Any] = {}
framework.observe(self.on.foo_bar_action, self._on_foo_bar_action)

def _on_foo_bar_action(self, event):
def _on_foo_bar_action(self, event: ops.ActionEvent):
event.set_results(self.res)

self._setup_test_action()
Expand All @@ -500,15 +503,15 @@ def _on_foo_bar_action(self, event):
with self.assertRaises(ValueError):
charm.on.foo_bar_action.emit()

def _test_action_event_defer_fails(self, cmd_type):
def _test_action_event_defer_fails(self, cmd_type: str):

class MyCharm(ops.CharmBase):

def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
framework.observe(self.on.start_action, self._on_start_action)

def _on_start_action(self, event):
def _on_start_action(self, event: ops.ActionEvent):
event.defer()

fake_script(self, f"{cmd_type}-get", """echo '{"foo-name": "name", "silent": true}'""")
Expand Down Expand Up @@ -588,31 +591,31 @@ def test_containers_storage_multiple_mounts(self):

def test_secret_events(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.seen = []
self.seen: typing.List[str] = []
self.framework.observe(self.on.secret_changed, self.on_secret_changed)
self.framework.observe(self.on.secret_rotate, self.on_secret_rotate)
self.framework.observe(self.on.secret_remove, self.on_secret_remove)
self.framework.observe(self.on.secret_expired, self.on_secret_expired)

def on_secret_changed(self, event):
def on_secret_changed(self, event: ops.SecretChangedEvent):
assert event.secret.id == 'secret:changed'
assert event.secret.label is None
self.seen.append(type(event).__name__)

def on_secret_rotate(self, event):
def on_secret_rotate(self, event: ops.SecretRotateEvent):
assert event.secret.id == 'secret:rotate'
assert event.secret.label == 'rot'
self.seen.append(type(event).__name__)

def on_secret_remove(self, event):
def on_secret_remove(self, event: ops.SecretRemoveEvent):
assert event.secret.id == 'secret:remove'
assert event.secret.label == 'rem'
assert event.revision == 7
self.seen.append(type(event).__name__)

def on_secret_expired(self, event):
def on_secret_expired(self, event: ops.SecretExpiredEvent):
assert event.secret.id == 'secret:expired'
assert event.secret.label == 'exp'
assert event.revision == 42
Expand All @@ -635,11 +638,11 @@ def on_secret_expired(self, event):

def test_collect_app_status_leader(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_status)

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
event.add_status(ops.ActiveStatus())
event.add_status(ops.BlockedStatus('first'))
event.add_status(ops.WaitingStatus('waiting'))
Expand All @@ -658,11 +661,11 @@ def _on_collect_status(self, event):

def test_collect_app_status_no_statuses(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_status)

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
pass

fake_script(self, 'is-leader', 'echo true')
Expand All @@ -676,11 +679,11 @@ def _on_collect_status(self, event):

def test_collect_app_status_non_leader(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_status)

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
raise Exception # shouldn't be called

fake_script(self, 'is-leader', 'echo false')
Expand All @@ -694,11 +697,11 @@ def _on_collect_status(self, event):

def test_collect_unit_status(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_unit_status, self._on_collect_status)

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
event.add_status(ops.ActiveStatus())
event.add_status(ops.BlockedStatus('first'))
event.add_status(ops.WaitingStatus('waiting'))
Expand All @@ -717,11 +720,11 @@ def _on_collect_status(self, event):

def test_collect_unit_status_no_statuses(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_unit_status, self._on_collect_status)

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
pass

fake_script(self, 'is-leader', 'echo false') # called only for collecting app statuses
Expand All @@ -735,15 +738,15 @@ def _on_collect_status(self, event):

def test_collect_app_and_unit_status(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_app_status)
self.framework.observe(self.on.collect_unit_status, self._on_collect_unit_status)

def _on_collect_app_status(self, event):
def _on_collect_app_status(self, event: ops.CollectStatusEvent):
event.add_status(ops.ActiveStatus())

def _on_collect_unit_status(self, event):
def _on_collect_unit_status(self, event: ops.CollectStatusEvent):
event.add_status(ops.WaitingStatus('blah'))

fake_script(self, 'is-leader', 'echo true')
Expand All @@ -760,12 +763,12 @@ def _on_collect_unit_status(self, event):

def test_add_status_type_error(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args):
def __init__(self, *args: typing.Any):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_status)

def _on_collect_status(self, event):
event.add_status('active')
def _on_collect_status(self, event: ops.CollectStatusEvent):
event.add_status('active') # type: ignore

fake_script(self, 'is-leader', 'echo true')

Expand All @@ -775,12 +778,12 @@ def _on_collect_status(self, event):

def test_collect_status_priority(self):
class MyCharm(ops.CharmBase):
def __init__(self, *args, statuses=None):
def __init__(self, *args: typing.Any, statuses: typing.List[str]):
super().__init__(*args)
self.framework.observe(self.on.collect_app_status, self._on_collect_status)
self.statuses = statuses

def _on_collect_status(self, event):
def _on_collect_status(self, event: ops.CollectStatusEvent):
for status in self.statuses:
event.add_status(ops.StatusBase.from_name(status, ''))

Expand Down
Loading