Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize case queries in dump_domain_data #34010

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions corehq/apps/dump_reload/sql/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from corehq.apps.dump_reload.exceptions import DomainDumpError
from corehq.apps.dump_reload.interface import DataDumper
from corehq.apps.dump_reload.sql.filters import (
CaseIDFilter,
FilteredModelIteratorBuilder,
ManyFilters,
SimpleFilter,
Expand All @@ -31,11 +32,13 @@

FilteredModelIteratorBuilder('form_processor.CommCareCase', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CommCareCaseIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', SimpleFilter('case__domain')),
FilteredModelIteratorBuilder('form_processor.CaseTransaction', SimpleFilter('case__domain')),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter('form_processor.CaseAttachment')),
FilteredModelIteratorBuilder('form_processor.CaseTransaction',
CaseIDFilter('form_processor.CaseTransaction'),
{'case_id': 'gte', 'pk': 'gt'}),
FilteredModelIteratorBuilder('form_processor.LedgerValue', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.LedgerTransaction', SimpleFilter('case__domain')),

FilteredModelIteratorBuilder('form_processor.LedgerTransaction',
CaseIDFilter('form_processor.LedgerTransaction')),
FilteredModelIteratorBuilder('case_search.DomainsNotInCaseSearchIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('case_search.CaseSearchConfig', SimpleFilter('domain')),
FilteredModelIteratorBuilder('case_search.FuzzyProperties', SimpleFilter('domain')),
Expand Down
63 changes: 52 additions & 11 deletions corehq/apps/dump_reload/sql/filters.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from abc import ABCMeta, abstractmethod

from django.db.models import Q
from corehq.util.queries import queryset_to_iterator
from django.db.models.fields.related import ForeignKey

from dimagi.utils.chunked import chunked

from corehq.apps.dump_reload.util import get_model_class
from corehq.form_processor.models.cases import CommCareCase
from corehq.sql_db.util import (
get_db_aliases_for_partitioned_query,
paginate_query,
)
from corehq.util.queries import queryset_to_iterator


class DomainFilter(metaclass=ABCMeta):
@abstractmethod
def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias=None):
"""Return a list of filters. Each filter will be applied to a queryset independently
of the others."""
raise NotImplementedError()
Expand All @@ -21,7 +29,7 @@ class SimpleFilter(DomainFilter):
def __init__(self, filter_kwarg):
self.filter_kwarg = filter_kwarg

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does db_alias not have a default here while it does for some of the other filters?

return [Q(**{self.filter_kwarg: domain_name})]


Expand All @@ -33,7 +41,7 @@ def __init__(self, *filter_kwargs):
assert filter_kwargs, 'Please set one of more filter_kwargs'
self.filter_kwargs = filter_kwargs

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias):
filter_ = Q(**{self.filter_kwargs[0]: domain_name})
for filter_kwarg in self.filter_kwargs[1:]:
filter_ &= Q(**{filter_kwarg: domain_name})
Expand All @@ -47,7 +55,7 @@ def __init__(self, usernames=None):
def count(self, domain_name):
return len(self.usernames) if self.usernames is not None else None

def get_filters(self, domain_name):
def get_filters(self, domain_name, db_alias=None):
"""
:return: A generator of filters each filtering for at most 500 users.
"""
Expand All @@ -72,11 +80,11 @@ def __init__(self, field, ids, chunksize=1000):
def count(self, domain_name):
return len(self.get_ids(domain_name))

def get_ids(self, domain_name):
def get_ids(self, domain_name, db_alias=None):
return self.ids

def get_filters(self, domain_name):
for chunk in chunked(self.get_ids(domain_name), self.chunksize):
def get_filters(self, domain_name, db_alias=None):
for chunk in chunked(self.get_ids(domain_name, db_alias=db_alias), self.chunksize):
query_kwarg = '{}__in'.format(self.field)
yield Q(**{query_kwarg: chunk})

Expand All @@ -86,11 +94,39 @@ def __init__(self, user_id_field, include_web_users=True):
super().__init__(user_id_field, None)
self.include_web_users = include_web_users

def get_ids(self, domain_name):
def get_ids(self, domain_name, db_alias=None):
from corehq.apps.users.dbaccessors import get_all_user_ids_by_domain
return get_all_user_ids_by_domain(domain_name, include_web_users=self.include_web_users)


class CaseIDFilter(IDFilter):
def __init__(self, model_label, case_field='case'):
_, self.model_cls = get_model_class(model_label)
try:
field_obj, = [f for f in self.model_cls._meta.fields if f.name == case_field]
assert isinstance(field_obj, ForeignKey)
assert field_obj.remote_field.model == CommCareCase
except Exception:
raise ValueError(
"CaseIDFilter only supports models with a foreign key relationship to CommCareCase"
)
super().__init__(case_field, None, chunksize=500)

def count(self, domain_name):
count = 0
for db in get_db_aliases_for_partitioned_query():
count += self.model_cls.objects.using(db).filter(case__domain=domain_name).count()
return count

def get_ids(self, domain_name, db_alias=None):
assert db_alias, "Expected db_alias to be defined for CaseIDFilter"
source = 'dump_domain_data'
query = Q(domain=domain_name)
for row in paginate_query(db_alias, CommCareCase, query, values=['case_id'], load_source=source):
# there isn't a good way to return flattened results
yield row[0]


class UnfilteredModelIteratorBuilder(object):
def __init__(self, model_label):
self.model_label = model_label
Expand Down Expand Up @@ -122,9 +158,14 @@ def build(self, domain, model_class, db_alias):


class FilteredModelIteratorBuilder(UnfilteredModelIteratorBuilder):
def __init__(self, model_label, filter):
def __init__(self, model_label, filter, paginate_by={}):
"""
:param paginate_by: optional dictionary of {field: conditional, ...} (e.g., {'username': 'gt'})
NOTE: the order of keys matters in this dictionary, as it dictates sort order.
"""
super(FilteredModelIteratorBuilder, self).__init__(model_label)
self.filter = filter
self.paginate_by = paginate_by

def build(self, domain, model_class, db_alias):
return self.__class__(self.model_label, self.filter).prepare(domain, model_class, db_alias)
Expand All @@ -137,7 +178,7 @@ def count(self):

def querysets(self):
queryset = self._base_queryset()
filters = self.filter.get_filters(self.domain)
filters = self.filter.get_filters(self.domain, db_alias=self.db_alias)
for filter_ in filters:
yield queryset.filter(filter_)

Expand Down
59 changes: 59 additions & 0 deletions corehq/apps/dump_reload/tests/test_sql_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from datetime import datetime

from django.test import TestCase

from corehq.apps.dump_reload.sql.filters import CaseIDFilter
from corehq.form_processor.models.cases import CaseTransaction
from corehq.form_processor.tests.utils import create_case
from corehq.sql_db.util import get_db_aliases_for_partitioned_query


class TestCaseIDFilter(TestCase):
"""
Given this is used in the context of dumping all data associated with a domain, it is important
that all cases for a a domain are included in this filter's get_ids method.
"""

def test_init_raises_exception_if_used_with_model_that_does_not_foreign_key_to_case(self):
with self.assertRaises(ValueError):
CaseIDFilter('form_processor.XFormInstance')

def test_returns_case_ids_for_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertEqual(case_ids, ['abc123'])

def test_does_not_return_case_ids_from_other_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('other', self.db_alias))
self.assertEqual(case_ids, [])

def test_deleted_case_ids_are_included(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertCountEqual(case_ids, ['abc123', 'def456'])

def test_count_correctly_counts_all_objects_related_to_case_id(self):
case1 = create_case('test', case_id='abc123', save=True)
CaseTransaction.objects.partitioned_query(case1.case_id).create(
case=case1, server_date=datetime.utcnow(), type=1
)
filter = CaseIDFilter('form_processor.CaseTransaction')
count = filter.count('test')
self.assertEqual(count, 2)

def test_count_includes_deleted_cases(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter('form_processor.CaseTransaction')
count = filter.count('test')
self.assertEqual(count, 2)

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.db_alias = get_db_aliases_for_partitioned_query()[0]
26 changes: 19 additions & 7 deletions corehq/util/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,33 @@ def paginated_queryset(queryset, chunk_size):
return


def queryset_to_iterator(queryset, model_cls, limit=500, ignore_ordering=False):
def queryset_to_iterator(queryset, model_cls, limit=500, ignore_ordering=False, paginate_by={}):
"""
Pull from queryset in chunks. This is suitable for deep pagination, but
cannot be used with ordered querysets (results will be sorted by pk).
:param paginate_by: optional dictionary of {field: conditional,} to specify what fields pagination should key
off of and how. This means that order of fields matters, since this prioritizes how sorting will be done.
"""
if queryset.ordered and not ignore_ordering:
raise AssertionError("queryset_to_iterator does not respect ordering. "
"Pass ignore_ordering=True to continue.")

pk_field = model_cls._meta.pk.name
queryset = queryset.order_by(pk_field)
if not paginate_by:
pk_field = model_cls._meta.pk.name
paginate_by = {pk_field: "gt"}

queryset = queryset.order_by(*list(paginate_by.keys()))
docs = queryset[:limit]
while docs:
for doc in docs:
yield doc
yield from docs

if len(docs) < limit:
break

last_doc = docs[len(docs) - 1]
last_doc_values = {}
for field, condition in paginate_by.items():
key = f"{field}__{condition}"
last_doc_values[key] = getattr(last_doc, field)

last_doc_pk = getattr(doc, pk_field)
docs = queryset.filter(**{pk_field + "__gt": last_doc_pk})[:limit]
docs = queryset.filter(**last_doc_values)[:limit]
91 changes: 64 additions & 27 deletions corehq/util/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,82 @@


class TestQuerysetToIterator(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.users = [
User.objects.create_user(f'user{i}@example.com', last_name="Tenenbaum")
for i in range(1, 11)
]

@classmethod
def tearDownClass(cls):
for user in cls.users:
user.delete()
super().tearDownClass()
def test_correct_results_are_returned(self):
query = User.objects.filter(last_name="Tenenbaum")

results = list(queryset_to_iterator(query, User, limit=10))

self.assertEqual(
[u.username for u in results],
[u.username for u in self.users],
)

def test_results_returned_in_one_query_if_limit_is_greater_than_result_size(self):
query = User.objects.filter(last_name="Tenenbaum")

with self.assertNumQueries(1):
# query 1: Users 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
results = list(queryset_to_iterator(query, User, limit=11))

self.assertEqual(len(results), 10)

def test_results_returned_in_two_queries_if_limit_is_equal_to_result_size(self):
query = User.objects.filter(last_name="Tenenbaum")

def test_queryset_to_iterator(self):
with self.assertNumQueries(2):
# query 1: Users 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
# query 2: Check that there are no users past #10
results = list(queryset_to_iterator(query, User, limit=10))

self.assertEqual(len(results), 10)

def test_results_return_in_three_queries_if_limit_is_less_than_or_equal_to_half_of_result_size(self):
query = User.objects.filter(last_name="Tenenbaum")
self.assertEqual(query.count(), 10)

with self.assertNumQueries(4):
with self.assertNumQueries(3):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@esoergel given this behavior was defined in tests as well, I wanted to clarify if it was intentional. Basically, is there a concern that breaking the pagination loop when the # of docs returned in for a page is less than the limit set could lead to prematurely exiting pagination? This maps to the addition of:

        if doc_count < limit:
            break

in queryset_to_iterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also I assumed you added this based on commit history, but I admittedly did not look at it very hard so if you don't have context, totally fine)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think I added this - that does seem reasonable, not sure why I didn't do that in the first place. Reviewing this now, my initial thought was that the dataset can change while the query is executing, which could cause weirdness, but the limit is part of that last query, so the number of results returned should be a valid data point in determining whether there are any remaining.

# query 1: Users 1, 2, 3, 4
# query 2: Users 5, 6, 7, 8
# query 3: Users 9, 10
# query 4: Check that there are no users past #10
all_users = list(queryset_to_iterator(query, User, limit=4))
results = list(queryset_to_iterator(query, User, limit=4))

self.assertEqual(
[u.username for u in all_users],
[u.username for u in self.users],
)
self.assertEqual(len(results), 10)

def test_ordered_queryset(self):
def test_ordered_queryset_raises_assertion_error_when_ignore_ordering_is_false(self):
query = User.objects.filter(last_name="Tenenbaum").order_by('username')

with self.assertRaises(AssertionError):
# ignore_ordering defaults to False
list(queryset_to_iterator(query, User, limit=4))

def test_ordered_queryset_ignored(self):
def test_ordered_queryset_does_not_raise_assertion_error_when_ignore_ordering_is_true(self):
query = User.objects.filter(last_name="Tenenbaum").order_by('username')
all_users = list(queryset_to_iterator(query, User, limit=4, ignore_ordering=True))
# test succeeds is AssertionError is not raised
list(queryset_to_iterator(query, User, limit=4, ignore_ordering=True))

def test_results_ordered_by_pagination_key_when_paginate_by_is_defined(self):
query = User.objects.filter(last_name="Tenenbaum")

results = list(queryset_to_iterator(query, User, limit=4, paginate_by={"username": "gt"}))

self.assertEqual(
[u.username for u in all_users],
[u.username for u in self.users],
)
[u.username for u in results],
['alice-user4@example.com',
'alice-user8@example.com',
'bob-user1@example.com',
'bob-user5@example.com',
'bob-user9@example.com',
'jane-user3@example.com',
'jane-user7@example.com',
'john-user10@example.com',
'john-user2@example.com',
'john-user6@example.com'])

@classmethod
def setUpClass(cls):
super().setUpClass()
first_names = ['alice', 'bob', 'john', 'jane']
cls.users = [
User.objects.create_user(f'{first_names[i % 4]}-user{i}@example.com', last_name="Tenenbaum")
for i in range(1, 11)
]
Loading