Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade core daos to sqlalchemy 2.0 #1362

Merged
merged 16 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .ds.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@
"filename": "tests/app/dao/test_users_dao.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 52,
"line_number": 69,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "tests/app/dao/test_users_dao.py",
"hashed_secret": "f2c57870308dc87f432e5912d4de6f8e322721ba",
"is_verified": false,
"line_number": 176,
"line_number": 194,
"is_secret": false
}
],
Expand Down Expand Up @@ -384,5 +384,5 @@
}
]
},
"generated_at": "2024-09-27T16:42:53Z"
"generated_at": "2024-10-11T19:26:50Z"
}
106 changes: 64 additions & 42 deletions app/dao/notifications_dao.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta

from flask import current_app
from sqlalchemy import asc, desc, or_, select, text, union
from sqlalchemy import asc, delete, desc, func, or_, select, text, union, update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql import functions
Expand Down Expand Up @@ -109,11 +109,12 @@ def _update_notification_status(
def update_notification_status_by_id(
notification_id, status, sent_by=None, provider_response=None, carrier=None
):
notification = (
Notification.query.with_for_update()
stmt = (
select(Notification)
.with_for_update()
.filter(Notification.id == notification_id)
.first()
)
notification = db.session.execute(stmt).scalars().first()

if not notification:
current_app.logger.info(
Expand Down Expand Up @@ -156,9 +157,8 @@ def update_notification_status_by_id(
@autocommit
def update_notification_status_by_reference(reference, status):
# this is used to update emails
notification = Notification.query.filter(
Notification.reference == reference
).first()
stmt = select(Notification).filter(Notification.reference == reference)
notification = db.session.execute(stmt).scalars().first()

if not notification:
current_app.logger.error(
Expand Down Expand Up @@ -200,31 +200,33 @@ def get_notifications_for_job(


def dao_get_notification_count_for_job_id(*, job_id):
return Notification.query.filter_by(job_id=job_id).count()
stmt = select(func.count(Notification.id)).filter_by(job_id=job_id)
return db.session.execute(stmt).scalar()


def dao_get_notification_count_for_service(*, service_id):
notification_count = Notification.query.filter_by(service_id=service_id).count()
return notification_count
stmt = select(func.count(Notification.id)).filter_by(service_id=service_id)
return db.session.execute(stmt).scalar()


def dao_get_failed_notification_count():
failed_count = Notification.query.filter_by(
stmt = select(func.count(Notification.id)).filter_by(
status=NotificationStatus.FAILED
).count()
return failed_count
)
return db.session.execute(stmt).scalar()


def get_notification_with_personalisation(service_id, notification_id, key_type):
filter_dict = {"service_id": service_id, "id": notification_id}
if key_type:
filter_dict["key_type"] = key_type

return (
Notification.query.filter_by(**filter_dict)
stmt = (
select(Notification)
.filter_by(**filter_dict)
.options(joinedload(Notification.template))
.one()
)
return db.session.execute(stmt).scalars().one()


def get_notification_by_id(notification_id, service_id=None, _raise=False):
Expand All @@ -233,9 +235,13 @@ def get_notification_by_id(notification_id, service_id=None, _raise=False):
if service_id:
filters.append(Notification.service_id == service_id)

query = Notification.query.filter(*filters)
stmt = select(Notification).filter(*filters)

return query.one() if _raise else query.first()
return (
db.session.execute(stmt).scalars().one()
if _raise
else db.session.execute(stmt).scalars().first()
)


def get_notifications_for_service(
Expand Down Expand Up @@ -415,12 +421,13 @@ def move_notifications_to_notification_history(
deleted += delete_count_per_call

# Deleting test Notifications, test notifications are not persisted to NotificationHistory
Notification.query.filter(
stmt = delete(Notification).filter(
Notification.notification_type == notification_type,
Notification.service_id == service_id,
Notification.created_at < timestamp_to_delete_backwards_from,
Notification.key_type == KeyType.TEST,
).delete(synchronize_session=False)
)
db.session.execute(stmt)
db.session.commit()

return deleted
Expand All @@ -442,39 +449,49 @@ def dao_timeout_notifications(cutoff_time, limit=100000):
current_statuses = [NotificationStatus.SENDING, NotificationStatus.PENDING]
new_status = NotificationStatus.TEMPORARY_FAILURE

notifications = (
Notification.query.filter(
stmt = (
select(Notification)
.filter(
Notification.created_at < cutoff_time,
Notification.status.in_(current_statuses),
Notification.notification_type.in_(
[NotificationType.SMS, NotificationType.EMAIL]
),
)
.limit(limit)
.all()
)
notifications = db.session.execute(stmt).scalars().all()

Notification.query.filter(
Notification.id.in_([n.id for n in notifications]),
).update(
{"status": new_status, "updated_at": updated_at}, synchronize_session=False
stmt = (
update(Notification)
.filter(Notification.id.in_([n.id for n in notifications]))
.values({"status": new_status, "updated_at": updated_at})
)
db.session.execute(stmt)

db.session.commit()
return notifications


@autocommit
def dao_update_notifications_by_reference(references, update_dict):
updated_count = Notification.query.filter(
Notification.reference.in_(references)
).update(update_dict, synchronize_session=False)
stmt = (
update(Notification)
.filter(Notification.reference.in_(references))
.values(update_dict)
)
result = db.session.execute(stmt)
updated_count = result.rowcount

updated_history_count = 0
if updated_count != len(references):
updated_history_count = NotificationHistory.query.filter(
NotificationHistory.reference.in_(references)
).update(update_dict, synchronize_session=False)
stmt = (
update(NotificationHistory)
.filter(NotificationHistory.reference.in_(references))
.values(update_dict)
)
result = db.session.execute(stmt)
updated_history_count = result.rowcount

return updated_count, updated_history_count

Expand Down Expand Up @@ -541,18 +558,21 @@ def dao_get_notifications_by_recipient_or_reference(


def dao_get_notification_by_reference(reference):
return Notification.query.filter(Notification.reference == reference).one()
stmt = select(Notification).filter(Notification.reference == reference)
return db.session.execute(stmt).scalars().one()


def dao_get_notification_history_by_reference(reference):
try:
# This try except is necessary because in test keys and research mode does not create notification history.
# Otherwise we could just search for the NotificationHistory object
return Notification.query.filter(Notification.reference == reference).one()
stmt = select(Notification).filter(Notification.reference == reference)
return db.session.execute(stmt).scalars().one()
except NoResultFound:
return NotificationHistory.query.filter(
stmt = select(NotificationHistory).filter(
NotificationHistory.reference == reference
).one()
)
return db.session.execute(stmt).scalars().one()


def dao_get_notifications_processing_time_stats(start_date, end_date):
Expand Down Expand Up @@ -590,23 +610,25 @@ def dao_get_notifications_processing_time_stats(start_date, end_date):


def dao_get_last_notification_added_for_job_id(job_id):
last_notification_added = (
Notification.query.filter(Notification.job_id == job_id)
stmt = (
select(Notification)
.filter(Notification.job_id == job_id)
.order_by(Notification.job_row_number.desc())
.first()
)
last_notification_added = db.session.execute(stmt).scalars().first()

return last_notification_added


def notifications_not_yet_sent(should_be_sending_after_seconds, notification_type):
older_than_date = utc_now() - timedelta(seconds=should_be_sending_after_seconds)

notifications = Notification.query.filter(
stmt = select(Notification).filter(
Notification.created_at <= older_than_date,
Notification.notification_type == notification_type,
Notification.status == NotificationStatus.CREATED,
).all()
)
notifications = db.session.execute(stmt).scalars().all()
return notifications


Expand Down
55 changes: 33 additions & 22 deletions app/dao/users_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import sqlalchemy
from flask import current_app
from sqlalchemy import func, text
from sqlalchemy import delete, func, select, text
from sqlalchemy.orm import joinedload

from app import db
Expand Down Expand Up @@ -37,8 +37,8 @@ def get_login_gov_user(login_uuid, email_address):
login.gov uuids are. Eventually the code that checks by email address
should be removed.
"""

user = User.query.filter_by(login_uuid=login_uuid).first()
stmt = select(User).filter_by(login_uuid=login_uuid)
user = db.session.execute(stmt).scalars().first()
if user:
if user.email_address != email_address:
try:
Expand All @@ -54,7 +54,8 @@ def get_login_gov_user(login_uuid, email_address):

return user
# Remove this 1 July 2025, all users should have login.gov uuids by now
user = User.query.filter(User.email_address.ilike(email_address)).first()
stmt = select(User).filter(User.email_address.ilike(email_address))
user = db.session.execute(stmt).scalars().first()

if user:
save_user_attribute(user, {"login_uuid": login_uuid})
Expand Down Expand Up @@ -102,24 +103,27 @@ def create_user_code(user, code, code_type):
def get_user_code(user, code, code_type):
# Get the most recent codes to try and reduce the
# time searching for the correct code.
codes = VerifyCode.query.filter_by(user=user, code_type=code_type).order_by(
VerifyCode.created_at.desc()
stmt = (
select(VerifyCode)
.filter_by(user=user, code_type=code_type)
.order_by(VerifyCode.created_at.desc())
)
codes = db.session.execute(stmt).scalars().all()
return next((x for x in codes if x.check_code(code)), None)


def delete_codes_older_created_more_than_a_day_ago():
deleted = (
db.session.query(VerifyCode)
.filter(VerifyCode.created_at < utc_now() - timedelta(hours=24))
.delete()
stmt = delete(VerifyCode).filter(
VerifyCode.created_at < utc_now() - timedelta(hours=24)
)

deleted = db.session.execute(stmt)
db.session.commit()
return deleted


def use_user_code(id):
verify_code = VerifyCode.query.get(id)
verify_code = db.session.get(VerifyCode, id)
verify_code.code_used = True
db.session.add(verify_code)
db.session.commit()
Expand All @@ -131,36 +135,42 @@ def delete_model_user(user):


def delete_user_verify_codes(user):
VerifyCode.query.filter_by(user=user).delete()
stmt = delete(VerifyCode).filter_by(user=user)
db.session.execute(stmt)
db.session.commit()


def count_user_verify_codes(user):
query = VerifyCode.query.filter(
stmt = select(func.count(VerifyCode.id)).filter(
VerifyCode.user == user,
VerifyCode.expiry_datetime > utc_now(),
VerifyCode.code_used.is_(False),
)
return query.count()
result = db.session.execute(stmt).scalar()
return result or 0


def get_user_by_id(user_id=None):
if user_id:
return User.query.filter_by(id=user_id).one()
return User.query.filter_by().all()
stmt = select(User).filter_by(id=user_id)
return db.session.execute(stmt).scalars().one()
return get_users()


def get_users():
return User.query.all()
stmt = select(User)
return db.session.execute(stmt).scalars().all()


def get_user_by_email(email):
return User.query.filter(func.lower(User.email_address) == func.lower(email)).one()
stmt = select(User).filter(func.lower(User.email_address) == func.lower(email))
return db.session.execute(stmt).scalars().one()


def get_users_by_partial_email(email):
email = escape_special_characters(email)
return User.query.filter(User.email_address.ilike("%{}%".format(email))).all()
stmt = select(User).filter(User.email_address.ilike("%{}%".format(email)))
return db.session.execute(stmt).scalars().all()


def increment_failed_login_count(user):
Expand Down Expand Up @@ -188,16 +198,17 @@ def get_user_and_accounts(user_id):
# TODO: With sqlalchemy 2.0 change as below because of the breaking change
# at User.organizations.services, we need to verify that the below subqueryload
# that we have put is functionally doing the same thing as before
return (
User.query.filter(User.id == user_id)
stmt = (
select(User)
.filter(User.id == user_id)
.options(
# eagerly load the user's services and organizations, and also the service's org and vice versa
# (so we can see if the user knows about it)
joinedload(User.services).joinedload(Service.organization),
joinedload(User.organizations).subqueryload(Organization.services),
)
.one()
)
return db.session.execute(stmt).scalars().unique().one()


@autocommit
Expand Down
Loading
Loading