diff --git a/CHANGELOG.md b/CHANGELOG.md index 2978f79a77..0470267580 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## [0.5.0] - Unreleased + +### Changed + +- Watchers are now called immediately when setting the attribute if they are synchronous. https://github.com/Textualize/textual/pull/1145 + ## [0.4.0] - 2022-11-08 https://textual.textualize.io/blog/2022/11/08/version-040/#version-040 diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 2ee3ba5dfd..7cd88d6300 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -2,7 +2,16 @@ from functools import partial from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Type, + TypeVar, + Union, +) from . import events from ._callback import count_parameters, invoke @@ -146,6 +155,7 @@ def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: setattr(owner, f"_default_{name}", default) def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: + _rich_traceback_omit = True value: _NotSet | ReactiveType = getattr(obj, self.internal_name, _NOT_SET) if isinstance(value, _NotSet): # No value present, we need to set the default @@ -155,11 +165,12 @@ def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: # Set and return the value setattr(obj, self.internal_name, default_value) if self._init: - self._check_watchers(obj, self.name, default_value, first_set=True) + self._check_watchers(obj, self.name, default_value) return default_value return value def __set__(self, obj: Reactable, value: ReactiveType) -> None: + _rich_traceback_omit = True name = self.name current_value = getattr(obj, name) # Check for validate function @@ -176,15 +187,13 @@ def __set__(self, obj: Reactable, value: ReactiveType) -> None: # Store the internal value setattr(obj, self.internal_name, value) # Check all watchers - self._check_watchers(obj, name, current_value, first_set=first_set) + self._check_watchers(obj, name, current_value) # Refresh according to descriptor flags if self._layout or self._repaint: obj.refresh(repaint=self._repaint, layout=self._layout) @classmethod - def _check_watchers( - cls, obj: Reactable, name: str, old_value: Any, first_set: bool = False - ) -> None: + def _check_watchers(cls, obj: Reactable, name: str, old_value: Any): """Check watchers, and call watch methods / computes Args: @@ -193,60 +202,68 @@ def _check_watchers( old_value (Any): The old (previous) value of the attribute. first_set (bool, optional): True if this is the first time setting the value. Defaults to False. """ + _rich_traceback_omit = True # Get the current value. internal_name = f"_reactive_{name}" value = getattr(obj, internal_name) - async def update_watcher( - obj: Reactable, watch_function: Callable, old_value: Any, value: Any - ) -> None: - """Call watch function, and run compute. + async def await_watcher(awaitable: Awaitable) -> None: + """Coroutine to await an awaitable returned from a watcher""" + _rich_traceback_omit = True + await awaitable + # Watcher may have changed the state, so run compute again + obj.post_message_no_wait( + events.Callback(sender=obj, callback=partial(Reactive._compute, obj)) + ) + + def invoke_watcher( + watch_function: Callable, old_value: object, value: object + ) -> bool: + """Invoke a watch function. Args: - obj (Reactable): Reactable object. - watch_function (Callable): Watch method. - old_value (Any): Old value. - value (Any): new value. + watch_function (Callable): A watch function, which may be sync or async. + old_value (object): The old value of the attribute. + value (object): The new value of the attribute. + + Returns: + bool: True if the watcher was run, or False if it was posted. """ - _rich_traceback_guard = True - # Call watch with one or two parameters + _rich_traceback_omit = True if count_parameters(watch_function) == 2: watch_result = watch_function(old_value, value) else: watch_result = watch_function(value) - # Optionally await result if isawaitable(watch_result): - await watch_result - # Run computes - await Reactive._compute(obj) + # Result is awaitable, so we need to await it within an async context + obj.post_message_no_wait( + events.Callback( + sender=obj, callback=partial(await_watcher, watch_result) + ) + ) + return False + else: + return True - # Check for watch method + # Compute is only required if a watcher runs immediately, not if they were posted. + require_compute = False watch_function = getattr(obj, f"watch_{name}", None) if callable(watch_function): - # Post a callback message, so we can call the watch method in an orderly async manner - obj.post_message_no_wait( - events.Callback( - sender=obj, - callback=partial( - update_watcher, obj, watch_function, old_value, value - ), - ) + require_compute = require_compute or invoke_watcher( + watch_function, old_value, value ) - # Check for watchers set via `watch` watchers: list[Callable] = getattr(obj, "__watchers", {}).get(name, []) for watcher in watchers: - obj.post_message_no_wait( - events.Callback( - sender=obj, - callback=partial(update_watcher, obj, watcher, old_value, value), - ) + require_compute = require_compute or invoke_watcher( + watcher, old_value, value ) - # Run computes - obj.post_message_no_wait( - events.Callback(sender=obj, callback=partial(Reactive._compute, obj)) - ) + if require_compute: + # Run computes + obj.post_message_no_wait( + events.Callback(sender=obj, callback=partial(Reactive._compute, obj)) + ) @classmethod async def _compute(cls, obj: Reactable) -> None: @@ -301,10 +318,13 @@ class var(Reactive[ReactiveType]): Args: default (ReactiveType | Callable[[], ReactiveType]): A default value or callable that returns a default. + init (bool, optional): Call watchers on initialize (post mount). Defaults to True. """ - def __init__(self, default: ReactiveType | Callable[[], ReactiveType]) -> None: - super().__init__(default, layout=False, repaint=False, init=True) + def __init__( + self, default: ReactiveType | Callable[[], ReactiveType], init: bool = True + ) -> None: + super().__init__(default, layout=False, repaint=False, init=init) def watch( diff --git a/tests/test_reactive.py b/tests/test_reactive.py new file mode 100644 index 0000000000..7158b46f24 --- /dev/null +++ b/tests/test_reactive.py @@ -0,0 +1,26 @@ +from textual.app import App, ComposeResult +from textual.reactive import reactive + + +class WatchApp(App): + + count = reactive(0, init=False) + + test_count = 0 + + def watch_count(self, value: int) -> None: + self.test_count = value + + +async def test_watch(): + """Test that changes to a watched reactive attribute happen immediately.""" + app = WatchApp() + async with app.run_test(): + app.count += 1 + assert app.test_count == 1 + app.count += 1 + assert app.test_count == 2 + app.count -= 1 + assert app.test_count == 1 + app.count -= 1 + assert app.test_count == 0