Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Django 5.0 db_default support #313

Merged
merged 15 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:

(Get-Content $pwd/testapp/settings.py).replace('localhost', $IP) | Set-Content $pwd/testapp/settings.py

Invoke-WebRequest https://download.microsoft.com/download/E/6/B/E6BFDC7A-5BCD-4C51-9912-635646DA801E/en-US/17.5.2.1/x64/msodbcsql.msi -OutFile msodbcsql.msi
Invoke-WebRequest https://download.microsoft.com/download/6/f/f/6ffefc73-39ab-4cc0-bb7c-4093d64c2669/en-US/17.10.5.1/x64/msodbcsql.msi -OutFile msodbcsql.msi
msiexec /quiet /passive /qn /i msodbcsql.msi IACCEPTMSODBCSQLLICENSETERMS=YES
Get-OdbcDriver
displayName: Install ODBC
Expand Down
2 changes: 1 addition & 1 deletion mssql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def format_sql(self, sql, params):
sql = smart_str(sql, self.driver_charset)

# pyodbc uses '?' instead of '%s' as parameter placeholder.
if params is not None:
if params is not None and params != []:
sql = sql % tuple('?' * len(params))

return sql
Expand Down
5 changes: 4 additions & 1 deletion mssql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_introspect_small_integer_field = True
can_return_columns_from_insert = True
can_return_id_from_insert = True
can_return_rows_from_bulk_insert = False
can_return_rows_from_bulk_insert = True
can_rollback_ddl = True
can_use_chunked_reads = False
for_update_after_from = True
Expand Down Expand Up @@ -56,6 +56,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_partially_nullable_unique_constraints = True
supports_partial_indexes = True
supports_functions_in_partial_indexes = True
supports_default_keyword_in_insert = True
supports_expression_defaults = True
supports_default_keyword_in_bulk_insert = True

@cached_property
def has_zoneinfo_database(self):
Expand Down
28 changes: 28 additions & 0 deletions mssql/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_constraints(self, cursor, table_name):
# Potentially misleading: primary key and unique constraints still have indexes attached to them.
# Should probably be updated with the additional info from the sys.indexes table we fetch later on.
"index": False,
"default": False,
}
# Record the details
constraints[constraint]['columns'].append(column)
Expand Down Expand Up @@ -313,6 +314,32 @@ def get_constraints(self, cursor, table_name):
"foreign_key": None,
"check": True,
"index": False,
"default": False,
}
# Record the details
constraints[constraint]['columns'].append(column)
# Now get DEFAULT constraint columns
cursor.execute(f"""
SELECT d.name AS constraint_name, pc.name AS column_name
FROM sys.default_constraints AS d
INNER JOIN sys.columns pc ON
d.parent_object_id = pc.object_id AND
d.parent_column_id = pc.column_id
WHERE
type_desc = 'DEFAULT_CONSTRAINT'
""")
for constraint, column in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
"columns": [],
"primary_key": False,
"unique": False,
"unique_constraint": False,
"foreign_key": None,
"check": False,
"index": False,
"default": True,
}
# Record the details
constraints[constraint]['columns'].append(column)
Expand Down Expand Up @@ -356,6 +383,7 @@ def get_constraints(self, cursor, table_name):
"unique_constraint": unique_constraint,
"foreign_key": None,
"check": False,
"default": False,
"index": True,
"orders": [],
"type": Index.suffix if type_ in (1, 2) else desc.lower(),
Expand Down
90 changes: 85 additions & 5 deletions mssql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Table,
)
from django import VERSION as django_version
from django.db.models import Index, UniqueConstraint
from django.db.models import NOT_PROVIDED, Index, UniqueConstraint
from django.db.models.fields import AutoField, BigAutoField
from django.db.models.sql.where import AND
from django.db.transaction import TransactionManagementError
Expand Down Expand Up @@ -69,6 +69,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_alter_column_type = "ALTER COLUMN %(column)s %(type)s"
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_delete_default = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_delete_table = """
DECLARE @sql_foreign_constraint_name nvarchar(128)
Expand Down Expand Up @@ -138,6 +139,48 @@ def _alter_column_default_sql(self, model, old_field, new_field, drop=False):
},
params,
)

def _alter_column_database_default_sql(
self, model, old_field, new_field, drop=False
):
"""
Hook to specialize column database default alteration.

Return a (sql, params) fragment to add or drop (depending on the drop
argument) a default to new_field's column.
"""
column = self.quote_name(new_field.column)

if drop:
# SQL Server requires the name of the default constraint
result = self.execute(
self._sql_select_default_constraint_name % {
"table": self.quote_value(model._meta.db_table),
"column": self.quote_value(new_field.column),
},
has_result=True
)
if result:
for row in result:
column = self.quote_name(next(iter(row)))

sql = self.sql_alter_column_no_default
default_sql = ""
params = []
else:
sql = self.sql_alter_column_default
default_sql, params = self.db_default_sql(new_field)

new_db_params = new_field.db_parameters(connection=self.connection)
return (
sql
% {
"column": column,
"type": new_db_params["type"],
"default": default_sql,
},
params,
)

def _alter_column_null_sql(self, model, old_field, new_field):
"""
Expand Down Expand Up @@ -460,6 +503,22 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
self._delete_unique_constraints(model, old_field, new_field, strict)
# Drop indexes, SQL Server requires explicit deletion
self._delete_indexes(model, old_field, new_field)
# db_default change?
if django_version >= (5,0):
if new_field.db_default is not NOT_PROVIDED:
if (
old_field.db_default is NOT_PROVIDED
or new_field.db_default != old_field.db_default
):
actions.append(
self._alter_column_database_default_sql(model, old_field, new_field)
)
elif old_field.db_default is not NOT_PROVIDED:
actions.append(
self._alter_column_database_default_sql(
model, old_field, new_field, drop=True
)
)
# When changing a column NULL constraint to NOT NULL with a given
# default value, we need to perform 4 steps:
# 1. Add a default for new incoming writes
Expand All @@ -476,6 +535,8 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
new_default is not None and
not self.skip_default(new_field)
)
if django_version >= (5,0):
needs_database_default = needs_database_default and new_field.db_default is NOT_PROVIDED
if needs_database_default:
actions.append(self._alter_column_default_sql(model, old_field, new_field))
# Nullability change?
Expand Down Expand Up @@ -503,7 +564,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
post_actions.append((create_index_sql_statement, ()))
# Only if we have a default and there is a change from NULL to NOT NULL
four_way_default_alteration = (
new_field.has_default() and
(new_field.has_default() or (django_version >= (5,0) and new_field.db_default is not NOT_PROVIDED)) and
(old_field.null and not new_field.null)
)
if actions or null_actions:
Expand All @@ -525,14 +586,19 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
params,
)
if four_way_default_alteration:
if django_version >= (5,0) and new_field.db_default is not NOT_PROVIDED:
default_sql, params = self.db_default_sql(new_field)
else:
default_sql = "%s"
params = [new_default]
# Update existing rows with default value
self.execute(
self.sql_update_with_default % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(new_field.column),
"default": "%s",
"default": default_sql,
},
[new_default],
params,
)
# Since we didn't run a NOT NULL change before we need to do it
# now
Expand Down Expand Up @@ -891,6 +957,9 @@ def add_field(self, model, field):
# It might not actually have a column behind it
if definition is None:
return
# Nullable columns with default values require 'WITH VALUES' to set existing rows
if 'DEFAULT' in definition and field.null:
definition = definition.replace('NULL', 'WITH VALUES')

if (self.connection.features.supports_nullable_unique_constraints and
not field.many_to_many and field.null and field.unique):
Expand All @@ -915,7 +984,11 @@ def add_field(self, model, field):
self.execute(sql, params)
# Drop the default if we need to
# (Django usually does not use in-database defaults)
if not self.skip_default(field) and self.effective_default(field) is not None:
if (
((django_version >= (5,0) and field.db_default is NOT_PROVIDED) or django_version < (5,0))
and not self.skip_default(field)
and self.effective_default(field) is not None
):
changes_sql, params = self._alter_column_default_sql(model, None, field, drop=True)
sql = self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
Expand Down Expand Up @@ -1288,6 +1361,13 @@ def remove_field(self, model, field):
"table": self.quote_name(model._meta.db_table),
"name": self.quote_name(name),
})
# Drop default constraint, SQL Server requires explicit deletion
for name, infodict in constraints.items():
if field.column in infodict['columns'] and infodict['default']:
self.execute(self.sql_delete_default % {
"table": self.quote_name(model._meta.db_table),
"name": self.quote_name(name),
})
# Delete the column
sql = self.sql_delete_column % {
"table": self.quote_name(model._meta.db_table),
Expand Down
Loading