diff --git a/server/app/pseudo_tweets/routes.py b/server/app/pseudo_tweets/routes.py index 6bb822c3..47259e1c 100644 --- a/server/app/pseudo_tweets/routes.py +++ b/server/app/pseudo_tweets/routes.py @@ -26,7 +26,6 @@ TweetRead, TweetUpdate, ) -from ..tweets_common.types import Month from . import router @@ -45,8 +44,8 @@ def get_pseudo_overview(all: bool = False, session: Session = Depends(get_sessio @router.get("/count", response_model=TweetCount) def get_count( topics: Optional[List[Topics]] = Query(None), - day: Optional[date] = None, - month: Optional[Month] = None, + start_date: Optional[date] = None, + end_date: Optional[date] = None, all: bool = False, session: Session = Depends(get_session), ): @@ -56,7 +55,7 @@ def get_count( Model = get_combined_model() if all else PseudoTweet - return get_filtered_count(Model, topics, day, month, session) + return get_filtered_count(Model, topics, start_date, end_date, session) @router.get("/", response_model=List[TweetRead]) @@ -64,15 +63,15 @@ def read_pseudo_tweets( offset: NonNegativeInt = 0, limit: conint(le=10, gt=0) = 10, topics: Optional[List[Topics]] = Query(None), - day: Optional[date] = None, - month: Optional[Month] = Query(None, description="Month in %Y-%m format"), + start_date: Optional[date] = None, + end_date: Optional[date] = None, maximize_labels: bool = False, session: Session = Depends(get_session), ): """ Read pseudo tweets within the offset and limit """ - selection = get_filtered_selection(topics, PseudoTweet, day, month) + selection = get_filtered_selection(PseudoTweet, topics, start_date, end_date) # others should be exclusively provided, hence the last check is_others = topics is not None and len(topics) and topics[0] == Topics.others diff --git a/server/app/tweets/routes.py b/server/app/tweets/routes.py index 7af05beb..d71f8d8f 100644 --- a/server/app/tweets/routes.py +++ b/server/app/tweets/routes.py @@ -24,7 +24,6 @@ TweetRead, TweetUpdate, ) -from ..tweets_common.types import Month from . import router @@ -41,15 +40,15 @@ def get_tweet_overview(session: Session = Depends(get_session)): @router.get("/count", response_model=TweetCount) def get_count( topics: Optional[List[Topics]] = Query(None), - day: Optional[date] = None, - month: Optional[Month] = None, + start_date: Optional[date] = None, + end_date: Optional[date] = None, session: Session = Depends(get_session), ): """ Get the count of tweets for the given filters """ - return get_filtered_count(Tweet, topics, day, month, session) + return get_filtered_count(Tweet, topics, start_date, end_date, session) @router.get("/", response_model=List[TweetRead]) @@ -57,14 +56,14 @@ def read_tweets( offset: NonNegativeInt = 0, limit: conint(le=10, gt=0) = 10, topics: Optional[List[Topics]] = Query(None), - day: Optional[date] = None, - month: Optional[Month] = Query(None, description="Month in %Y-%m format"), + start_date: Optional[date] = None, + end_date: Optional[date] = None, session: Session = Depends(get_session), ): """ Read tweets within the offset and limit """ - selection = get_filtered_selection(topics, Tweet, day, month) + selection = get_filtered_selection(Tweet, topics, start_date, end_date) tweets = session.exec( selection.order_by(Tweet.id.desc()).offset(offset).limit(limit) diff --git a/server/app/tweets_common/helper_functions.py b/server/app/tweets_common/helper_functions.py index 7468e2d9..3d8e5fde 100644 --- a/server/app/tweets_common/helper_functions.py +++ b/server/app/tweets_common/helper_functions.py @@ -7,49 +7,66 @@ from sqlmodel.sql.expression import Select from .models import PseudoTweet, Topics, Tweet, TweetRead, TweetUpdate -from .types import Month # Make a Generic Type to get the original type completion back ModelType = TypeVar("ModelType", Tweet, PseudoTweet) -def get_filtered_selection( - topics: Optional[Collection[Topics]], +def get_selection_filter( Model: ModelType, - day: Optional[date] = None, - month: Optional[Month] = None, - fields: Optional[Collection[str]] = None, + topics: Optional[Collection[Topics]], + start_date: Optional[date], + end_date: Optional[date], + selection: Select[tuple], + others_filter, ): """ - Get selection query with filter depending upon topics provided + Filter the selection by various dimensions. + `selection` is the selection before filtering + `others_filter` signifies the filter to be used for others column. + (direct text or creation of others column) """ - selection = get_scalar_select(Model, fields) - if topics is not None: if Topics.others in topics: if len(topics) > 1: raise HTTPException(400, "Can't filter by others and other topics.") # If others is defined in the selection, directly provide the column - filter = ( - text(Topics.others) - if fields is None or not len(fields) or "others" in fields - else get_others_column(Model) - ) + filter = others_filter else: filter = and_(*tuple(getattr(Model, topic) for topic in topics)) selection = selection.filter(filter) - if day is not None or month is not None: - # If both specified, use day only - filter = ( - func.date(Model.created_at) == day - if day is not None - else func.strftime("%Y-%m", Model.created_at) == month - ) - selection = selection.filter(filter) + created_date = func.date(Model.created_at) + + if start_date is not None: + selection = selection.filter(created_date >= start_date) + + if end_date is not None: + selection = selection.filter(created_date < end_date) + return selection + + +def get_filtered_selection( + Model: ModelType, + topics: Optional[Collection[Topics]], + start_date: Optional[date] = None, + end_date: Optional[date] = None, + fields: Optional[Collection[str]] = None, +): + """ + Get selection query with filter depending upon topics provided + """ + + selection, contains_others = get_scalar_select(Model, fields) + + others_filter = text(Topics.others) if contains_others else get_others_column(Model) + + selection = get_selection_filter( + Model, topics, start_date, end_date, selection, others_filter + ) return selection @@ -58,9 +75,10 @@ def get_a_tweet(session: Session, tweet_id: PositiveInt, Model: ModelType) -> tu """ Get a not-None tweet from the database with others column as a dictonary """ - tweet = session.exec( - get_scalar_select(Model).where(Model.id == tweet_id) - ).one_or_none() + + selection, _ = get_scalar_select(Model).where(Model.id == tweet_id) + + tweet = session.exec(selection).one_or_none() assert_not_null(tweet, tweet_id, Model) @@ -100,7 +118,7 @@ def assert_not_null(tweet: Optional[ModelType], id: PositiveInt, Model: ModelTyp def get_scalar_select( Model: ModelType, fields: Optional[Collection[str]] = None -) -> Select[tuple]: +) -> Tuple[Select[tuple], bool]: """ Get a select statement for the Model with others column """ @@ -112,10 +130,12 @@ def get_scalar_select( db_fields = list(getattr(Model, field) for field in fields) - if is_fields_empty or "others" in fields: + contains_others = is_fields_empty or "others" in fields + + if contains_others: db_fields.append(get_others_column(Model)) - return select(*db_fields) + return select(*db_fields), contains_others MapperReturnType = TypeVar("MapperReturnType") @@ -203,8 +223,8 @@ def get_model_attr(attr: str): def get_filtered_count( Model: ModelType, topics: Optional[Collection[Topics]], - day: Optional[date], - month: Optional[Month], + start_date: Optional[date], + end_date: Optional[date], session: Session, ): def get_sum_column(column: str): @@ -225,25 +245,10 @@ def get_sum_column(column: str): func.count().label("total"), ) - if topics is not None: - if Topics.others in topics: - if len(topics) > 1: - raise HTTPException(400, "Can't filter by others and other topics.") - - # If others is defined in the selection, directly provide the column - filter = text(Topics.others) - else: - filter = and_(*tuple(getattr(Model, topic) for topic in topics)) + others_filter = text(Topics.others) - selection = selection.filter(filter) - - if day is not None or month is not None: - # If both specified, use day only - filter = ( - func.date(Model.created_at) == day - if day is not None - else func.strftime("%Y-%m", Model.created_at) == month - ) - selection = selection.filter(filter) + selection = get_selection_filter( + Model, topics, start_date, end_date, selection, others_filter + ) return session.exec(selection).one() diff --git a/server/app/tweets_common/routes.py b/server/app/tweets_common/routes.py index 35d8d74b..fb6e18f2 100644 --- a/server/app/tweets_common/routes.py +++ b/server/app/tweets_common/routes.py @@ -12,7 +12,6 @@ from ..database import get_session from . import router from .models import PredictionOutput, PseudoTweet, Topics, Tweet -from .types import Month from .word_cloud_helper import get_word_count_distribution CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours @@ -22,8 +21,8 @@ @timed_lru_cache(seconds=CACHE_TIMEOUT, maxsize=64) def get_word_cloud( topics: Optional[Tuple[Topics, ...]] = Query(None), - day: Optional[date] = None, - month: Optional[Month] = None, + start_date: Optional[date] = None, + end_date: Optional[date] = None, session: Session = Depends(get_session), ): """ @@ -32,10 +31,10 @@ def get_word_cloud( fields = ("text",) - tweet_selection = get_filtered_selection(topics, Tweet, day, month, fields) - pseudo_tweet_selection = get_filtered_selection( - topics, PseudoTweet, day, month, fields - ) + args = (topics, start_date, end_date, fields) + + tweet_selection = get_filtered_selection(Tweet, *args) + pseudo_tweet_selection = get_filtered_selection(PseudoTweet, *args) combined_model = union_all(tweet_selection, pseudo_tweet_selection).subquery().c @@ -43,7 +42,9 @@ def get_word_cloud( combined_tweets = session.exec(select(combined_model.text)).all() # change list of tweets to tuple to allow caching - word_freq = get_word_count_distribution(tuple(combined_tweets)) + combined_tweets = tuple(combined_tweets) + + word_freq = get_word_count_distribution(combined_tweets) return word_freq