From adc01a58af835f0ba08c69947b081cea0b6049d1 Mon Sep 17 00:00:00 2001 From: dauinsight <145612907+dauinsight@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:33:45 -0700 Subject: [PATCH] Fix aggregate queries with case expressions (#354) * Fix aggregate queries with case expressions --- mssql/base.py | 55 +++++++++++++++++++++++++++++++ testapp/settings.py | 2 -- testapp/tests/test_expressions.py | 15 +++++++-- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/mssql/base.py b/mssql/base.py index 033660c..8adc94b 100644 --- a/mssql/base.py +++ b/mssql/base.py @@ -9,6 +9,8 @@ import time import struct import datetime +from decimal import Decimal +from uuid import UUID from django.core.exceptions import ImproperlyConfigured from django.utils.functional import cached_property @@ -571,6 +573,36 @@ def __init__(self, cursor, connection): self.last_sql = '' self.last_params = () + def _as_sql_type(self, typ, value): + if isinstance(value, str): + length = len(value) + if length == 0: + return 'NVARCHAR' + elif length > 4000: + return 'NVARCHAR(max)' + return 'NVARCHAR(%s)' % len(value) + elif typ == int: + if value < 0x7FFFFFFF and value > -0x7FFFFFFF: + return 'INT' + else: + return 'BIGINT' + elif typ == float: + return 'DOUBLE PRECISION' + elif typ == bool: + return 'BIT' + elif isinstance(value, Decimal): + return 'NUMERIC' + elif isinstance(value, datetime.datetime): + return 'DATETIME2' + elif isinstance(value, datetime.date): + return 'DATE' + elif isinstance(value, datetime.time): + return 'TIME' + elif isinstance(value, UUID): + return 'uniqueidentifier' + else: + raise NotImplementedError('Not supported type %s (%s)' % (type(value), repr(value))) + def close(self): if self.active: self.active = False @@ -588,6 +620,27 @@ def format_sql(self, sql, params): return sql + def format_group_by_params(self, query, params): + if params: + # Insert None params directly into the query + if None in params: + null_params = ['NULL' if param is None else '%s' for param in params] + query = query % tuple(null_params) + params = tuple(p for p in params if p is not None) + params = [(param, type(param)) for param in params] + params_dict = {param: '@var%d' % i for i, param in enumerate(set(params))} + args = [params_dict[param] for param in params] + + variables = [] + params = [] + for key, value in params_dict.items(): + datatype = self._as_sql_type(key[1], key[0]) + variables.append("%s %s = %%s " % (value, datatype)) + params.append(key[0]) + query = ('DECLARE %s \n' % ','.join(variables)) + (query % tuple(args)) + + return query, params + def format_params(self, params): fp = [] if params is not None: @@ -616,6 +669,8 @@ def format_params(self, params): def execute(self, sql, params=None): self.last_sql = sql + if 'GROUP BY' in sql: + sql, params = self.format_group_by_params(sql, params) sql = self.format_sql(sql, params) params = self.format_params(params) self.last_params = params diff --git a/testapp/settings.py b/testapp/settings.py index 21d7d2f..2fe5aad 100644 --- a/testapp/settings.py +++ b/testapp/settings.py @@ -284,8 +284,6 @@ 'aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_ref_subquery_annotation', 'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_group_by_annotation_kept', 'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_window_requires_wrapping', - 'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk', - 'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk_grouped', 'aggregation.tests.AggregateTestCase.test_group_by_nested_expression_with_params', 'expressions.tests.BasicExpressionsTests.test_aggregate_subquery_annotation', 'queries.test_qs_combinators.QuerySetSetOperationTests.test_union_order_with_null_first_last', diff --git a/testapp/tests/test_expressions.py b/testapp/tests/test_expressions.py index 2cfeb11..be7794e 100644 --- a/testapp/tests/test_expressions.py +++ b/testapp/tests/test_expressions.py @@ -5,12 +5,13 @@ from unittest import skipUnless from django import VERSION -from django.db.models import IntegerField, F +from django.db.models import CharField, IntegerField, F from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When, ExpressionWrapper from django.test import TestCase, skipUnlessDBFeature -from django.db.models.aggregates import Count -from ..models import Author, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes +from django.db.models.aggregates import Count, Sum + +from ..models import Author, Book, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes DJANGO3 = VERSION[0] >= 3 @@ -85,6 +86,14 @@ def test_order_by_exists(self): self.assertSequenceEqual(authors_by_posts, [author_without_posts, self.author]) +class TestGroupBy(TestCase): + def test_group_by_case(self): + annotated_queryset = Book.objects.annotate(age=Case( + When(id__gt=1000, then=Value("new")), + default=Value("old"), + output_field=CharField())).values('age').annotate(sum=Sum('id')) + self.assertEqual(list(annotated_queryset.all()), []) + @skipUnless(DJANGO3, "Django 3 specific tests") @skipUnlessDBFeature("order_by_nulls_first") class TestOrderBy(TestCase):