Skip to content

Commit

Permalink
Fix issue with count method on CaseIDFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
gherceg committed Jan 26, 2024
1 parent dd5a46e commit 29d55f7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
10 changes: 6 additions & 4 deletions corehq/apps/dump_reload/sql/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@

FilteredModelIteratorBuilder('form_processor.CommCareCase', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CommCareCaseIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter()),
FilteredModelIteratorBuilder('form_processor.CaseTransaction', CaseIDFilter(), {'case_id': 'gte', 'pk': 'gt'}),
FilteredModelIteratorBuilder('form_processor.CaseAttachment', CaseIDFilter('form_processor.CaseAttachment')),
FilteredModelIteratorBuilder('form_processor.CaseTransaction',
CaseIDFilter('form_processor.CaseTransaction'),
{'case_id': 'gte', 'pk': 'gt'}),
FilteredModelIteratorBuilder('form_processor.LedgerValue', SimpleFilter('domain')),
FilteredModelIteratorBuilder('form_processor.LedgerTransaction', CaseIDFilter()),

FilteredModelIteratorBuilder('form_processor.LedgerTransaction',
CaseIDFilter('form_processor.LedgerTransaction')),
FilteredModelIteratorBuilder('case_search.DomainsNotInCaseSearchIndex', SimpleFilter('domain')),
FilteredModelIteratorBuilder('case_search.CaseSearchConfig', SimpleFilter('domain')),
FilteredModelIteratorBuilder('case_search.FuzzyProperties', SimpleFilter('domain')),
Expand Down
25 changes: 20 additions & 5 deletions corehq/apps/dump_reload/sql/filters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from abc import ABCMeta, abstractmethod

from django.db.models import Q
from django.db.models.fields.related import ForeignKey

from dimagi.utils.chunked import chunked

from corehq.apps.dump_reload.util import get_model_class
from corehq.form_processor.models.cases import CommCareCase
from corehq.sql_db.util import paginate_query
from corehq.sql_db.util import (
get_db_aliases_for_partitioned_query,
paginate_query,
)
from corehq.util.queries import queryset_to_iterator


Expand Down Expand Up @@ -95,13 +100,23 @@ def get_ids(self, domain_name, db_alias=None):


class CaseIDFilter(IDFilter):
def __init__(self, case_field='case'):
def __init__(self, model_label, case_field='case'):
_, self.model_cls = get_model_class(model_label)
try:
case_field, = [f for f in self.model_cls._meta.fields if f.name == 'case']
assert isinstance(case_field, ForeignKey)
assert case_field.remote_field.model == CommCareCase
except Exception:
raise ValueError(
"CaseIDFilter only supports models with a foreign key relationship to CommCareCase"
)
super().__init__(case_field, None, chunksize=500)

def count(self, domain_name):
active_case_count = len(CommCareCase.objects.get_case_ids_in_domain(domain_name))
deleted_case_count = len(CommCareCase.objects.get_deleted_case_ids_in_domain(domain_name))
return active_case_count + deleted_case_count
count = 0
for db in get_db_aliases_for_partitioned_query():
count += self.model_cls.objects.using(db).filter(case__domain=domain_name).count()
return count

def get_ids(self, domain_name, db_alias=None):
assert db_alias, "Expected db_alias to be defined for CaseIDFilter"
Expand Down
30 changes: 23 additions & 7 deletions corehq/apps/dump_reload/tests/test_sql_filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import datetime

from django.test import TestCase

from corehq.apps.dump_reload.sql.filters import CaseIDFilter
from corehq.form_processor.models.cases import CaseTransaction
from corehq.form_processor.tests.utils import create_case
from corehq.sql_db.util import get_db_aliases_for_partitioned_query

Expand All @@ -11,29 +14,42 @@ class TestCaseIDFilter(TestCase):
that all cases for a a domain are included in this filter's get_ids method.
"""

def test_returns_cases_for_domain(self):
def test_init_raises_exception_if_used_with_model_that_does_not_foreign_key_to_case(self):
with self.assertRaises(ValueError):
CaseIDFilter('form_processor.XFormInstance')

def test_returns_case_ids_for_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter()
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertEqual(case_ids, ['abc123'])

def test_does_not_return_cases_from_other_domain(self):
def test_does_not_return_case_ids_from_other_domain(self):
create_case('test', case_id='abc123', save=True)
filter = CaseIDFilter()
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('other', self.db_alias))
self.assertEqual(case_ids, [])

def test_deleted_cases_are_included(self):
def test_deleted_case_ids_are_included(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter()
filter = CaseIDFilter('form_processor.CaseTransaction')
case_ids = list(filter.get_ids('test', self.db_alias))
self.assertCountEqual(case_ids, ['abc123', 'def456'])

def test_count_correctly_counts_all_objects_related_to_case_id(self):
case1 = create_case('test', case_id='abc123', save=True)
CaseTransaction.objects.partitioned_query(case1.case_id).create(
case=case1, server_date=datetime.utcnow(), type=1
)
filter = CaseIDFilter('form_processor.CaseTransaction')
count = filter.count('test')
self.assertEqual(count, 2)

def test_count_includes_deleted_cases(self):
create_case('test', case_id='abc123', save=True)
create_case('test', case_id='def456', save=True, deleted=True)
filter = CaseIDFilter()
filter = CaseIDFilter('form_processor.CaseTransaction')
count = filter.count('test')
self.assertEqual(count, 2)

Expand Down

0 comments on commit 29d55f7

Please sign in to comment.