diff --git a/packages/discovery-provider/ddl/migrations/0060_aggregate_monthly_plays_country.sql b/packages/discovery-provider/ddl/migrations/0060_aggregate_monthly_plays_country.sql new file mode 100644 index 00000000000..8e11b592391 --- /dev/null +++ b/packages/discovery-provider/ddl/migrations/0060_aggregate_monthly_plays_country.sql @@ -0,0 +1,9 @@ +begin; + +alter table aggregate_monthly_plays add column if not exists country text not null default ''; + +ALTER TABLE aggregate_monthly_plays DROP CONSTRAINT aggregate_monthly_plays_pkey; + +ALTER TABLE aggregate_monthly_plays ADD PRIMARY KEY (play_item_id, "timestamp", country); + +commit; diff --git a/packages/discovery-provider/src/models/social/aggregate_monthly_plays.py b/packages/discovery-provider/src/models/social/aggregate_monthly_plays.py index 529d330a66a..cd94ee3ce1b 100644 --- a/packages/discovery-provider/src/models/social/aggregate_monthly_plays.py +++ b/packages/discovery-provider/src/models/social/aggregate_monthly_plays.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Date, Integer, text +from sqlalchemy import Column, Date, Integer, String, text from src.models.base import Base from src.models.model_utils import RepresentableMixin @@ -13,4 +13,5 @@ class AggregateMonthlyPlay(Base, RepresentableMixin): timestamp = Column( Date, primary_key=True, nullable=False, server_default=text("CURRENT_TIMESTAMP") ) + country = Column(String, primary_key=True, nullable=False, server_default="") count = Column(Integer, nullable=False) diff --git a/packages/discovery-provider/src/queries/get_user_listen_counts_monthly.py b/packages/discovery-provider/src/queries/get_user_listen_counts_monthly.py index e3630552f23..90919e0db3d 100644 --- a/packages/discovery-provider/src/queries/get_user_listen_counts_monthly.py +++ b/packages/discovery-provider/src/queries/get_user_listen_counts_monthly.py @@ -1,11 +1,9 @@ from typing import TypedDict +from sqlalchemy import text from sqlalchemy.orm.session import Session -from src.models.social.aggregate_monthly_plays import AggregateMonthlyPlay -from src.models.tracks.track import Track from src.utils.db_session import get_db_read_replica -from src.utils.helpers import query_result_to_list class GetUserListenCountsMonthlyArgs(TypedDict): @@ -33,23 +31,24 @@ def get_user_listen_counts_monthly(args: GetUserListenCountsMonthlyArgs): db = get_db_read_replica() with db.scoped_session() as session: - user_listen_counts_monthly = _get_user_listen_counts_monthly(session, args) - return query_result_to_list(user_listen_counts_monthly) + return list(_get_user_listen_counts_monthly(session, args)) def _get_user_listen_counts_monthly( session: Session, args: GetUserListenCountsMonthlyArgs ): - user_id = args["user_id"] - start_time = args["start_time"] - end_time = args["end_time"] - - query = ( - session.query(AggregateMonthlyPlay) - .join(Track, Track.track_id == AggregateMonthlyPlay.play_item_id) - .filter(Track.owner_id == user_id) - .filter(Track.is_current == True) - .filter(AggregateMonthlyPlay.timestamp >= start_time) - .filter(AggregateMonthlyPlay.timestamp < end_time) + sql = text( + """ + select + play_item_id, + timestamp, + sum(count) as count + from aggregate_monthly_plays + where play_item_id in (select track_id from tracks where owner_id = :user_id) + and timestamp >= :start_time + and timestamp < :end_time + group by play_item_id, timestamp + """ ) - return query.all() + + return session.execute(sql, args) diff --git a/packages/discovery-provider/src/tasks/index_aggregate_monthly_plays.py b/packages/discovery-provider/src/tasks/index_aggregate_monthly_plays.py index b2ce416f3f0..352a4477dee 100644 --- a/packages/discovery-provider/src/tasks/index_aggregate_monthly_plays.py +++ b/packages/discovery-provider/src/tasks/index_aggregate_monthly_plays.py @@ -25,6 +25,7 @@ select play_item_id, date_trunc('month', created_at) as timestamp, + coalesce(country, '') as country, count(play_item_id) as count from plays p @@ -32,16 +33,17 @@ p.id > :prev_id_checkpoint and p.id <= :new_id_checkpoint group by - play_item_id, date_trunc('month', created_at) + play_item_id, date_trunc('month', created_at), coalesce(country, '') ) insert into - aggregate_monthly_plays (play_item_id, timestamp, count) + aggregate_monthly_plays (play_item_id, timestamp, country, count) select new_plays.play_item_id, new_plays.timestamp, + new_plays.country, new_plays.count from - new_plays on conflict (play_item_id, timestamp) do + new_plays on conflict (play_item_id, timestamp, country) do update set count = aggregate_monthly_plays.count + excluded.count