Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade test queries to sqlalchemy 2.0 #1382

Merged
merged 28 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .ds.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
"filename": "tests/app/db.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 87,
"line_number": 90,
"is_secret": false
}
],
Expand All @@ -277,15 +277,15 @@
"filename": "tests/app/notifications/test_receive_notification.py",
"hashed_secret": "913a73b565c8e2c8ed94497580f619397709b8b6",
"is_verified": false,
"line_number": 24,
"line_number": 26,
"is_secret": false
},
{
"type": "Base64 High Entropy String",
"filename": "tests/app/notifications/test_receive_notification.py",
"hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65",
"is_verified": false,
"line_number": 54,
"line_number": 56,
"is_secret": false
}
],
Expand All @@ -305,7 +305,7 @@
"filename": "tests/app/service/test_rest.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 1275,
"line_number": 1284,
"is_secret": false
}
],
Expand Down Expand Up @@ -341,15 +341,15 @@
"filename": "tests/app/user/test_rest.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 106,
"line_number": 108,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "tests/app/user/test_rest.py",
"hashed_secret": "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33",
"is_verified": false,
"line_number": 810,
"line_number": 826,
"is_secret": false
}
],
Expand Down Expand Up @@ -384,5 +384,5 @@
}
]
},
"generated_at": "2024-10-28T20:26:27Z"
"generated_at": "2024-10-31T21:25:32Z"
}
47 changes: 30 additions & 17 deletions app/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from faker import Faker
from flask import current_app, json
from notifications_python_client.authentication import create_jwt_token
from sqlalchemy import and_, text
from sqlalchemy import and_, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound

Expand Down Expand Up @@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix):
if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]:
current_app.logger.error("Can only be run in development")
return

users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all()
stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%"))
users = db.session.execute(stmt).scalars().all()
for usr in users:
# Make sure the full email includes a uuid in it
# Just in case someone decides to use a similar email address.
Expand Down Expand Up @@ -338,9 +338,10 @@ def boolean_or_none(field):
email_branding = None
email_branding_column = columns[5].strip()
if len(email_branding_column) > 0:
email_branding = EmailBranding.query.filter(
stmt = select(EmailBranding).where(
EmailBranding.name == email_branding_column
).one()
)
email_branding = db.session.execute(stmt).scalars().one()
data = {
"name": columns[0],
"active": True,
Expand Down Expand Up @@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name):

@notify_command(name="associate-services-to-organizations")
def associate_services_to_organizations():
services = Service.get_history_model().query.filter_by(version=1).all()
stmt = select(Service.get_history_model()).where(
Service.get_history_model().version == 1
)
services = db.session.execute(stmt).scalars().all()

for s in services:
created_by_user = User.query.filter_by(id=s.created_by_id).first()
stmt = select(User).where(User.id == s.created_by_id)
created_by_user = db.session.execute(stmt).scalars().first()
organization = dao_get_organization_by_email_address(
created_by_user.email_address
)
Expand Down Expand Up @@ -467,15 +472,16 @@ def populate_go_live(file_name):

@notify_command(name="fix-billable-units")
def fix_billable_units():
query = Notification.query.filter(
stmt = select(Notification).where(
Notification.notification_type == NotificationType.SMS,
Notification.status != NotificationStatus.CREATED,
Notification.sent_at == None, # noqa
Notification.billable_units == 0,
Notification.key_type != KeyType.TEST,
)
all = db.session.execute(stmt).scalars().all()

for notification in query.all():
for notification in all:
template_model = dao_get_template_by_id(
notification.template_id, notification.template_version
)
Expand All @@ -490,9 +496,12 @@ def fix_billable_units():
f"Updating notification: {notification.id} with {template.fragment_count} billable_units"
)

Notification.query.filter(Notification.id == notification.id).update(
{"billable_units": template.fragment_count}
stmt = (
update(Notification)
.where(Notification.id == notification.id)
.values({"billable_units": template.fragment_count})
)
db.session.execute(stmt)
db.session.commit()
current_app.logger.info("End fix_billable_units")

Expand Down Expand Up @@ -637,8 +646,9 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
This is useful to ensure all services start the new year with the correct annual billing.
"""
if missing_services_only:
active_services = (
Service.query.filter(Service.active)
stmt = (
select(Service)
.where(Service.active)
.outerjoin(
AnnualBilling,
and_(
Expand All @@ -647,10 +657,11 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
),
)
.filter(AnnualBilling.id == None) # noqa
.all()
)
active_services = db.session.execute(stmt).scalars().all()
else:
active_services = Service.query.filter(Service.active).all()
stmt = select(Service).where(Service.active)
active_services = db.session.execute(stmt).scalars().all()
previous_year = year - 1
services_with_zero_free_allowance = (
db.session.query(AnnualBilling.service_id)
Expand Down Expand Up @@ -750,7 +761,8 @@ def create_user_jwt(token):


def _update_template(id, name, template_type, content, subject):
template = Template.query.filter_by(id=id).first()
stmt = select(Template).where(Template.id == id)
template = db.session.execute(stmt).scalars().first()
if not template:
template = Template(id=id)
template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
Expand All @@ -761,7 +773,8 @@ def _update_template(id, name, template_type, content, subject):
template.content = "\n".join(content)
template.subject = subject

history = TemplateHistory.query.filter_by(id=id).first()
stmt = select(TemplateHistory).where(TemplateHistory.id == id)
history = db.session.execute(stmt).scalars().first()
if not history:
history = TemplateHistory(id=id)
history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
Expand Down
23 changes: 16 additions & 7 deletions app/dao/invited_user_dao.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import timedelta

from sqlalchemy import select

from app import db
from app.enums import InvitedUserStatus
from app.models import InvitedUser
Expand All @@ -12,30 +14,37 @@ def save_invited_user(invited_user):


def get_invited_user_by_service_and_id(service_id, invited_user_id):
return InvitedUser.query.filter(

stmt = select(InvitedUser).where(
InvitedUser.service_id == service_id,
InvitedUser.id == invited_user_id,
).one()
)
return db.session.execute(stmt).scalars().one()


def get_expired_invite_by_service_and_id(service_id, invited_user_id):
return InvitedUser.query.filter(
stmt = select(InvitedUser).where(
InvitedUser.service_id == service_id,
InvitedUser.id == invited_user_id,
InvitedUser.status == InvitedUserStatus.EXPIRED,
).one()
)
return db.session.execute(stmt).scalars().one()


def get_invited_user_by_id(invited_user_id):
return InvitedUser.query.filter(InvitedUser.id == invited_user_id).one()
stmt = select(InvitedUser).where(InvitedUser.id == invited_user_id)
return db.session.execute(stmt).scalars().one()


def get_expired_invited_users_for_service(service_id):
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
# TODO why does this return all invited users?
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
return db.session.execute(stmt).scalars().all()


def get_invited_users_for_service(service_id):
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
return db.session.execute(stmt).scalars().all()


def expire_invitations_created_more_than_two_days_ago():
Expand Down
1 change: 1 addition & 0 deletions app/dao/notifications_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def get_notifications_for_job(
):
if page_size is None:
page_size = current_app.config["PAGE_SIZE"]

query = Notification.query.filter_by(service_id=service_id, job_id=job_id)
query = _filter_query(query, filter_dict)
return query.order_by(asc(Notification.job_row_number)).paginate(
Expand Down
27 changes: 16 additions & 11 deletions app/dao/provider_details_dao.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime

from flask import current_app
from sqlalchemy import desc, func
from sqlalchemy import desc, func, select

from app import db
from app.dao.dao_utils import autocommit
Expand All @@ -11,11 +11,12 @@


def get_provider_details_by_id(provider_details_id):
return ProviderDetails.query.get(provider_details_id)
return db.session.get(ProviderDetails, provider_details_id)


def get_provider_details_by_identifier(identifier):
return ProviderDetails.query.filter_by(identifier=identifier).one()
stmt = select(ProviderDetails).where(ProviderDetails.identifier == identifier)
return db.session.execute(stmt).scalars().one()


def get_alternative_sms_provider(identifier):
Expand All @@ -25,12 +26,14 @@ def get_alternative_sms_provider(identifier):


def dao_get_provider_versions(provider_id):
return (
ProviderDetailsHistory.query.filter_by(id=provider_id)
stmt = (
select(ProviderDetailsHistory)
.where(ProviderDetailsHistory.id == provider_id)
.order_by(desc(ProviderDetailsHistory.version))
.limit(100) # limit results instead of adding pagination
.all()
.limit(100)
)
# limit results instead of adding pagination
return db.session.execute(stmt).scalars().all()


def _get_sms_providers_for_update(time_threshold):
Expand All @@ -42,14 +45,15 @@ def _get_sms_providers_for_update(time_threshold):
release the transaction in that case
"""
# get current priority of both providers
q = (
ProviderDetails.query.filter(
stmt = (
select(ProviderDetails)
.where(
ProviderDetails.notification_type == NotificationType.SMS,
ProviderDetails.active,
)
.with_for_update()
.all()
)
q = db.session.execute(stmt).scalars().all()

# if something updated recently, don't update again. If the updated_at is null, treat it as min time
if any(
Expand All @@ -72,7 +76,8 @@ def get_provider_details_by_notification_type(
if supports_international:
filters.append(ProviderDetails.supports_international == supports_international)

return ProviderDetails.query.filter(*filters).all()
stmt = select(ProviderDetails).where(*filters)
return db.session.execute(stmt).scalars().all()


@autocommit
Expand Down
Loading
Loading