diff --git a/corehq/apps/dump_reload/sql/filters.py b/corehq/apps/dump_reload/sql/filters.py index 165a3ca2065b..68b66abe0383 100644 --- a/corehq/apps/dump_reload/sql/filters.py +++ b/corehq/apps/dump_reload/sql/filters.py @@ -12,6 +12,9 @@ def get_filters(self, domain_name): of the others.""" raise NotImplementedError() + def count(self, domain_name): + return None + class SimpleFilter(DomainFilter): def __init__(self, filter_kwarg): @@ -40,6 +43,9 @@ class UsernameFilter(DomainFilter): def __init__(self, usernames=None): self.usernames = usernames + def count(self, domain_name): + return len(self.usernames) if self.usernames is not None else None + def get_filters(self, domain_name): """ :return: A generator of filters each filtering for at most 500 users. @@ -61,6 +67,9 @@ def __init__(self, field, ids): self.field = field self.ids = ids + def count(self, domain_name): + return len(self.get_ids(domain_name)) + def get_ids(self, domain_name): return self.ids @@ -97,7 +106,10 @@ def _base_queryset(self): return objects.using(self.db_alias).order_by(self.model_class._meta.pk.name) def querysets(self): - return self._base_queryset() + yield self._base_queryset() + + def count(self): + return sum(q.count() for q in self.querysets()) def iterators(self): for queryset in self.querysets(): @@ -115,11 +127,17 @@ def __init__(self, model_label, filter): def build(self, domain, model_class, db_alias): return self.__class__(self.model_label, self.filter).prepare(domain, model_class, db_alias) + def count(self): + count = self.filter.count(self.domain) + if count is not None: + return count + return super(FilteredModelIteratorBuilder, self).count() + def querysets(self): queryset = self._base_queryset() filters = self.filter.get_filters(self.domain) - for filter in filters: - yield queryset.filter(filter) + for filter_ in filters: + yield queryset.filter(filter_) class UniqueFilteredModelIteratorBuilder(FilteredModelIteratorBuilder):