Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More performance enhancements #1850

Merged
merged 12 commits into from
Dec 9, 2024
154 changes: 90 additions & 64 deletions pay-api/src/pay_api/models/payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
import pytz
from flask import current_app
from marshmallow import fields
from sqlalchemy import Boolean, ForeignKey, String, and_, cast, func, or_
from sqlalchemy import Boolean, ForeignKey, String, and_, cast, func, or_, select
from sqlalchemy.dialects.postgresql import ARRAY, TEXT
from sqlalchemy.orm import relationship
from sqlalchemy.sql import select
from sqlalchemy.orm import contains_eager, lazyload, load_only, relationship

from pay_api.exceptions import BusinessException
from pay_api.utils.constants import DT_SHORT_FORMAT
Expand Down Expand Up @@ -252,50 +251,62 @@ def find_payments_to_consolidate(cls, auth_account_id: str):
def generate_base_transaction_query(cls):
"""Generate a base query."""
return (
db.session.query(
Invoice.id,
Invoice.payment_account_id,
Invoice.corp_type_code,
Invoice.created_on,
Invoice.payment_date,
Invoice.refund_date,
Invoice.invoice_status_code,
Invoice.total,
Invoice.service_fees,
Invoice.paid,
Invoice.refund,
Invoice.folio_number,
Invoice.created_name,
Invoice.invoice_status_code,
Invoice.payment_method_code,
Invoice.details,
Invoice.business_identifier,
Invoice.created_by,
Invoice.filing_id,
Invoice.bcol_account,
Invoice.disbursement_date,
Invoice.disbursement_reversal_date,
Invoice.overdue_date,
PaymentLineItem.id,
PaymentLineItem.description,
PaymentLineItem.gst,
PaymentLineItem.pst,
PaymentAccount.id,
PaymentAccount.auth_account_id,
PaymentAccount.name,
PaymentAccount.billable,
InvoiceReference.id,
InvoiceReference.invoice_number,
InvoiceReference.reference_number,
InvoiceReference.status_code,
)
.outerjoin(PaymentAccount, Invoice.payment_account_id == PaymentAccount.id)
.outerjoin(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
.outerjoin(
db.session.query(Invoice)
.join(PaymentAccount, Invoice.payment_account_id == PaymentAccount.id)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joins make sense for these, they will always occur

.join(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
.join(
FeeSchedule,
FeeSchedule.fee_schedule_id == PaymentLineItem.fee_schedule_id,
)
.outerjoin(InvoiceReference, InvoiceReference.invoice_id == Invoice.id)
.options(
lazyload("*"),
load_only(
Invoice.id,
Invoice.corp_type_code,
Invoice.created_on,
Invoice.payment_date,
Invoice.refund_date,
Invoice.invoice_status_code,
Invoice.total,
Invoice.service_fees,
Invoice.paid,
Invoice.refund,
Invoice.folio_number,
Invoice.created_name,
Invoice.invoice_status_code,
Invoice.payment_method_code,
Invoice.details,
Invoice.business_identifier,
Invoice.created_by,
Invoice.filing_id,
Invoice.bcol_account,
Invoice.disbursement_date,
Invoice.disbursement_reversal_date,
Invoice.overdue_date,
),
contains_eager(Invoice.payment_line_items)
.load_only(
PaymentLineItem.description,
PaymentLineItem.gst,
PaymentLineItem.pst,
PaymentLineItem.service_fees,
PaymentLineItem.total,
)
.contains_eager(PaymentLineItem.fee_schedule)
.load_only(FeeSchedule.filing_type_code),
contains_eager(Invoice.payment_account).load_only(
PaymentAccount.auth_account_id,
PaymentAccount.name,
PaymentAccount.billable,
PaymentAccount.branch_name,
),
contains_eager(Invoice.references).load_only(
InvoiceReference.invoice_number,
InvoiceReference.reference_number,
InvoiceReference.status_code,
),
)
)

@classmethod
Expand All @@ -320,7 +331,7 @@ def search_purchase_history( # noqa:E501; pylint:disable=too-many-arguments, to
count_future = executor.submit(cls.get_count, auth_account_id, search_filter)
sub_query = cls.generate_subquery(auth_account_id, search_filter, limit, page)
query = query.filter(Invoice.id.in_(sub_query.subquery().select())).order_by(Invoice.id.desc())
result_future = executor.submit(db.session.query(Invoice).from_statement(query).all)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was no good, would get rid of all of nice load_only all the rows would have to be reloaded at serialization

result_future = executor.submit(query.all)
count = count_future.result()
result = result_future.result()
# If maximum number of records is provided, return it as total
Expand All @@ -330,16 +341,14 @@ def search_purchase_history( # noqa:E501; pylint:disable=too-many-arguments, to
# If maximum number of records is provided, set the page with that number
sub_query = cls.generate_subquery(auth_account_id, search_filter, max_no_records, page=None)
result, count = (
db.session.query(Invoice)
.from_statement(query.filter(Invoice.id.in_(sub_query.subquery().select())))
.all(),
query.filter(Invoice.id.in_(sub_query.subquery().select())).all(),
sub_query.count(),
)
else:
count = cls.get_count(auth_account_id, search_filter)
if count > 60000:
raise BusinessException(Error.PAYMENT_SEARCH_TOO_MANY_RECORDS)
result = db.session.query(Invoice).from_statement(query).all()
result = query.all()
return result, count

@classmethod
Expand Down Expand Up @@ -379,18 +388,15 @@ def get_invoices_and_payment_accounts_for_statements(cls, search_filter: Dict):
@classmethod
def get_count(cls, auth_account_id: str, search_filter: Dict):
"""Slimmed downed version for count (less joins)."""
query = cls.generate_base_transaction_query()
query = cls.filter(query, auth_account_id, search_filter)
count = query.group_by(Invoice.id).with_entities(func.count()).count() # pylint:disable=not-callable
query = db.session.query(Invoice.id)
query = cls.filter(query, auth_account_id, search_filter, include_joins=True)
count = query.distinct(Invoice.id).count()
return count

@classmethod
def filter(cls, query, auth_account_id: str, search_filter: Dict):
def filter(cls, query, auth_account_id: str, search_filter: Dict, include_joins=False):
"""For filtering queries."""
if auth_account_id:
query = query.filter(PaymentAccount.auth_account_id == auth_account_id)
if account_name := search_filter.get("accountName", None):
query = query.filter(PaymentAccount.name.ilike(f"%{account_name}%"))
query = cls.filter_payment_account(query, auth_account_id, search_filter, include_joins)
if status_code := search_filter.get("statusCode", None):
query = query.filter(Invoice.invoice_status_code == status_code)
if search_filter.get("status", None):
Expand All @@ -408,14 +414,31 @@ def filter(cls, query, auth_account_id: str, search_filter: Dict):
if invoice_id := search_filter.get("id", None):
query = query.filter(cast(Invoice.id, String).like(f"%{invoice_id}%"))
if invoice_number := search_filter.get("invoiceNumber", None):
if include_joins:
query = query.join(InvoiceReference, InvoiceReference.invoice_id == Invoice.id)
query = query.filter(InvoiceReference.invoice_number.ilike(f"%{invoice_number}%"))

query = cls.filter_corp_type(query, search_filter)
query = cls.filter_payment(query, search_filter)
query = cls.filter_details(query, search_filter)
query = cls.filter_details(query, search_filter, include_joins)
query = cls.filter_date(query, search_filter)
return query

@classmethod
def filter_payment_account(cls, query, auth_account_id, search_filter: dict, include_joins=False):
"""Use subquery to look for payment accounts ahead of time, much faster and easier."""
account_name = search_filter.get("accountName", None)
if auth_account_id:
payment_account_id = (
db.session.query(PaymentAccount.id).filter(PaymentAccount.auth_account_id == auth_account_id).scalar()
)
query = query.filter(Invoice.payment_account_id == (payment_account_id or -1))
if account_name:
if include_joins:
query = query.join(PaymentAccount, PaymentAccount.id == Invoice.payment_account_id)
query = query.filter(PaymentAccount.name.ilike(f"%{account_name}%"))
return query

@classmethod
def filter_corp_type(cls, query, search_filter: dict):
"""Filter for corp type."""
Expand Down Expand Up @@ -470,9 +493,13 @@ def filter_date(cls, query, search_filter: dict):
return query

@classmethod
def filter_details(cls, query, search_filter: dict):
def filter_details(cls, query, search_filter: dict, include_joins=False):
"""Filter by details."""
if line_item := search_filter.get("lineItems", None):
line_item = search_filter.get("lineItems", None)
line_item_or_details = search_filter.get("lineItemsAndDetails", None)
if (line_item or line_item_or_details) and include_joins:
query = query.join(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
if line_item:
query = query.filter(PaymentLineItem.description.ilike(f"%{line_item}%"))
if details := search_filter.get("details", None):
query = query.filter(
Expand All @@ -485,7 +512,7 @@ def filter_details(cls, query, search_filter: dict):
),
)
)
if line_item_or_details := search_filter.get("lineItemsAndDetails", None):
if line_item_or_details:
query = query.filter(
or_(
PaymentLineItem.description.ilike(f"%{line_item_or_details}%"),
Expand All @@ -505,11 +532,10 @@ def filter_details(cls, query, search_filter: dict):
@classmethod
def generate_subquery(cls, auth_account_id, search_filter, limit, page):
"""Generate subquery for invoices, used for pagination."""
subquery = cls.generate_base_transaction_query()
subquery = db.session.query(Invoice.id)
subquery = (
cls.filter(subquery, auth_account_id, search_filter)
.with_entities(Invoice.id)
.group_by(Invoice.id)
cls.filter(subquery, auth_account_id, search_filter, include_joins=True)
.distinct()
.order_by(Invoice.id.desc())
)
if limit:
Expand Down
2 changes: 2 additions & 0 deletions pay-api/src/pay_api/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
class JSONPath(UserDefinedType):
"""Used to define json path when casting."""

cache_ok = True

@property
def python_type(self):
"""Return the python type."""
Expand Down
Loading