diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 26c6bd7bd..54e2c5ef9 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,7 @@ APScheduler, see the :doc:`migration section `. (this breaks data store compatibility) - **BREAKING** Switched to using the timezone aware timestamp column type on Oracle - **BREAKING** Fixed precision issue with interval columns on MySQL +- **BREAKING** Fixed datetime comparison issues on SQLite and MySQL - **BREAKING** Worked around datetime microsecond precision issue on MongoDB - **BREAKING** Renamed the ``worker_id`` field to ``scheduler_id`` in the ``JobAcquired`` and ``JobReleased`` events diff --git a/src/apscheduler/_schedulers/async_.py b/src/apscheduler/_schedulers/async_.py index 465964510..61daeb9b7 100644 --- a/src/apscheduler/_schedulers/async_.py +++ b/src/apscheduler/_schedulers/async_.py @@ -304,8 +304,8 @@ async def configure_task( ) modified = True else: - if func is not unset and task.func is not func: - task.func = func + if func is not unset and task.func != func_ref: + task.func = func_ref modified = True if job_executor is not unset and task.job_executor != job_executor: @@ -397,11 +397,7 @@ async def add_schedule( if ismethod(func_or_task_id): args = (func_or_task_id.__self__,) + args func_or_task_id = func_or_task_id.__func__ - elif ( - isbuiltin(func_or_task_id) - and func_or_task_id.__self__ is not None - and not ismodule(func_or_task_id.__self__) - ): + elif isbuiltin(func_or_task_id) and hasattr(func_or_task_id, "__self__"): args = (func_or_task_id.__self__,) + args method_class = type(func_or_task_id.__self__) func_or_task_id = getattr(method_class, func_or_task_id.__name__) @@ -509,11 +505,7 @@ async def add_job( if ismethod(func_or_task_id): args = (func_or_task_id.__self__,) + args func_or_task_id = func_or_task_id.__func__ - elif ( - isbuiltin(func_or_task_id) - and func_or_task_id.__self__ is not None - and not ismodule(func_or_task_id.__self__) - ): + elif isbuiltin(func_or_task_id) and not ismodule(func_or_task_id.__self__): args = (func_or_task_id.__self__,) + args method_class = type(func_or_task_id.__self__) func_or_task_id = getattr(method_class, func_or_task_id.__name__) diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index 416a77a2c..967646e57 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -15,14 +15,15 @@ import tenacity from anyio import CancelScope, to_thread from sqlalchemy import ( - TIMESTAMP, BigInteger, Column, + DateTime, Enum, Integer, Interval, LargeBinary, MetaData, + SmallInteger, Table, TypeDecorator, Unicode, @@ -128,6 +129,8 @@ class SQLAlchemyDataStore(BaseExternalDataStore): max_idle_time: float = attrs.field(default=60) _supports_update_returning: bool = attrs.field(init=False, default=False) + _supports_tzaware_timestamps: bool = attrs.field(init=False, default=False) + _supports_native_interval: bool = attrs.field(init=False, default=False) _metadata: MetaData = attrs.field(init=False) _t_metadata: Table = attrs.field(init=False) _t_tasks: Table = attrs.field(init=False) @@ -138,6 +141,11 @@ class SQLAlchemyDataStore(BaseExternalDataStore): def __attrs_post_init__(self) -> None: # Generate the table definitions prefix = f"{self.schema}." if self.schema else "" + self._supports_tzaware_timestamps = self.engine.dialect in ( + "postgresql", + "oracle", + ) + self._supports_native_interval = self.engine.dialect == "postgresql" self._metadata = self.get_table_definitions() self._t_metadata = self._metadata.tables[prefix + "metadata"] self._t_tasks = self._metadata.tables[prefix + "tasks"] @@ -222,13 +230,43 @@ def _temporary_failure_exceptions(self) -> tuple[type[Exception], ...]: return InterfaceError, OSError + def _convert_incoming_next_fire_time(self, data: dict[str, Any]) -> dict[str, Any]: + if not self._supports_tzaware_timestamps: + utcoffset_minutes = data.pop("next_fire_time_utcoffset", None) + if utcoffset_minutes is not None: + tz = timezone(timedelta(minutes=utcoffset_minutes)) + timestamp = data["next_fire_time"] / 1000_000 + data["next_fire_time"] = datetime.fromtimestamp(timestamp, tz=tz) + + return data + + def _convert_outgoing_next_fire_time(self, data: dict[str, Any]) -> dict[str, Any]: + if not self._supports_tzaware_timestamps: + next_fire_time = data["next_fire_time"] + if next_fire_time is not None: + data["next_fire_time"] = int(next_fire_time.timestamp() * 1000_000) + data["next_fire_time_utcoffset"] = ( + next_fire_time.utcoffset().total_seconds() // 60 + ) + else: + data["next_fire_time_utcoffset"] = None + + return data + def get_table_definitions(self) -> MetaData: - if self.engine.dialect.name in ("postgresql", "oracle"): - timestamp_type: TypeEngine[datetime] = TIMESTAMP(timezone=True) + if self._supports_tzaware_timestamps: + timestamp_type: TypeEngine[datetime] = DateTime(timezone=True) + next_fire_time_tzoffset_columns: tuple[Column, ...] = ( + Column("next_fire_time", timestamp_type, index=True), + ) else: timestamp_type = EmulatedTimestampTZ() + next_fire_time_tzoffset_columns = ( + Column("next_fire_time", BigInteger, index=True), + Column("next_fire_time_utcoffset", SmallInteger), + ) - if self.engine.dialect.name == "postgresql": + if self._supports_native_interval: interval_type = Interval(second_precision=6) else: interval_type = EmulatedInterval() @@ -256,7 +294,7 @@ def get_table_definitions(self) -> MetaData: Column("coalesce", Enum(CoalescePolicy), nullable=False), Column("misfire_grace_time", interval_type), Column("max_jitter", interval_type), - Column("next_fire_time", timestamp_type, index=True), + *next_fire_time_tzoffset_columns, Column("last_fire_time", timestamp_type), Column("acquired_by", Unicode(500)), Column("acquired_until", timestamp_type), @@ -343,7 +381,12 @@ async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: - schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) + schedules.append( + Schedule.unmarshal( + self.serializer, + self._convert_incoming_next_fire_time(row._asdict()), + ) + ) except SerializationError as exc: await self._event_broker.publish( ScheduleDeserializationFailed(schedule_id=row.id, exception=exc) @@ -448,7 +491,9 @@ async def add_schedule( self, schedule: Schedule, conflict_policy: ConflictPolicy ) -> None: event: DataStoreEvent - values = schedule.marshal(self.serializer) + values = self._convert_outgoing_next_fire_time( + schedule.marshal(self.serializer) + ) insert = self._t_schedules.insert().values(**values) try: async for attempt in self._retry(): @@ -540,12 +585,19 @@ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedul async with self._begin_transaction() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + if self._supports_tzaware_timestamps: + comparison = self._t_schedules.c.next_fire_time <= now + else: + comparison = self._t_schedules.c.next_fire_time <= int( + now.timestamp() * 1000_000 + ) + schedules_cte = ( select(self._t_schedules.c.id) .where( and_( self._t_schedules.c.next_fire_time.isnot(None), - self._t_schedules.c.next_fire_time <= now, + comparison, or_( self._t_schedules.c.acquired_until.is_(None), self._t_schedules.c.acquired_until < now, @@ -581,6 +633,9 @@ async def release_schedules( self, scheduler_id: str, schedules: list[Schedule] ) -> None: task_ids = {schedule.id: schedule.task_id for schedule in schedules} + next_fire_times = { + schedule.id: schedule.next_fire_time for schedule in schedules + } async for attempt in self._retry(): with attempt: async with self._begin_transaction() as conn: @@ -602,21 +657,43 @@ async def release_schedules( finished_schedule_ids.append(schedule.id) continue - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) + if self._supports_tzaware_timestamps: + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": schedule.next_fire_time, + } + ) + else: + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": int( + schedule.next_fire_time.timestamp() + * 1000_000 + ), + "p_next_fire_time_utcoffset": ( + schedule.next_fire_time.utcoffset().total_seconds() + // 60 + ), + } + ) else: finished_schedule_ids.append(schedule.id) # Update schedules that have a next fire time if update_args: + extra_values = {} p_id: BindParameter = bindparam("p_id") p_trigger: BindParameter = bindparam("p_trigger") p_next_fire_time: BindParameter = bindparam("p_next_fire_time") + if not self._supports_tzaware_timestamps: + extra_values["next_fire_time_utcoffset"] = bindparam( + "p_next_fire_time_utcoffset" + ) + update = ( self._t_schedules.update() .where( @@ -630,11 +707,9 @@ async def release_schedules( next_fire_time=p_next_fire_time, acquired_by=None, acquired_until=None, + **extra_values, ) ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } # TODO: actually check which rows were updated? await self._execute(conn, update, update_args) updated_ids = list(next_fire_times) @@ -668,8 +743,12 @@ async def release_schedules( ) async def get_next_schedule_run_time(self) -> datetime | None: + columns = [self._t_schedules.c.next_fire_time] + if not self._supports_tzaware_timestamps: + columns.append(self._t_schedules.c.next_fire_time_utcoffset) + statenent = ( - select(self._t_schedules.c.next_fire_time) + select(*columns) .where(self._t_schedules.c.next_fire_time.isnot(None)) .order_by(self._t_schedules.c.next_fire_time) .limit(1) @@ -679,6 +758,13 @@ async def get_next_schedule_run_time(self) -> datetime | None: async with self._begin_transaction() as conn: result = await self._execute(conn, statenent) + if not self._supports_tzaware_timestamps: + if row := result.first(): + tz = timezone(timedelta(minutes=row[1])) + return datetime.fromtimestamp(row[0] / 1000_000, tz=tz) + else: + return None + return result.scalar() async def add_job(self, job: Job) -> None: diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 6c8e3c2d7..0983f5b7c 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -192,11 +192,13 @@ async def test_configure_task(self, raw_datastore: DataStore) -> None: assert isinstance(event, TaskUpdated) assert event.task_id == "mytask" - async def test_add_remove_schedule(self, raw_datastore: DataStore) -> None: + async def test_add_remove_schedule( + self, raw_datastore: DataStore, timezone: ZoneInfo + ) -> None: send, receive = create_memory_object_stream[Event](3) async with AsyncScheduler(data_store=raw_datastore) as scheduler: scheduler.subscribe(send.send) - now = datetime.now(UTC) + now = datetime.now(timezone) trigger = DateTrigger(now) schedule_id = await scheduler.add_schedule( dummy_async_job, trigger, id="foo" @@ -360,9 +362,10 @@ async def test_callable_types( expected_result: object, use_scheduling: bool, raw_datastore: DataStore, + timezone: ZoneInfo, ) -> None: send, receive = create_memory_object_stream[Event](4) - now = datetime.now(UTC) + now = datetime.now(timezone) async with AsyncScheduler(data_store=raw_datastore) as scheduler: scheduler.subscribe(send.send, {JobReleased}) await scheduler.start_in_background() @@ -382,10 +385,10 @@ async def test_callable_types( assert result.return_value == expected_result async def test_scheduled_job_missed_deadline( - self, raw_datastore: DataStore + self, raw_datastore: DataStore, timezone: ZoneInfo ) -> None: send, receive = create_memory_object_stream[Event](4) - now = datetime.now(UTC) + now = datetime.now(timezone) trigger = DateTrigger(now) async with AsyncScheduler(data_store=raw_datastore) as scheduler: await scheduler.add_schedule( @@ -453,9 +456,10 @@ async def test_coalesce_policy( expected_jobs: int, first_fire_time_delta: timedelta, raw_datastore: DataStore, + timezone: ZoneInfo, ) -> None: send, receive = create_memory_object_stream[Event](4) - now = datetime.now(UTC) + now = datetime.now(timezone) first_start_time = now - timedelta(minutes=3, seconds=5) trigger = IntervalTrigger(minutes=1, start_time=first_start_time) async with AsyncScheduler( @@ -513,7 +517,7 @@ async def test_jitter( ) -> None: send, receive = create_memory_object_stream[Event](4) jitter = 1.569374 - now = datetime.now(UTC) + now = datetime.now(timezone) fake_uniform = mocker.patch("random.uniform") fake_uniform.configure_mock(side_effect=lambda a, b: jitter) async with AsyncScheduler( @@ -632,7 +636,7 @@ async def test_add_job_get_result_no_ready_yet(self) -> None: with pytest.raises(JobLookupError), fail_after(1): await scheduler.get_job_result(job_id, wait=False) - async def test_contextvars(self, mocker: MockerFixture) -> None: + async def test_contextvars(self, mocker: MockerFixture, timezone: ZoneInfo) -> None: def check_contextvars() -> None: assert current_async_scheduler.get() is scheduler info = current_job.get() @@ -649,7 +653,7 @@ def check_contextvars() -> None: fake_uniform = mocker.patch("random.uniform") fake_uniform.configure_mock(return_value=2.16) send, receive = create_memory_object_stream[Event](1) - now = datetime.now(UTC) + now = datetime.now(timezone) async with AsyncScheduler() as scheduler: await scheduler.configure_task("contextvars", func=check_contextvars) await scheduler.add_schedule( @@ -743,11 +747,11 @@ def test_configure_task(self) -> None: assert isinstance(event, TaskUpdated) assert event.task_id == "mytask" - def test_add_remove_schedule(self) -> None: + def test_add_remove_schedule(self, timezone: ZoneInfo) -> None: queue = Queue() with Scheduler() as scheduler: scheduler.subscribe(queue.put_nowait) - now = datetime.now(UTC) + now = datetime.now(timezone) trigger = DateTrigger(now) schedule_id = scheduler.add_schedule(dummy_async_job, trigger, id="foo") assert schedule_id == "foo"