Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where order is wrong after adding objects #56

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 21 additions & 48 deletions sortedm2m/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from django.utils.functional import curry

from .compat import get_foreignkey_field_kwargs
from .compat import get_model
from .compat import get_model_name
from .forms import SortedMultipleChoiceField

Expand Down Expand Up @@ -93,7 +92,7 @@ def _add_items(self, source_field_name, target_field_name, *objs):
# *objs - objects to add. Either object instances, or primary keys of object instances.

# If there aren't any objects, there is nothing to do.
from django.db.models import Model
from django.db.models import Max, Model
if objs:
# Django uses a set here, we need to use a list to keep the
# correct ordering.
Expand Down Expand Up @@ -122,7 +121,8 @@ def _add_items(self, source_field_name, target_field_name, *objs):
new_ids.append(obj)

db = router.db_for_write(self.through, instance=self.instance)
vals = (self.through._default_manager.using(db)
manager = self.through._default_manager.using(db)
vals = (manager
.values_list(target_field_name, flat=True)
.filter(**{
source_field_name: self._fk_val,
Expand All @@ -144,27 +144,23 @@ def _add_items(self, source_field_name, target_field_name, *objs):
signals.m2m_changed.send(sender=rel.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids_set, using=db)

# Add the ones that aren't there already
sort_field_name = self.through._sort_field_name
sort_field = self.through._meta.get_field_by_name(sort_field_name)[0]
if django.VERSION < (1, 6):
for obj_id in new_ids:
self.through._default_manager.using(db).create(**{
'%s_id' % source_field_name: self._fk_val, # Django 1.5 compatibility
'%s_id' % target_field_name: obj_id,
sort_field_name: sort_field.get_default(),
with atomic(using=db):
fk_val = self._fk_val
source_queryset = manager.filter(**{'%s_id' % source_field_name: fk_val})
sort_field_name = self.through._sort_field_name
sort_value_max = source_queryset.aggregate(max=Max(sort_field_name))['max'] or 0

manager.bulk_create([
self.through(**{
'%s_id' % source_field_name: fk_val,
'%s_id' % target_field_name: pk,
sort_field_name: sort_value_max + i + 1,
})
else:
with transaction.atomic():
sort_field_default = sort_field.get_default()
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self._fk_val,
'%s_id' % target_field_name: v,
sort_field_name: sort_field_default + i,
})
for i, v in enumerate(new_ids)
])
for i, pk in enumerate(new_ids)
])

if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
Expand Down Expand Up @@ -195,9 +191,8 @@ class SortedManyToManyField(ManyToManyField):
'''
def __init__(self, to, sorted=True, **kwargs):
self.sorted = sorted
self.sort_value_field_name = kwargs.pop(
'sort_value_field_name',
SORT_VALUE_FIELD_NAME)
self.sort_value_field_name = kwargs.pop('sort_value_field_name', SORT_VALUE_FIELD_NAME)

super(SortedManyToManyField, self).__init__(to, **kwargs)
if self.sorted:
self.help_text = kwargs.get('help_text', None)
Expand Down Expand Up @@ -316,31 +311,9 @@ def get_rel_to_model_and_object_name(self, klass):
to_object_name = to_model._meta.object_name
return to_model, to_object_name

def get_intermediate_model_sort_value_field_default(self, klass):
def default_sort_value(name):
model = get_model(klass._meta.app_label, name)
# Django 1.5 support.
if django.VERSION < (1, 6):
return model._default_manager.count()
else:
from django.db.utils import ProgrammingError, OperationalError
try:
# We need to catch if the model is not yet migrated in the
# database. The default function is still called in this case while
# running the migration. So we mock the return value of 0.
with transaction.atomic():
return model._default_manager.count()
except (ProgrammingError, OperationalError):
return 0

name = self.get_intermediate_model_name(klass)
default_sort_value = curry(default_sort_value, name)
return default_sort_value

def get_intermediate_model_sort_value_field(self, klass):
default_sort_value = self.get_intermediate_model_sort_value_field_default(klass)
field_name = self.sort_value_field_name
field = models.IntegerField(default=default_sort_value)
field = models.IntegerField(default=0)
return field_name, field

def get_intermediate_model_from_field(self, klass):
Expand Down