Skip to content

Commit

Permalink
Fix aggregate queries with case expressions (#354)
Browse files Browse the repository at this point in the history
* Fix aggregate queries with case expressions
  • Loading branch information
dauinsight authored Mar 11, 2024
1 parent 0899188 commit adc01a5
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
55 changes: 55 additions & 0 deletions mssql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import time
import struct
import datetime
from decimal import Decimal
from uuid import UUID

from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import cached_property
Expand Down Expand Up @@ -571,6 +573,36 @@ def __init__(self, cursor, connection):
self.last_sql = ''
self.last_params = ()

def _as_sql_type(self, typ, value):
if isinstance(value, str):
length = len(value)
if length == 0:
return 'NVARCHAR'
elif length > 4000:
return 'NVARCHAR(max)'
return 'NVARCHAR(%s)' % len(value)
elif typ == int:
if value < 0x7FFFFFFF and value > -0x7FFFFFFF:
return 'INT'
else:
return 'BIGINT'
elif typ == float:
return 'DOUBLE PRECISION'
elif typ == bool:
return 'BIT'
elif isinstance(value, Decimal):
return 'NUMERIC'
elif isinstance(value, datetime.datetime):
return 'DATETIME2'
elif isinstance(value, datetime.date):
return 'DATE'
elif isinstance(value, datetime.time):
return 'TIME'
elif isinstance(value, UUID):
return 'uniqueidentifier'
else:
raise NotImplementedError('Not supported type %s (%s)' % (type(value), repr(value)))

def close(self):
if self.active:
self.active = False
Expand All @@ -588,6 +620,27 @@ def format_sql(self, sql, params):

return sql

def format_group_by_params(self, query, params):
if params:
# Insert None params directly into the query
if None in params:
null_params = ['NULL' if param is None else '%s' for param in params]
query = query % tuple(null_params)
params = tuple(p for p in params if p is not None)
params = [(param, type(param)) for param in params]
params_dict = {param: '@var%d' % i for i, param in enumerate(set(params))}
args = [params_dict[param] for param in params]

variables = []
params = []
for key, value in params_dict.items():
datatype = self._as_sql_type(key[1], key[0])
variables.append("%s %s = %%s " % (value, datatype))
params.append(key[0])
query = ('DECLARE %s \n' % ','.join(variables)) + (query % tuple(args))

return query, params

def format_params(self, params):
fp = []
if params is not None:
Expand Down Expand Up @@ -616,6 +669,8 @@ def format_params(self, params):

def execute(self, sql, params=None):
self.last_sql = sql
if 'GROUP BY' in sql:
sql, params = self.format_group_by_params(sql, params)
sql = self.format_sql(sql, params)
params = self.format_params(params)
self.last_params = params
Expand Down
2 changes: 0 additions & 2 deletions testapp/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@
'aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_ref_subquery_annotation',
'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_group_by_annotation_kept',
'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_window_requires_wrapping',
'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk',
'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk_grouped',
'aggregation.tests.AggregateTestCase.test_group_by_nested_expression_with_params',
'expressions.tests.BasicExpressionsTests.test_aggregate_subquery_annotation',
'queries.test_qs_combinators.QuerySetSetOperationTests.test_union_order_with_null_first_last',
Expand Down
15 changes: 12 additions & 3 deletions testapp/tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from unittest import skipUnless

from django import VERSION
from django.db.models import IntegerField, F
from django.db.models import CharField, IntegerField, F
from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When, ExpressionWrapper
from django.test import TestCase, skipUnlessDBFeature

from django.db.models.aggregates import Count
from ..models import Author, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes
from django.db.models.aggregates import Count, Sum

from ..models import Author, Book, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes


DJANGO3 = VERSION[0] >= 3
Expand Down Expand Up @@ -85,6 +86,14 @@ def test_order_by_exists(self):
self.assertSequenceEqual(authors_by_posts, [author_without_posts, self.author])


class TestGroupBy(TestCase):
def test_group_by_case(self):
annotated_queryset = Book.objects.annotate(age=Case(
When(id__gt=1000, then=Value("new")),
default=Value("old"),
output_field=CharField())).values('age').annotate(sum=Sum('id'))
self.assertEqual(list(annotated_queryset.all()), [])

@skipUnless(DJANGO3, "Django 3 specific tests")
@skipUnlessDBFeature("order_by_nulls_first")
class TestOrderBy(TestCase):
Expand Down

1 comment on commit adc01a5

@henrikek
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function format_group_by_params in this commit creates problems. If the SQL query contains the percent sign that should remain, for example in a LIKE SQL query, there will be a problem on line 619 in the format_sql function as the escaped percent sign has already been handled in line 640 in the format_group_by_params function. So you get an exception TypeError: not enough arguments for format string. So the solution should be to somehow merge the functions format_group_by_params and format_sql as they destroy each other's SQL code. I have tried to find a good solution for a code change but have not succeeded.

Please sign in to comment.