Skip to content

Commit

Permalink
Add country column to aggregate_monthly_plays (#7801)
Browse files Browse the repository at this point in the history
  • Loading branch information
stereosteve authored Mar 13, 2024
1 parent e6267f1 commit e841490
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
begin;

alter table aggregate_monthly_plays add column if not exists country text not null default '';

ALTER TABLE aggregate_monthly_plays DROP CONSTRAINT aggregate_monthly_plays_pkey;

ALTER TABLE aggregate_monthly_plays ADD PRIMARY KEY (play_item_id, "timestamp", country);

commit;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column, Date, Integer, text
from sqlalchemy import Column, Date, Integer, String, text

from src.models.base import Base
from src.models.model_utils import RepresentableMixin
Expand All @@ -13,4 +13,5 @@ class AggregateMonthlyPlay(Base, RepresentableMixin):
timestamp = Column(
Date, primary_key=True, nullable=False, server_default=text("CURRENT_TIMESTAMP")
)
country = Column(String, primary_key=True, nullable=False, server_default="")
count = Column(Integer, nullable=False)
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import TypedDict

from sqlalchemy import text
from sqlalchemy.orm.session import Session

from src.models.social.aggregate_monthly_plays import AggregateMonthlyPlay
from src.models.tracks.track import Track
from src.utils.db_session import get_db_read_replica
from src.utils.helpers import query_result_to_list


class GetUserListenCountsMonthlyArgs(TypedDict):
Expand Down Expand Up @@ -33,23 +31,24 @@ def get_user_listen_counts_monthly(args: GetUserListenCountsMonthlyArgs):

db = get_db_read_replica()
with db.scoped_session() as session:
user_listen_counts_monthly = _get_user_listen_counts_monthly(session, args)
return query_result_to_list(user_listen_counts_monthly)
return list(_get_user_listen_counts_monthly(session, args))


def _get_user_listen_counts_monthly(
session: Session, args: GetUserListenCountsMonthlyArgs
):
user_id = args["user_id"]
start_time = args["start_time"]
end_time = args["end_time"]

query = (
session.query(AggregateMonthlyPlay)
.join(Track, Track.track_id == AggregateMonthlyPlay.play_item_id)
.filter(Track.owner_id == user_id)
.filter(Track.is_current == True)
.filter(AggregateMonthlyPlay.timestamp >= start_time)
.filter(AggregateMonthlyPlay.timestamp < end_time)
sql = text(
"""
select
play_item_id,
timestamp,
sum(count) as count
from aggregate_monthly_plays
where play_item_id in (select track_id from tracks where owner_id = :user_id)
and timestamp >= :start_time
and timestamp < :end_time
group by play_item_id, timestamp
"""
)
return query.all()

return session.execute(sql, args)
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,25 @@
select
play_item_id,
date_trunc('month', created_at) as timestamp,
coalesce(country, '') as country,
count(play_item_id) as count
from
plays p
where
p.id > :prev_id_checkpoint
and p.id <= :new_id_checkpoint
group by
play_item_id, date_trunc('month', created_at)
play_item_id, date_trunc('month', created_at), coalesce(country, '')
)
insert into
aggregate_monthly_plays (play_item_id, timestamp, count)
aggregate_monthly_plays (play_item_id, timestamp, country, count)
select
new_plays.play_item_id,
new_plays.timestamp,
new_plays.country,
new_plays.count
from
new_plays on conflict (play_item_id, timestamp) do
new_plays on conflict (play_item_id, timestamp, country) do
update
set
count = aggregate_monthly_plays.count + excluded.count
Expand Down

0 comments on commit e841490

Please sign in to comment.