Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend: Added start_date and end_date #190

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions server/app/pseudo_tweets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TweetRead,
TweetUpdate,
)
from ..tweets_common.types import Month
from . import router


Expand All @@ -45,8 +44,8 @@ def get_pseudo_overview(all: bool = False, session: Session = Depends(get_sessio
@router.get("/count", response_model=TweetCount)
def get_count(
topics: Optional[List[Topics]] = Query(None),
day: Optional[date] = None,
month: Optional[Month] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
all: bool = False,
session: Session = Depends(get_session),
):
Expand All @@ -56,23 +55,23 @@ def get_count(

Model = get_combined_model() if all else PseudoTweet

return get_filtered_count(Model, topics, day, month, session)
return get_filtered_count(Model, topics, start_date, end_date, session)


@router.get("/", response_model=List[TweetRead])
def read_pseudo_tweets(
offset: NonNegativeInt = 0,
limit: conint(le=10, gt=0) = 10,
topics: Optional[List[Topics]] = Query(None),
day: Optional[date] = None,
month: Optional[Month] = Query(None, description="Month in %Y-%m format"),
start_date: Optional[date] = None,
end_date: Optional[date] = None,
maximize_labels: bool = False,
session: Session = Depends(get_session),
):
"""
Read pseudo tweets within the offset and limit
"""
selection = get_filtered_selection(topics, PseudoTweet, day, month)
selection = get_filtered_selection(PseudoTweet, topics, start_date, end_date)

# others should be exclusively provided, hence the last check
is_others = topics is not None and len(topics) and topics[0] == Topics.others
Expand Down
13 changes: 6 additions & 7 deletions server/app/tweets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
TweetRead,
TweetUpdate,
)
from ..tweets_common.types import Month
from . import router


Expand All @@ -41,30 +40,30 @@ def get_tweet_overview(session: Session = Depends(get_session)):
@router.get("/count", response_model=TweetCount)
def get_count(
topics: Optional[List[Topics]] = Query(None),
day: Optional[date] = None,
month: Optional[Month] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
session: Session = Depends(get_session),
):
"""
Get the count of tweets for the given filters
"""

return get_filtered_count(Tweet, topics, day, month, session)
return get_filtered_count(Tweet, topics, start_date, end_date, session)


@router.get("/", response_model=List[TweetRead])
def read_tweets(
offset: NonNegativeInt = 0,
limit: conint(le=10, gt=0) = 10,
topics: Optional[List[Topics]] = Query(None),
day: Optional[date] = None,
month: Optional[Month] = Query(None, description="Month in %Y-%m format"),
start_date: Optional[date] = None,
end_date: Optional[date] = None,
session: Session = Depends(get_session),
):
"""
Read tweets within the offset and limit
"""
selection = get_filtered_selection(topics, Tweet, day, month)
selection = get_filtered_selection(Tweet, topics, start_date, end_date)

tweets = session.exec(
selection.order_by(Tweet.id.desc()).offset(offset).limit(limit)
Expand Down
103 changes: 54 additions & 49 deletions server/app/tweets_common/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,66 @@
from sqlmodel.sql.expression import Select

from .models import PseudoTweet, Topics, Tweet, TweetRead, TweetUpdate
from .types import Month

# Make a Generic Type to get the original type completion back
ModelType = TypeVar("ModelType", Tweet, PseudoTweet)


def get_filtered_selection(
topics: Optional[Collection[Topics]],
def get_selection_filter(
Model: ModelType,
day: Optional[date] = None,
month: Optional[Month] = None,
fields: Optional[Collection[str]] = None,
topics: Optional[Collection[Topics]],
start_date: Optional[date],
end_date: Optional[date],
selection: Select[tuple],
others_filter,
):
"""
Get selection query with filter depending upon topics provided
Filter the selection by various dimensions.
`selection` is the selection before filtering
`others_filter` signifies the filter to be used for others column.
(direct text or creation of others column)
"""

selection = get_scalar_select(Model, fields)

if topics is not None:
if Topics.others in topics:
if len(topics) > 1:
raise HTTPException(400, "Can't filter by others and other topics.")

# If others is defined in the selection, directly provide the column
filter = (
text(Topics.others)
if fields is None or not len(fields) or "others" in fields
else get_others_column(Model)
)
filter = others_filter
else:
filter = and_(*tuple(getattr(Model, topic) for topic in topics))

selection = selection.filter(filter)

if day is not None or month is not None:
# If both specified, use day only
filter = (
func.date(Model.created_at) == day
if day is not None
else func.strftime("%Y-%m", Model.created_at) == month
)
selection = selection.filter(filter)
created_date = func.date(Model.created_at)

if start_date is not None:
selection = selection.filter(created_date >= start_date)

if end_date is not None:
selection = selection.filter(created_date < end_date)
return selection


def get_filtered_selection(
Model: ModelType,
topics: Optional[Collection[Topics]],
start_date: Optional[date] = None,
end_date: Optional[date] = None,
fields: Optional[Collection[str]] = None,
):
"""
Get selection query with filter depending upon topics provided
"""

selection, contains_others = get_scalar_select(Model, fields)

others_filter = text(Topics.others) if contains_others else get_others_column(Model)

selection = get_selection_filter(
Model, topics, start_date, end_date, selection, others_filter
)

return selection

Expand All @@ -58,9 +75,10 @@ def get_a_tweet(session: Session, tweet_id: PositiveInt, Model: ModelType) -> tu
"""
Get a not-None tweet from the database with others column as a dictonary
"""
tweet = session.exec(
get_scalar_select(Model).where(Model.id == tweet_id)
).one_or_none()

selection, _ = get_scalar_select(Model).where(Model.id == tweet_id)

tweet = session.exec(selection).one_or_none()

assert_not_null(tweet, tweet_id, Model)

Expand Down Expand Up @@ -100,7 +118,7 @@ def assert_not_null(tweet: Optional[ModelType], id: PositiveInt, Model: ModelTyp

def get_scalar_select(
Model: ModelType, fields: Optional[Collection[str]] = None
) -> Select[tuple]:
) -> Tuple[Select[tuple], bool]:
"""
Get a select statement for the Model with others column
"""
Expand All @@ -112,10 +130,12 @@ def get_scalar_select(

db_fields = list(getattr(Model, field) for field in fields)

if is_fields_empty or "others" in fields:
contains_others = is_fields_empty or "others" in fields

if contains_others:
db_fields.append(get_others_column(Model))

return select(*db_fields)
return select(*db_fields), contains_others


MapperReturnType = TypeVar("MapperReturnType")
Expand Down Expand Up @@ -203,8 +223,8 @@ def get_model_attr(attr: str):
def get_filtered_count(
Model: ModelType,
topics: Optional[Collection[Topics]],
day: Optional[date],
month: Optional[Month],
start_date: Optional[date],
end_date: Optional[date],
session: Session,
):
def get_sum_column(column: str):
Expand All @@ -225,25 +245,10 @@ def get_sum_column(column: str):
func.count().label("total"),
)

if topics is not None:
if Topics.others in topics:
if len(topics) > 1:
raise HTTPException(400, "Can't filter by others and other topics.")

# If others is defined in the selection, directly provide the column
filter = text(Topics.others)
else:
filter = and_(*tuple(getattr(Model, topic) for topic in topics))
others_filter = text(Topics.others)

selection = selection.filter(filter)

if day is not None or month is not None:
# If both specified, use day only
filter = (
func.date(Model.created_at) == day
if day is not None
else func.strftime("%Y-%m", Model.created_at) == month
)
selection = selection.filter(filter)
selection = get_selection_filter(
Model, topics, start_date, end_date, selection, others_filter
)

return session.exec(selection).one()
17 changes: 9 additions & 8 deletions server/app/tweets_common/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..database import get_session
from . import router
from .models import PredictionOutput, PseudoTweet, Topics, Tweet
from .types import Month
from .word_cloud_helper import get_word_count_distribution

CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours
Expand All @@ -22,8 +21,8 @@
@timed_lru_cache(seconds=CACHE_TIMEOUT, maxsize=64)
def get_word_cloud(
topics: Optional[Tuple[Topics, ...]] = Query(None),
day: Optional[date] = None,
month: Optional[Month] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
session: Session = Depends(get_session),
):
"""
Expand All @@ -32,18 +31,20 @@ def get_word_cloud(

fields = ("text",)

tweet_selection = get_filtered_selection(topics, Tweet, day, month, fields)
pseudo_tweet_selection = get_filtered_selection(
topics, PseudoTweet, day, month, fields
)
args = (topics, start_date, end_date, fields)

tweet_selection = get_filtered_selection(Tweet, *args)
pseudo_tweet_selection = get_filtered_selection(PseudoTweet, *args)

combined_model = union_all(tweet_selection, pseudo_tweet_selection).subquery().c

# Manually selected the text here, need to change if needed
combined_tweets = session.exec(select(combined_model.text)).all()

# change list of tweets to tuple to allow caching
word_freq = get_word_count_distribution(tuple(combined_tweets))
combined_tweets = tuple(combined_tweets)

word_freq = get_word_count_distribution(combined_tweets)

return word_freq

Expand Down