diff --git a/corehq/apps/dump_reload/sql/dump.py b/corehq/apps/dump_reload/sql/dump.py index 50ae345cc2eea..78b4f76a972ab 100644 --- a/corehq/apps/dump_reload/sql/dump.py +++ b/corehq/apps/dump_reload/sql/dump.py @@ -32,11 +32,13 @@ FilteredModelIteratorBuilder('form_processor.CommCareCase', SimpleFilter('domain')), FilteredModelIteratorBuilder('form_processor.CommCareCaseIndex', SimpleFilter('domain')), - FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter()), - FilteredModelIteratorBuilder('form_processor.CaseTransaction', CaseIDFilter(), {'case_id': 'gte', 'pk': 'gt'}), + FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter('form_processor.CaseAttachment')), + FilteredModelIteratorBuilder('form_processor.CaseTransaction', + CaseIDFilter('form_processor.CaseTransaction'), + {'case_id': 'gte', 'pk': 'gt'}), FilteredModelIteratorBuilder('form_processor.LedgerValue', SimpleFilter('domain')), - FilteredModelIteratorBuilder('form_processor.LedgerTransaction', CaseIDFilter()), - + FilteredModelIteratorBuilder('form_processor.LedgerTransaction', + CaseIDFilter('form_processor.LedgerTransaction')), FilteredModelIteratorBuilder('case_search.DomainsNotInCaseSearchIndex', SimpleFilter('domain')), FilteredModelIteratorBuilder('case_search.CaseSearchConfig', SimpleFilter('domain')), FilteredModelIteratorBuilder('case_search.FuzzyProperties', SimpleFilter('domain')), diff --git a/corehq/apps/dump_reload/sql/filters.py b/corehq/apps/dump_reload/sql/filters.py index 549186127d398..7248ca095edad 100644 --- a/corehq/apps/dump_reload/sql/filters.py +++ b/corehq/apps/dump_reload/sql/filters.py @@ -1,11 +1,16 @@ from abc import ABCMeta, abstractmethod from django.db.models import Q +from django.db.models.fields.related import ForeignKey from dimagi.utils.chunked import chunked +from corehq.apps.dump_reload.util import get_model_class from corehq.form_processor.models.cases import CommCareCase -from corehq.sql_db.util import paginate_query +from corehq.sql_db.util import ( + get_db_aliases_for_partitioned_query, + paginate_query, +) from corehq.util.queries import queryset_to_iterator @@ -95,13 +100,23 @@ def get_ids(self, domain_name, db_alias=None): class CaseIDFilter(IDFilter): - def __init__(self, case_field='case'): + def __init__(self, model_label, case_field='case'): + _, self.model_cls = get_model_class(model_label) + try: + case_field, = [f for f in self.model_cls._meta.fields if f.name == 'case'] + assert isinstance(case_field, ForeignKey) + assert case_field.remote_field.model == CommCareCase + except Exception: + raise ValueError( + "CaseIDFilter only supports models with a foreign key relationship to CommCareCase" + ) super().__init__(case_field, None, chunksize=500) def count(self, domain_name): - active_case_count = len(CommCareCase.objects.get_case_ids_in_domain(domain_name)) - deleted_case_count = len(CommCareCase.objects.get_deleted_case_ids_in_domain(domain_name)) - return active_case_count + deleted_case_count + count = 0 + for db in get_db_aliases_for_partitioned_query(): + count += self.model_cls.objects.using(db).filter(case__domain=domain_name).count() + return count def get_ids(self, domain_name, db_alias=None): assert db_alias, "Expected db_alias to be defined for CaseIDFilter" diff --git a/corehq/apps/dump_reload/tests/test_sql_filters.py b/corehq/apps/dump_reload/tests/test_sql_filters.py index 54a5ea2dd0ee2..780bb1906fa96 100644 --- a/corehq/apps/dump_reload/tests/test_sql_filters.py +++ b/corehq/apps/dump_reload/tests/test_sql_filters.py @@ -1,6 +1,9 @@ +from datetime import datetime + from django.test import TestCase from corehq.apps.dump_reload.sql.filters import CaseIDFilter +from corehq.form_processor.models.cases import CaseTransaction from corehq.form_processor.tests.utils import create_case from corehq.sql_db.util import get_db_aliases_for_partitioned_query @@ -11,29 +14,42 @@ class TestCaseIDFilter(TestCase): that all cases for a a domain are included in this filter's get_ids method. """ - def test_returns_cases_for_domain(self): + def test_init_raises_exception_if_used_with_model_that_does_not_foreign_key_to_case(self): + with self.assertRaises(ValueError): + CaseIDFilter('form_processor.XFormInstance') + + def test_returns_case_ids_for_domain(self): create_case('test', case_id='abc123', save=True) - filter = CaseIDFilter() + filter = CaseIDFilter('form_processor.CaseTransaction') case_ids = list(filter.get_ids('test', self.db_alias)) self.assertEqual(case_ids, ['abc123']) - def test_does_not_return_cases_from_other_domain(self): + def test_does_not_return_case_ids_from_other_domain(self): create_case('test', case_id='abc123', save=True) - filter = CaseIDFilter() + filter = CaseIDFilter('form_processor.CaseTransaction') case_ids = list(filter.get_ids('other', self.db_alias)) self.assertEqual(case_ids, []) - def test_deleted_cases_are_included(self): + def test_deleted_case_ids_are_included(self): create_case('test', case_id='abc123', save=True) create_case('test', case_id='def456', save=True, deleted=True) - filter = CaseIDFilter() + filter = CaseIDFilter('form_processor.CaseTransaction') case_ids = list(filter.get_ids('test', self.db_alias)) self.assertCountEqual(case_ids, ['abc123', 'def456']) + def test_count_correctly_counts_all_objects_related_to_case_id(self): + case1 = create_case('test', case_id='abc123', save=True) + CaseTransaction.objects.partitioned_query(case1.case_id).create( + case=case1, server_date=datetime.utcnow(), type=1 + ) + filter = CaseIDFilter('form_processor.CaseTransaction') + count = filter.count('test') + self.assertEqual(count, 2) + def test_count_includes_deleted_cases(self): create_case('test', case_id='abc123', save=True) create_case('test', case_id='def456', save=True, deleted=True) - filter = CaseIDFilter() + filter = CaseIDFilter('form_processor.CaseTransaction') count = filter.count('test') self.assertEqual(count, 2)