Skip to content

Commit

Permalink
Worked around timezone awareness issues on the next_fire_time column
Browse files Browse the repository at this point in the history
An extra column had to be used in order to make the column indexable (MariaDB does not support functional indexes at the moment).
  • Loading branch information
agronholm committed Nov 12, 2023
1 parent 07e8940 commit bbb9dfe
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ APScheduler, see the :doc:`migration section <migration>`.
(this breaks data store compatibility)
- **BREAKING** Switched to using the timezone aware timestamp column type on Oracle
- **BREAKING** Fixed precision issue with interval columns on MySQL
- **BREAKING** Fixed datetime comparison issues on SQLite and MySQL
- **BREAKING** Worked around datetime microsecond precision issue on MongoDB
- **BREAKING** Renamed the ``worker_id`` field to ``scheduler_id`` in the
``JobAcquired`` and ``JobReleased`` events
Expand Down
16 changes: 4 additions & 12 deletions src/apscheduler/_schedulers/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ async def configure_task(
)
modified = True
else:
if func is not unset and task.func is not func:
task.func = func
if func is not unset and task.func != func_ref:
task.func = func_ref
modified = True

if job_executor is not unset and task.job_executor != job_executor:
Expand Down Expand Up @@ -397,11 +397,7 @@ async def add_schedule(
if ismethod(func_or_task_id):
args = (func_or_task_id.__self__,) + args
func_or_task_id = func_or_task_id.__func__
elif (
isbuiltin(func_or_task_id)
and func_or_task_id.__self__ is not None
and not ismodule(func_or_task_id.__self__)
):
elif isbuiltin(func_or_task_id) and hasattr(func_or_task_id, "__self__"):
args = (func_or_task_id.__self__,) + args
method_class = type(func_or_task_id.__self__)
func_or_task_id = getattr(method_class, func_or_task_id.__name__)
Expand Down Expand Up @@ -509,11 +505,7 @@ async def add_job(
if ismethod(func_or_task_id):
args = (func_or_task_id.__self__,) + args
func_or_task_id = func_or_task_id.__func__
elif (
isbuiltin(func_or_task_id)
and func_or_task_id.__self__ is not None
and not ismodule(func_or_task_id.__self__)
):
elif isbuiltin(func_or_task_id) and not ismodule(func_or_task_id.__self__):
args = (func_or_task_id.__self__,) + args
method_class = type(func_or_task_id.__self__)
func_or_task_id = getattr(method_class, func_or_task_id.__name__)
Expand Down
124 changes: 105 additions & 19 deletions src/apscheduler/datastores/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import tenacity
from anyio import CancelScope, to_thread
from sqlalchemy import (
TIMESTAMP,
BigInteger,
Column,
DateTime,
Enum,
Integer,
Interval,
LargeBinary,
MetaData,
SmallInteger,
Table,
TypeDecorator,
Unicode,
Expand Down Expand Up @@ -128,6 +129,8 @@ class SQLAlchemyDataStore(BaseExternalDataStore):
max_idle_time: float = attrs.field(default=60)

_supports_update_returning: bool = attrs.field(init=False, default=False)
_supports_tzaware_timestamps: bool = attrs.field(init=False, default=False)
_supports_native_interval: bool = attrs.field(init=False, default=False)
_metadata: MetaData = attrs.field(init=False)
_t_metadata: Table = attrs.field(init=False)
_t_tasks: Table = attrs.field(init=False)
Expand All @@ -138,6 +141,11 @@ class SQLAlchemyDataStore(BaseExternalDataStore):
def __attrs_post_init__(self) -> None:
# Generate the table definitions
prefix = f"{self.schema}." if self.schema else ""
self._supports_tzaware_timestamps = self.engine.dialect in (
"postgresql",
"oracle",
)
self._supports_native_interval = self.engine.dialect == "postgresql"
self._metadata = self.get_table_definitions()
self._t_metadata = self._metadata.tables[prefix + "metadata"]
self._t_tasks = self._metadata.tables[prefix + "tasks"]
Expand Down Expand Up @@ -222,13 +230,43 @@ def _temporary_failure_exceptions(self) -> tuple[type[Exception], ...]:

return InterfaceError, OSError

def _convert_incoming_next_fire_time(self, data: dict[str, Any]) -> dict[str, Any]:
if not self._supports_tzaware_timestamps:
utcoffset_minutes = data.pop("next_fire_time_utcoffset", None)
if utcoffset_minutes is not None:
tz = timezone(timedelta(minutes=utcoffset_minutes))
timestamp = data["next_fire_time"] / 1000_000
data["next_fire_time"] = datetime.fromtimestamp(timestamp, tz=tz)

return data

def _convert_outgoing_next_fire_time(self, data: dict[str, Any]) -> dict[str, Any]:
if not self._supports_tzaware_timestamps:
next_fire_time = data["next_fire_time"]
if next_fire_time is not None:
data["next_fire_time"] = int(next_fire_time.timestamp() * 1000_000)
data["next_fire_time_utcoffset"] = (
next_fire_time.utcoffset().total_seconds() // 60
)
else:
data["next_fire_time_utcoffset"] = None

return data

def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name in ("postgresql", "oracle"):
timestamp_type: TypeEngine[datetime] = TIMESTAMP(timezone=True)
if self._supports_tzaware_timestamps:
timestamp_type: TypeEngine[datetime] = DateTime(timezone=True)
next_fire_time_tzoffset_columns: tuple[Column, ...] = (
Column("next_fire_time", timestamp_type, index=True),
)
else:
timestamp_type = EmulatedTimestampTZ()
next_fire_time_tzoffset_columns = (
Column("next_fire_time", BigInteger, index=True),
Column("next_fire_time_utcoffset", SmallInteger),
)

if self.engine.dialect.name == "postgresql":
if self._supports_native_interval:
interval_type = Interval(second_precision=6)
else:
interval_type = EmulatedInterval()
Expand Down Expand Up @@ -256,7 +294,7 @@ def get_table_definitions(self) -> MetaData:
Column("coalesce", Enum(CoalescePolicy), nullable=False),
Column("misfire_grace_time", interval_type),
Column("max_jitter", interval_type),
Column("next_fire_time", timestamp_type, index=True),
*next_fire_time_tzoffset_columns,
Column("last_fire_time", timestamp_type),
Column("acquired_by", Unicode(500)),
Column("acquired_until", timestamp_type),
Expand Down Expand Up @@ -343,7 +381,12 @@ async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
schedules: list[Schedule] = []
for row in result:
try:
schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
schedules.append(
Schedule.unmarshal(
self.serializer,
self._convert_incoming_next_fire_time(row._asdict()),
)
)
except SerializationError as exc:
await self._event_broker.publish(
ScheduleDeserializationFailed(schedule_id=row.id, exception=exc)
Expand Down Expand Up @@ -448,7 +491,9 @@ async def add_schedule(
self, schedule: Schedule, conflict_policy: ConflictPolicy
) -> None:
event: DataStoreEvent
values = schedule.marshal(self.serializer)
values = self._convert_outgoing_next_fire_time(
schedule.marshal(self.serializer)
)
insert = self._t_schedules.insert().values(**values)
try:
async for attempt in self._retry():
Expand Down Expand Up @@ -540,12 +585,19 @@ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedul
async with self._begin_transaction() as conn:
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
if self._supports_tzaware_timestamps:
comparison = self._t_schedules.c.next_fire_time <= now
else:
comparison = self._t_schedules.c.next_fire_time <= int(
now.timestamp() * 1000_000
)

schedules_cte = (
select(self._t_schedules.c.id)
.where(
and_(
self._t_schedules.c.next_fire_time.isnot(None),
self._t_schedules.c.next_fire_time <= now,
comparison,
or_(
self._t_schedules.c.acquired_until.is_(None),
self._t_schedules.c.acquired_until < now,
Expand Down Expand Up @@ -581,6 +633,9 @@ async def release_schedules(
self, scheduler_id: str, schedules: list[Schedule]
) -> None:
task_ids = {schedule.id: schedule.task_id for schedule in schedules}
next_fire_times = {
schedule.id: schedule.next_fire_time for schedule in schedules
}
async for attempt in self._retry():
with attempt:
async with self._begin_transaction() as conn:
Expand All @@ -602,21 +657,43 @@ async def release_schedules(
finished_schedule_ids.append(schedule.id)
continue

update_args.append(
{
"p_id": schedule.id,
"p_trigger": serialized_trigger,
"p_next_fire_time": schedule.next_fire_time,
}
)
if self._supports_tzaware_timestamps:
update_args.append(
{
"p_id": schedule.id,
"p_trigger": serialized_trigger,
"p_next_fire_time": schedule.next_fire_time,
}
)
else:
update_args.append(
{
"p_id": schedule.id,
"p_trigger": serialized_trigger,
"p_next_fire_time": int(
schedule.next_fire_time.timestamp()
* 1000_000
),
"p_next_fire_time_utcoffset": (
schedule.next_fire_time.utcoffset().total_seconds()
// 60
),
}
)
else:
finished_schedule_ids.append(schedule.id)

# Update schedules that have a next fire time
if update_args:
extra_values = {}
p_id: BindParameter = bindparam("p_id")
p_trigger: BindParameter = bindparam("p_trigger")
p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
if not self._supports_tzaware_timestamps:
extra_values["next_fire_time_utcoffset"] = bindparam(
"p_next_fire_time_utcoffset"
)

update = (
self._t_schedules.update()
.where(
Expand All @@ -630,11 +707,9 @@ async def release_schedules(
next_fire_time=p_next_fire_time,
acquired_by=None,
acquired_until=None,
**extra_values,
)
)
next_fire_times = {
arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
}
# TODO: actually check which rows were updated?
await self._execute(conn, update, update_args)
updated_ids = list(next_fire_times)
Expand Down Expand Up @@ -668,8 +743,12 @@ async def release_schedules(
)

async def get_next_schedule_run_time(self) -> datetime | None:
columns = [self._t_schedules.c.next_fire_time]
if not self._supports_tzaware_timestamps:
columns.append(self._t_schedules.c.next_fire_time_utcoffset)

statenent = (
select(self._t_schedules.c.next_fire_time)
select(*columns)
.where(self._t_schedules.c.next_fire_time.isnot(None))
.order_by(self._t_schedules.c.next_fire_time)
.limit(1)
Expand All @@ -679,6 +758,13 @@ async def get_next_schedule_run_time(self) -> datetime | None:
async with self._begin_transaction() as conn:
result = await self._execute(conn, statenent)

if not self._supports_tzaware_timestamps:
if row := result.first():
tz = timezone(timedelta(minutes=row[1]))
return datetime.fromtimestamp(row[0] / 1000_000, tz=tz)
else:
return None

return result.scalar()

async def add_job(self, job: Job) -> None:
Expand Down
26 changes: 15 additions & 11 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ async def test_configure_task(self, raw_datastore: DataStore) -> None:
assert isinstance(event, TaskUpdated)
assert event.task_id == "mytask"

async def test_add_remove_schedule(self, raw_datastore: DataStore) -> None:
async def test_add_remove_schedule(
self, raw_datastore: DataStore, timezone: ZoneInfo
) -> None:
send, receive = create_memory_object_stream[Event](3)
async with AsyncScheduler(data_store=raw_datastore) as scheduler:
scheduler.subscribe(send.send)
now = datetime.now(UTC)
now = datetime.now(timezone)
trigger = DateTrigger(now)
schedule_id = await scheduler.add_schedule(
dummy_async_job, trigger, id="foo"
Expand Down Expand Up @@ -360,9 +362,10 @@ async def test_callable_types(
expected_result: object,
use_scheduling: bool,
raw_datastore: DataStore,
timezone: ZoneInfo,
) -> None:
send, receive = create_memory_object_stream[Event](4)
now = datetime.now(UTC)
now = datetime.now(timezone)
async with AsyncScheduler(data_store=raw_datastore) as scheduler:
scheduler.subscribe(send.send, {JobReleased})
await scheduler.start_in_background()
Expand All @@ -382,10 +385,10 @@ async def test_callable_types(
assert result.return_value == expected_result

async def test_scheduled_job_missed_deadline(
self, raw_datastore: DataStore
self, raw_datastore: DataStore, timezone: ZoneInfo
) -> None:
send, receive = create_memory_object_stream[Event](4)
now = datetime.now(UTC)
now = datetime.now(timezone)
trigger = DateTrigger(now)
async with AsyncScheduler(data_store=raw_datastore) as scheduler:
await scheduler.add_schedule(
Expand Down Expand Up @@ -453,9 +456,10 @@ async def test_coalesce_policy(
expected_jobs: int,
first_fire_time_delta: timedelta,
raw_datastore: DataStore,
timezone: ZoneInfo,
) -> None:
send, receive = create_memory_object_stream[Event](4)
now = datetime.now(UTC)
now = datetime.now(timezone)
first_start_time = now - timedelta(minutes=3, seconds=5)
trigger = IntervalTrigger(minutes=1, start_time=first_start_time)
async with AsyncScheduler(
Expand Down Expand Up @@ -513,7 +517,7 @@ async def test_jitter(
) -> None:
send, receive = create_memory_object_stream[Event](4)
jitter = 1.569374
now = datetime.now(UTC)
now = datetime.now(timezone)
fake_uniform = mocker.patch("random.uniform")
fake_uniform.configure_mock(side_effect=lambda a, b: jitter)
async with AsyncScheduler(
Expand Down Expand Up @@ -632,7 +636,7 @@ async def test_add_job_get_result_no_ready_yet(self) -> None:
with pytest.raises(JobLookupError), fail_after(1):
await scheduler.get_job_result(job_id, wait=False)

async def test_contextvars(self, mocker: MockerFixture) -> None:
async def test_contextvars(self, mocker: MockerFixture, timezone: ZoneInfo) -> None:
def check_contextvars() -> None:
assert current_async_scheduler.get() is scheduler
info = current_job.get()
Expand All @@ -649,7 +653,7 @@ def check_contextvars() -> None:
fake_uniform = mocker.patch("random.uniform")
fake_uniform.configure_mock(return_value=2.16)
send, receive = create_memory_object_stream[Event](1)
now = datetime.now(UTC)
now = datetime.now(timezone)
async with AsyncScheduler() as scheduler:
await scheduler.configure_task("contextvars", func=check_contextvars)
await scheduler.add_schedule(
Expand Down Expand Up @@ -743,11 +747,11 @@ def test_configure_task(self) -> None:
assert isinstance(event, TaskUpdated)
assert event.task_id == "mytask"

def test_add_remove_schedule(self) -> None:
def test_add_remove_schedule(self, timezone: ZoneInfo) -> None:
queue = Queue()
with Scheduler() as scheduler:
scheduler.subscribe(queue.put_nowait)
now = datetime.now(UTC)
now = datetime.now(timezone)
trigger = DateTrigger(now)
schedule_id = scheduler.add_schedule(dummy_async_job, trigger, id="foo")
assert schedule_id == "foo"
Expand Down

0 comments on commit bbb9dfe

Please sign in to comment.