-
-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize case queries in dump_domain_data #34010
base: master
Are you sure you want to change the base?
Changes from 2 commits
2f0e213
1a43fd4
ed47ec6
dd5a46e
6696b9b
885b626
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
from django.db.models import Q | ||
from corehq.util.queries import queryset_to_iterator | ||
|
||
from dimagi.utils.chunked import chunked | ||
|
||
from corehq.form_processor.models.cases import CommCareCase | ||
from corehq.sql_db.util import paginate_query | ||
from corehq.util.queries import queryset_to_iterator | ||
|
||
|
||
class DomainFilter(metaclass=ABCMeta): | ||
@abstractmethod | ||
def get_filters(self, domain_name): | ||
def get_filters(self, domain_name, db_alias=None): | ||
"""Return a list of filters. Each filter will be applied to a queryset independently | ||
of the others.""" | ||
raise NotImplementedError() | ||
|
@@ -21,7 +24,7 @@ class SimpleFilter(DomainFilter): | |
def __init__(self, filter_kwarg): | ||
self.filter_kwarg = filter_kwarg | ||
|
||
def get_filters(self, domain_name): | ||
def get_filters(self, domain_name, db_alias): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does |
||
return [Q(**{self.filter_kwarg: domain_name})] | ||
|
||
|
||
|
@@ -33,7 +36,7 @@ def __init__(self, *filter_kwargs): | |
assert filter_kwargs, 'Please set one of more filter_kwargs' | ||
self.filter_kwargs = filter_kwargs | ||
|
||
def get_filters(self, domain_name): | ||
def get_filters(self, domain_name, db_alias): | ||
filter_ = Q(**{self.filter_kwargs[0]: domain_name}) | ||
for filter_kwarg in self.filter_kwargs[1:]: | ||
filter_ &= Q(**{filter_kwarg: domain_name}) | ||
|
@@ -47,7 +50,7 @@ def __init__(self, usernames=None): | |
def count(self, domain_name): | ||
return len(self.usernames) if self.usernames is not None else None | ||
|
||
def get_filters(self, domain_name): | ||
def get_filters(self, domain_name, db_alias=None): | ||
""" | ||
:return: A generator of filters each filtering for at most 500 users. | ||
""" | ||
|
@@ -72,11 +75,11 @@ def __init__(self, field, ids, chunksize=1000): | |
def count(self, domain_name): | ||
return len(self.get_ids(domain_name)) | ||
|
||
def get_ids(self, domain_name): | ||
def get_ids(self, domain_name, db_alias=None): | ||
return self.ids | ||
|
||
def get_filters(self, domain_name): | ||
for chunk in chunked(self.get_ids(domain_name), self.chunksize): | ||
def get_filters(self, domain_name, db_alias=None): | ||
for chunk in chunked(self.get_ids(domain_name, db_alias=db_alias), self.chunksize): | ||
query_kwarg = '{}__in'.format(self.field) | ||
yield Q(**{query_kwarg: chunk}) | ||
|
||
|
@@ -86,11 +89,29 @@ def __init__(self, user_id_field, include_web_users=True): | |
super().__init__(user_id_field, None) | ||
self.include_web_users = include_web_users | ||
|
||
def get_ids(self, domain_name): | ||
def get_ids(self, domain_name, db_alias=None): | ||
from corehq.apps.users.dbaccessors import get_all_user_ids_by_domain | ||
return get_all_user_ids_by_domain(domain_name, include_web_users=self.include_web_users) | ||
|
||
|
||
class CaseIDFilter(IDFilter): | ||
def __init__(self, case_field='case'): | ||
super().__init__(case_field, None) | ||
|
||
def count(self, domain_name): | ||
active_case_count = len(CommCareCase.objects.get_case_ids_in_domain(domain_name)) | ||
deleted_case_count = len(CommCareCase.objects.get_deleted_case_ids_in_domain(domain_name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loads all of the domain's case ids into memory, which I think we want to avoid. It would be better to use a The new manager method could simply raise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is just for progress, there is also the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this was lazy on my part. Looking closer at where this used, the Builder object references it here, but I don't see where the builder object's count method is called, and if I set a breakpoint in that method and run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the count method was introduced after the StatsCounter object, and it looks like in the context of ICDS #28895. Is it possible this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahhh it is used in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the |
||
return active_case_count + deleted_case_count | ||
|
||
def get_ids(self, domain_name, db_alias=None): | ||
assert db_alias, "Expected db_alias to be defined for CaseIDFilter" | ||
source = 'dump_domain_data' | ||
query = Q(domain=domain_name) | ||
for row in paginate_query(db_alias, CommCareCase, query, values=['case_id'], load_source=source): | ||
# there isn't a good way to return flattened results | ||
yield row[0] | ||
|
||
|
||
class UnfilteredModelIteratorBuilder(object): | ||
def __init__(self, model_label): | ||
self.model_label = model_label | ||
|
@@ -137,7 +158,7 @@ def count(self): | |
|
||
def querysets(self): | ||
queryset = self._base_queryset() | ||
filters = self.filter.get_filters(self.domain) | ||
filters = self.filter.get_filters(self.domain, db_alias=self.db_alias) | ||
for filter_ in filters: | ||
yield queryset.filter(filter_) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from django.test import TestCase | ||
|
||
from corehq.apps.dump_reload.sql.filters import CaseIDFilter | ||
from corehq.form_processor.tests.utils import create_case | ||
from corehq.sql_db.util import get_db_aliases_for_partitioned_query | ||
|
||
|
||
class TestCaseIDFilter(TestCase): | ||
""" | ||
Given this is used in the context of dumping all data associated with a domain, it is important | ||
that all cases for a a domain are included in this filter's get_ids method. | ||
""" | ||
|
||
def test_returns_cases_for_domain(self): | ||
create_case('test', case_id='abc123', save=True) | ||
filter = CaseIDFilter() | ||
case_ids = list(filter.get_ids('test', self.db_alias)) | ||
self.assertEqual(case_ids, ['abc123']) | ||
|
||
def test_does_not_return_cases_from_other_domain(self): | ||
create_case('test', case_id='abc123', save=True) | ||
filter = CaseIDFilter() | ||
case_ids = list(filter.get_ids('other', self.db_alias)) | ||
self.assertEqual(case_ids, []) | ||
|
||
def test_deleted_cases_are_included(self): | ||
create_case('test', case_id='abc123', save=True) | ||
create_case('test', case_id='def456', save=True, deleted=True) | ||
filter = CaseIDFilter() | ||
case_ids = list(filter.get_ids('test', self.db_alias)) | ||
self.assertCountEqual(case_ids, ['abc123', 'def456']) | ||
|
||
def test_count_includes_deleted_cases(self): | ||
create_case('test', case_id='abc123', save=True) | ||
create_case('test', case_id='def456', save=True, deleted=True) | ||
filter = CaseIDFilter() | ||
count = filter.count('test') | ||
self.assertEqual(count, 2) | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.db_alias = get_db_aliases_for_partitioned_query()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there caching in CaseIDFilter or will it re-fetch all the case IDs every time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No caching at the moment. On my first attempt I did cache the list of all case ids, but since I wanted to use pagination and a generator instead, I put that on the backburner (seemed a bit trickier to cache that result). It certainly would be useful to cache though since this filter is used in multiple places.