Skip to content

Commit

Permalink
Fix: Check database errors properly.
Browse files Browse the repository at this point in the history
  - Have sane testing of object states
  • Loading branch information
terjekv committed Jun 19, 2023
1 parent 3e05e80 commit 034e595
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions mreg/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from itertools import combinations
from unittest import mock

from copy import deepcopy

from unittest import mock

from django.db import DatabaseError
Expand Down Expand Up @@ -299,18 +301,27 @@ def setUp(self):
@mock.patch('django.db.models.Model.save')
def test_update_serialno_handles_database_error(self, mock_save):
mock_save.side_effect = DatabaseError
zone_sample = deepcopy(self.zone_sample)

# This will raise a DatabaseError, which we will catch and ignore.
# During the course of this update provcess, the serial number will be
# incremented and the updated_at timestamp will be updated.
self.zone_sample.update_serialno(force=True)


# The save() method should have been called exactly once.
mock_save.assert_called_once()

# Assert that the zone's serial number has not changed
self.assertEqual(self.zone_sample.serialno, self.zone_sample.serialno)

# Assert that updated is still False
self.assertEqual(self.zone_sample.updated, False)

# Assert that serialno_updated_at has not changed
self.assertEqual(self.zone_sample.serialno_updated_at, self.zone_sample.serialno_updated_at)
# The serial number will be incremented during the run, before we try to save
# the object. Checking this increment verifies that we are running through
# the code to get ready to save.
self.assertEqual(zone_sample.serialno + 1, self.zone_sample.serialno)

# Refetch the object from the database.
zone_sample_db = ForwardZone.objects.get(name='example.org')

# Check that the serial number in the database is the same as the one we originally had.
self.assertEqual(zone_sample.serialno, zone_sample_db.serialno)
self.assertEqual(zone_sample.serialno_updated_at, zone_sample_db.serialno_updated_at)

def test_model_can_create_ns(self):
"""Test that the model is able to create an Ns."""
Expand Down

0 comments on commit 034e595

Please sign in to comment.