Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize case queries in dump_domain_data #34010

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions corehq/apps/dump_reload/sql/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from corehq.apps.dump_reload.exceptions import DomainDumpError
from corehq.apps.dump_reload.interface import DataDumper
from corehq.apps.dump_reload.sql.filters import (
CaseIDFilter,
FilteredModelIteratorBuilder,
ManyFilters,
SimpleFilter,
Expand All @@ -31,10 +32,10 @@

FilteredModelIteratorBuilder('form_processor.CommCareCase', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CommCareCaseIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', SimpleFilter('case__domain')),
FilteredModelIteratorBuilder('form_processor.CaseTransaction', SimpleFilter('case__domain')),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter()),
FilteredModelIteratorBuilder('form_processor.CaseTransaction', CaseIDFilter()),
FilteredModelIteratorBuilder('form_processor.LedgerValue', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.LedgerTransaction', SimpleFilter('case__domain')),
FilteredModelIteratorBuilder('form_processor.LedgerTransaction', CaseIDFilter()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there caching in CaseIDFilter or will it re-fetch all the case IDs every time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No caching at the moment. On my first attempt I did cache the list of all case ids, but since I wanted to use pagination and a generator instead, I put that on the backburner (seemed a bit trickier to cache that result). It certainly would be useful to cache though since this filter is used in multiple places.


FilteredModelIteratorBuilder('case_search.DomainsNotInCaseSearchIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('case_search.CaseSearchConfig', SimpleFilter('domain')),
Expand Down
41 changes: 31 additions & 10 deletions corehq/apps/dump_reload/sql/filters.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from abc import ABCMeta, abstractmethod

from django.db.models import Q
from corehq.util.queries import queryset_to_iterator

from dimagi.utils.chunked import chunked

from corehq.form_processor.models.cases import CommCareCase
from corehq.sql_db.util import paginate_query
from corehq.util.queries import queryset_to_iterator


class DomainFilter(metaclass=ABCMeta):
@abstractmethod
def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias=None):
"""Return a list of filters. Each filter will be applied to a queryset independently
of the others."""
raise NotImplementedError()
Expand All @@ -21,7 +24,7 @@ class SimpleFilter(DomainFilter):
def __init__(self, filter_kwarg):
self.filter_kwarg = filter_kwarg

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does db_alias not have a default here while it does for some of the other filters?

return [Q(**{self.filter_kwarg: domain_name})]


Expand All @@ -33,7 +36,7 @@ def __init__(self, *filter_kwargs):
assert filter_kwargs, 'Please set one of more filter_kwargs'
self.filter_kwargs = filter_kwargs

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias):
filter_ = Q(**{self.filter_kwargs[0]: domain_name})
for filter_kwarg in self.filter_kwargs[1:]:
filter_ &= Q(**{filter_kwarg: domain_name})
Expand All @@ -47,7 +50,7 @@ def __init__(self, usernames=None):
def count(self, domain_name):
return len(self.usernames) if self.usernames is not None else None

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias=None):
"""
:return: A generator of filters each filtering for at most 500 users.
"""
Expand All @@ -72,11 +75,11 @@ def __init__(self, field, ids, chunksize=1000):
def count(self, domain_name):
return len(self.get_ids(domain_name))

def get_ids(self, domain_name):
def get_ids(self, domain_name, db_alias=None):
return self.ids

def get_filters(self, domain_name):
for chunk in chunked(self.get_ids(domain_name), self.chunksize):
def get_filters(self, domain_name, db_alias=None):
for chunk in chunked(self.get_ids(domain_name, db_alias=db_alias), self.chunksize):
query_kwarg = '{}__in'.format(self.field)
yield Q(**{query_kwarg: chunk})

Expand All @@ -86,11 +89,29 @@ def __init__(self, user_id_field, include_web_users=True):
super().__init__(user_id_field, None)
self.include_web_users = include_web_users

def get_ids(self, domain_name):
def get_ids(self, domain_name, db_alias=None):
from corehq.apps.users.dbaccessors import get_all_user_ids_by_domain
return get_all_user_ids_by_domain(domain_name, include_web_users=self.include_web_users)


class CaseIDFilter(IDFilter):
def __init__(self, case_field='case'):
super().__init__(case_field, None)

def count(self, domain_name):
active_case_count = len(CommCareCase.objects.get_case_ids_in_domain(domain_name))
deleted_case_count = len(CommCareCase.objects.get_deleted_case_ids_in_domain(domain_name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loads all of the domain's case ids into memory, which I think we want to avoid. It would be better to use a .count() query per shard. This could be implemented as CommCareCaseManager.count_cases_in_domain(domain_name, include_deleted=True).

The new manager method could simply raise NotImplementedError if include_deleted is false since there is no use case for that branch at this time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just for progress, there is also the corehq.sql_db.util.estimate_row_count function which uses the query plan to get an estimated count.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was lazy on my part. Looking closer at where this used, the Builder object references it here, but I don't see where the builder object's count method is called, and if I set a breakpoint in that method and run dump_domain_data it doesn't get hit. So I'm tempted to just remove this count method altogether, but can dig a bit more to see if we made an intentional change to the StatsCounter at some point instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the count method was introduced after the StatsCounter object, and it looks like in the context of ICDS #28895. Is it possible this count code isn't currently applicable to dump_domain_data because there was only custom ICDS code that took advantage of it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh it is used in print_domain_stats

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the count should return the count of objects the CaseIDFilter is setup for, not the count of cases. This made it a bit trickier, but I updated the filter in 6696b9b to handle this and added tests to verify that behavior.

return active_case_count + deleted_case_count

def get_ids(self, domain_name, db_alias=None):
assert db_alias, "Expected db_alias to be defined for CaseIDFilter"
source = 'dump_domain_data'
query = Q(domain=domain_name)
for row in paginate_query(db_alias, CommCareCase, query, values=['case_id'], load_source=source):
# there isn't a good way to return flattened results
yield row[0]


class UnfilteredModelIteratorBuilder(object):
def __init__(self, model_label):
self.model_label = model_label
Expand Down Expand Up @@ -137,7 +158,7 @@ def count(self):

def querysets(self):
queryset = self._base_queryset()
filters = self.filter.get_filters(self.domain)
filters = self.filter.get_filters(self.domain, db_alias=self.db_alias)
for filter_ in filters:
yield queryset.filter(filter_)

Expand Down
43 changes: 43 additions & 0 deletions corehq/apps/dump_reload/tests/test_sql_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from django.test import TestCase

from corehq.apps.dump_reload.sql.filters import CaseIDFilter
from corehq.form_processor.tests.utils import create_case
from corehq.sql_db.util import get_db_aliases_for_partitioned_query


class TestCaseIDFilter(TestCase):
"""
Given this is used in the context of dumping all data associated with a domain, it is important
that all cases for a a domain are included in this filter's get_ids method.
"""

def test_returns_cases_for_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter()
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertEqual(case_ids, ['abc123'])

def test_does_not_return_cases_from_other_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter()
case_ids = list(filter.get_ids('other', self.db_alias))
self.assertEqual(case_ids, [])

def test_deleted_cases_are_included(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter()
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertCountEqual(case_ids, ['abc123', 'def456'])

def test_count_includes_deleted_cases(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter()
count = filter.count('test')
self.assertEqual(count, 2)

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.db_alias = get_db_aliases_for_partitioned_query()[0]
Loading