Skip to content

Commit

Permalink
perf: Improve performance of the apply changes method by grouping row…
Browse files Browse the repository at this point in the history
…s. (#88)

Improves the speed when having thousands of entries on the sync backlog by grouping them together. This should make sure that the inserts are done in a single query (INSERT ... ON CONFLICT UPDATE), but should also newly created entities fields together, making the queries more efficient. Over testing, this helped improve performance of budget with high number of entries high, but also keep the performance of newly created budgets (without many changes) the same.
  • Loading branch information
bvanelli authored Oct 21, 2024
1 parent 0e391ee commit 68d640d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 10 deletions.
21 changes: 14 additions & 7 deletions actual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from os import PathLike
from typing import IO, List, Union

from sqlalchemy import insert, update
from sqlmodel import MetaData, Session, create_engine, select

from actual.api import ActualServer
Expand All @@ -23,6 +22,7 @@
Accounts,
MessagesClock,
Transactions,
apply_change,
get_attribute_from_reflected_table_name,
get_class_from_reflected_table_name,
reflect_model,
Expand Down Expand Up @@ -287,6 +287,8 @@ def apply_changes(self, messages: List[Message]):
if not self.engine:
raise UnknownFileId("No valid file available, download one with download_budget()")
with Session(self.engine) as s:
# use the current value to group updates together to the same row
current_table, current_id, current_value = None, None, {}
for message in messages:
if message.dataset == "prefs":
# write it to metadata.json instead
Expand All @@ -303,12 +305,17 @@ def apply_changes(self, messages: List[Message]):
f"Actual found a column not supported by the library: "
f"column '{message.column}' at table '{message.dataset}' not found\n"
)
entry = s.exec(select(table).where(table.columns.id == message.row)).one_or_none()
if not entry:
s.exec(insert(table).values(id=message.row))
s.exec(update(table).values({column: message.get_value()}).where(table.columns.id == message.row))
# this seems to be required for sqlmodel, remove if not needed anymore when querying from cache
s.flush()
# if the current id exists, and it's different from the next one, we update the values
next_id = message.row
if current_id and (current_id != next_id or table != current_table):
apply_change(s, current_table, current_id, current_value)
current_table, current_id, current_value = table, next_id, {column: message.get_value()}
# otherwise update the cache with the current value
else:
current_table, current_id, current_value[column] = table, next_id, message.get_value()
# if after finishing all values there is a value left, update it too
if current_table is not None and current_id is not None and current_value is not None:
apply_change(s, current_table, current_id, current_value)
s.commit()

def get_metadata(self) -> dict:
Expand Down
17 changes: 15 additions & 2 deletions actual/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import datetime
import decimal
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from sqlalchemy import MetaData, Table, engine, event, inspect
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.orm import class_mapper, object_session
from sqlmodel import (
Boolean,
Expand All @@ -27,6 +28,7 @@
Integer,
LargeBinary,
Relationship,
Session,
SQLModel,
Text,
func,
Expand Down Expand Up @@ -108,7 +110,18 @@ def get_attribute_by_table_name(table_name: str, column_name: str, reverse: bool
)


def strong_reference_session(session):
def apply_change(
session: Session, table: Table, table_id: str, values: Dict[Column, Union[str, int, float, None]]
) -> None:
"""This function upserts multiple changes into a table based on the `table_id` as primary key. All the `values`
will be inserted as a new row, and if the id already exists, the values will be updated."""
insert_stmt = (
insert(table).values({"id": table_id, **values}).on_conflict_do_update(index_elements=["id"], set_=values)
)
session.exec(insert_stmt) # noqa: Insert type here is correct


def strong_reference_session(session: Session):
@event.listens_for(session, "before_flush")
def before_flush(sess, flush_context, instances):
if len(sess.deleted):
Expand Down
31 changes: 30 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from actual import Actual, ActualError
from actual import Actual, ActualError, reflect_model
from actual.database import Notes
from actual.queries import (
create_account,
Expand Down Expand Up @@ -239,3 +239,32 @@ def test_session_error(mocker):
with Actual(token="foo") as actual:
with pytest.raises(ActualError, match="No session defined"):
print(actual.session) # try to access the session, should raise an exception


def test_apply_changes(session, mocker):
mocker.patch("actual.Actual.validate")
actual = Actual(token="foo")
actual._session, actual.engine, actual._meta = session, session.bind, reflect_model(session.bind)
# create elements but do not commit them
account = create_account(session, "Bank")
transaction = create_transaction(session, date(2024, 1, 4), account, amount=35.7)
session.flush()
messages_size = len(session.info["messages"])
transaction.notes = "foobar"
session.flush()
assert len(session.info["messages"]) == messages_size + 1
messages = session.info["messages"]
# undo all changes, but apply via database
session.rollback()
actual.apply_changes(messages)
# make sure elements got committed correctly
accounts = get_accounts(session, "Bank")
assert len(accounts) == 1
assert accounts[0].id == account.id
assert accounts[0].name == account.name
transactions = get_transactions(session)
assert len(transactions) == 1
assert transactions[0].id == transaction.id
assert transactions[0].notes == transaction.notes
assert transactions[0].get_date() == transaction.get_date()
assert transactions[0].get_amount() == transaction.get_amount()

0 comments on commit 68d640d

Please sign in to comment.