Skip to content

Commit

Permalink
feat: added unit test with coverage of 68% (#611)
Browse files Browse the repository at this point in the history
Add unit tests for many spanner_django modules that add coverage beyond the built-in django tests.
  • Loading branch information
vi3k6i5 authored May 17, 2021
1 parent 3fa1aeb commit 92ad508
Show file tree
Hide file tree
Showing 14 changed files with 1,057 additions and 103 deletions.
9 changes: 7 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def lint_setup_py(session):
def default(session):
# Install all test dependencies, then install this package in-place.
session.install(
"django~=2.2", "mock", "mock-import", "pytest", "pytest-cov"
"django~=2.2",
"mock",
"mock-import",
"pytest",
"pytest-cov",
"coverage",
)
session.install("-e", ".")

Expand All @@ -79,7 +84,7 @@ def default(session):
"--cov-append",
"--cov-config=.coveragerc",
"--cov-report=",
"--cov-fail-under=20",
"--cov-fail-under=68",
os.path.join("tests", "unit"),
*session.posargs
)
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
import django
from django.conf import settings

# We manually designate which settings we will be using in an environment
# variable. This is similar to what occurs in the `manage.py` file.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings")


# `pytest` automatically calls this function once when tests are run.
def pytest_configure():
settings.DEBUG = False
django.setup()
46 changes: 46 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

DEBUG = True
USE_TZ = True

INSTALLED_APPS = [
"django_spanner", # Must be the first entry
"django.contrib.contenttypes",
"django.contrib.auth",
"django.contrib.sites",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"tests",
]

TIME_ZONE = "UTC"

DATABASES = {
"default": {
"ENGINE": "django_spanner",
"PROJECT": "emulator-local",
"INSTANCE": "django-test-instance",
"NAME": "django-test-db",
}
}
SECRET_KEY = "spanner emulator secret key"

PASSWORD_HASHERS = [
"django.contrib.auth.hashers.MD5PasswordHasher",
]

SITE_ID = 1

CONN_MAX_AGE = 60

ENGINE = "django_spanner"
PROJECT = "emulator-local"
INSTANCE = "django-test-instance"
NAME = "django-test-db"
OPTIONS = {}
AUTOCOMMIT = True
Empty file.
61 changes: 61 additions & 0 deletions tests/unit/django_spanner/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""
Different models used for testing django-spanner code.
"""
from django.db import models


# Register transformations for model fields.
class UpperCase(models.Transform):
lookup_name = "upper"
function = "UPPER"
bilateral = True


models.CharField.register_lookup(UpperCase)
models.TextField.register_lookup(UpperCase)


# Models
class ModelDecimalField(models.Model):
field = models.DecimalField()


class ModelCharField(models.Model):
field = models.CharField()


class Item(models.Model):
item_id = models.IntegerField()
name = models.CharField(max_length=10)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)

class Meta:
ordering = ["name"]


class Number(models.Model):
num = models.IntegerField()
decimal_num = models.DecimalField(max_digits=5, decimal_places=2)
item = models.ForeignKey(Item, models.CASCADE)


class Author(models.Model):
name = models.CharField(max_length=40)
last_name = models.CharField(max_length=40)
num = models.IntegerField(unique=True)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)


class Report(models.Model):
name = models.CharField(max_length=10)
creator = models.ForeignKey(Author, models.CASCADE, null=True)

class Meta:
ordering = ["name"]
33 changes: 33 additions & 0 deletions tests/unit/django_spanner/simple_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from django_spanner.client import DatabaseClient
from django_spanner.base import DatabaseWrapper
from django_spanner.operations import DatabaseOperations
from unittest import TestCase
import os


class SpannerSimpleTestClass(TestCase):
@classmethod
def setUpClass(cls):
cls.PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]

cls.INSTANCE_ID = "instance_id"
cls.DATABASE_ID = "database_id"
cls.USER_AGENT = "django_spanner/2.2.0a1"
cls.OPTIONS = {"option": "dummy"}

cls.settings_dict = {
"PROJECT": cls.PROJECT,
"INSTANCE": cls.INSTANCE_ID,
"NAME": cls.DATABASE_ID,
"user_agent": cls.USER_AGENT,
"OPTIONS": cls.OPTIONS,
}
cls.db_client = DatabaseClient(cls.settings_dict)
cls.db_wrapper = cls.connection = DatabaseWrapper(cls.settings_dict)
cls.db_operations = DatabaseOperations(cls.connection)
97 changes: 29 additions & 68 deletions tests/unit/django_spanner/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,24 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import sys
import unittest
import os

from mock_import import mock_import
from unittest import mock
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass


@mock_import()
@unittest.skipIf(
sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5"
)
class TestBase(unittest.TestCase):
PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]
INSTANCE_ID = "instance_id"
DATABASE_ID = "database_id"
USER_AGENT = "django_spanner/2.2.0a1"
OPTIONS = {"option": "dummy"}

settings_dict = {
"PROJECT": PROJECT,
"INSTANCE": INSTANCE_ID,
"NAME": DATABASE_ID,
"user_agent": USER_AGENT,
"OPTIONS": OPTIONS,
}

def _get_target_class(self):
from django_spanner.base import DatabaseWrapper

return DatabaseWrapper

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

class TestBase(SpannerSimpleTestClass):
def test_property_instance(self):
settings_dict = {"INSTANCE": "instance"}
db_wrapper = self._make_one(settings_dict=settings_dict)

with mock.patch("django_spanner.base.spanner") as mock_spanner:
mock_spanner.Client = mock_client = mock.MagicMock()
mock_client().instance = mock_instance = mock.MagicMock()
_ = db_wrapper.instance
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])
_ = self.db_wrapper.instance
mock_instance.assert_called_once_with(self.INSTANCE_ID)

def test_property__nodb_connection(self):
db_wrapper = self._make_one(None)
def test_property_nodb_connection(self):
with self.assertRaises(NotImplementedError):
db_wrapper._nodb_connection()
self.db_wrapper._nodb_connection()

def test_get_connection_params(self):
db_wrapper = self._make_one(self.settings_dict)
params = db_wrapper.get_connection_params()
params = self.db_wrapper.get_connection_params()

self.assertEqual(params["project"], self.PROJECT)
self.assertEqual(params["instance_id"], self.INSTANCE_ID)
Expand All @@ -65,54 +30,50 @@ def test_get_connection_params(self):
self.assertEqual(params["option"], self.OPTIONS["option"])

def test_get_new_connection(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.Database = mock_database = mock.MagicMock()
self.db_wrapper.Database = mock_database = mock.MagicMock()
mock_database.connect = mock_connection = mock.MagicMock()
conn_params = {"test_param": "dummy"}
db_wrapper.get_new_connection(conn_params)
self.db_wrapper.get_new_connection(conn_params)
mock_connection.assert_called_once_with(**conn_params)

def test_init_connection_state(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.close = mock_close = mock.MagicMock()
db_wrapper.init_connection_state()
self.db_wrapper.init_connection_state()
mock_close.assert_called_once_with()

def test_create_cursor(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock_cursor = mock.MagicMock()
db_wrapper.create_cursor()
self.db_wrapper.create_cursor()
mock_cursor.assert_called_once_with()

def test__set_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
def test_set_autocommit(self):
self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.autocommit = False
db_wrapper._set_autocommit(True)
self.db_wrapper._set_autocommit(True)
self.assertEqual(mock_connection.autocommit, True)

def test_is_usable(self):
from google.cloud.spanner_dbapi.exceptions import Error

db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = None
self.assertFalse(db_wrapper.is_usable())
self.db_wrapper.connection = None
self.assertFalse(self.db_wrapper.is_usable())

db_wrapper.connection = mock_connection = mock.MagicMock()
self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.is_closed = True
self.assertFalse(db_wrapper.is_usable())
self.assertFalse(self.db_wrapper.is_usable())

mock_connection.is_closed = False
self.assertTrue(db_wrapper.is_usable())
self.assertTrue(self.db_wrapper.is_usable())

def test_is_usable_with_error(self):
from google.cloud.spanner_dbapi.exceptions import Error

self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock.MagicMock(side_effect=Error)
self.assertFalse(db_wrapper.is_usable())
self.assertFalse(self.db_wrapper.is_usable())

def test__start_transaction_under_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
def test_start_transaction_under_autocommit(self):
self.db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock_cursor = mock.MagicMock()
db_wrapper._start_transaction_under_autocommit()
self.db_wrapper._start_transaction_under_autocommit()
mock_cursor.assert_called_once_with()
37 changes: 4 additions & 33 deletions tests/unit/django_spanner/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,12 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import sys
import unittest
import os

from google.cloud.spanner_dbapi.exceptions import NotSupportedError
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass

@unittest.skipIf(
sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5"
)
class TestClient(unittest.TestCase):
PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]
INSTANCE_ID = "instance_id"
DATABASE_ID = "database_id"
USER_AGENT = "django_spanner/2.2.0a1"
OPTIONS = {"option": "dummy"}

settings_dict = {
"PROJECT": PROJECT,
"INSTANCE": INSTANCE_ID,
"NAME": DATABASE_ID,
"user_agent": USER_AGENT,
"OPTIONS": OPTIONS,
}

def _get_target_class(self):
from django_spanner.client import DatabaseClient

return DatabaseClient

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

class TestClient(SpannerSimpleTestClass):
def test_runshell(self):
from google.cloud.spanner_dbapi.exceptions import NotSupportedError

db_wrapper = self._make_one(self.settings_dict)

with self.assertRaises(NotSupportedError):
db_wrapper.runshell(parameters=self.settings_dict)
self.db_client.runshell(parameters=self.settings_dict)
Loading

0 comments on commit 92ad508

Please sign in to comment.