Skip to content

Commit

Permalink
feat: add decimal/numeric support
Browse files Browse the repository at this point in the history
  • Loading branch information
vi3k6i5 committed May 10, 2021
1 parent ad8e43e commit 3d014ae
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 34 deletions.
2 changes: 1 addition & 1 deletion django_spanner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"CharField": "STRING(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "FLOAT64",
"DecimalField": "NUMERIC",
"DurationField": "INT64",
"EmailField": "STRING(%(max_length)s)",
"FileField": "STRING(%(max_length)s)",
Expand Down
1 change: 1 addition & 0 deletions django_spanner/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
TypeCode.INT64: "IntegerField",
TypeCode.STRING: "CharField",
TypeCode.TIMESTAMP: "DateTimeField",
TypeCode.NUMERIC: "DecimalField",
}

def get_field_type(self, data_type, description):
Expand Down
8 changes: 1 addition & 7 deletions django_spanner/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from django.db.models import DecimalField
from django.db.models.lookups import (
Contains,
EndsWith,
Expand Down Expand Up @@ -233,13 +232,8 @@ def cast_param_to_float(self, compiler, connection):
"""
sql, params = self.as_sql(compiler, connection)
if params:
# Cast to DecimaField lookup values to float because
# google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize
# decimal.Decimal.
if isinstance(self.lhs.output_field, DecimalField):
params[0] = float(params[0])
# Cast remote field lookups that must be integer but come in as string.
elif hasattr(self.lhs.output_field, "get_path_info"):
if hasattr(self.lhs.output_field, "get_path_info"):
for i, field in enumerate(
self.lhs.output_field.get_path_info()[-1].target_fields
):
Expand Down
26 changes: 2 additions & 24 deletions django_spanner/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def adapt_decimalfield_value(
:param decimal_places: (Optional) The number of decimal places to store
with the number.
:rtype: float
:rtype: Decimal
:returns: Formatted value.
"""
if value is None:
return None
return float(value)
return Decimal(value)

def adapt_timefield_value(self, value):
"""
Expand Down Expand Up @@ -244,8 +244,6 @@ def get_db_converters(self, expression):
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
elif internal_type == "DecimalField":
converters.append(self.convert_decimalfield_value)
elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
elif internal_type == "BinaryField":
Expand Down Expand Up @@ -311,26 +309,6 @@ def convert_datetimefield_value(self, value, expression, connection):
else dt
)

def convert_decimalfield_value(self, value, expression, connection):
"""Convert Spanner DecimalField value for Django.
:type value: float
:param value: A decimal field.
:type expression: :class:`django.db.models.expressions.BaseExpression`
:param expression: A query expression.
:type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection`
:param connection: Reference to a Spanner database connection.
:rtype: :class:`Decimal`
:returns: A converted decimal field.
"""
if value is None:
return value
# Cloud Spanner returns a float.
return Decimal(str(value))

def convert_timefield_value(self, value, expression, connection):
"""Convert Spanner TimeField value for Django.
Expand Down
57 changes: 55 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from __future__ import absolute_import

import os
import pathlib
import shutil

import nox
Expand All @@ -25,7 +26,9 @@

DEFAULT_PYTHON_VERSION = "3.8"
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]

CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()


@nox.session(python=DEFAULT_PYTHON_VERSION)
Expand Down Expand Up @@ -81,7 +84,7 @@ def default(session):
"--cov-report=",
"--cov-fail-under=20",
os.path.join("tests", "unit"),
*session.posargs
*session.posargs,
)


Expand All @@ -91,6 +94,56 @@ def unit(session):
default(session)


@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
def system(session):
"""Run the system test suite."""
constraints_path = str(
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
)
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")

# Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true.
if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false":
session.skip("RUN_SYSTEM_TESTS is set to false, skipping")
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS", ""
) and not os.environ.get("SPANNER_EMULATOR_HOST", ""):
session.skip(
"Credentials or emulator host must be set via environment variable"
)

system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
# Sanity check: only run tests if found.
if not system_test_exists and not system_test_folder_exists:
session.skip("System tests were not found")

# Use pre-release gRPC for system tests.
session.install("--pre", "grpcio")

# Install all test dependencies, then install this package into the
# virtualenv's dist-packages.
session.install(
"django~=2.2",
"mock",
"pytest",
"google-cloud-testutils",
"-c",
constraints_path,
)
session.install("-e", ".[tracing]", "-c", constraints_path)

# Run py.test against the system tests.
if system_test_exists:
session.run("py.test", "--quiet", system_test_path, *session.posargs)
if system_test_folder_exists:
session.run(
"py.test", "--quiet", system_test_folder_path, *session.posargs
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def cover(session):
"""Run the final coverage report.
Expand Down
19 changes: 19 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import os
import django
from django.conf import settings

# We manually designate which settings we will be using in an environment
# variable. This is similar to what occurs in the `manage.py` file.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.system.settings")


# `pytest` automatically calls this function once when tests are run.
def pytest_configure():
settings.DEBUG = False
django.setup()
Empty file.
23 changes: 23 additions & 0 deletions tests/system/django_spanner/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

"""
Different models used by system tests in django-spanner code.
"""
from django.db import models


class Author(models.Model):
first_name = models.CharField(max_length=20)
last_name = models.CharField(max_length=20)
ratting = models.DecimalField()


class Number(models.Model):
num = models.DecimalField()

def __str__(self):
return str(self.num)
117 changes: 117 additions & 0 deletions tests/system/django_spanner/test_decimal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from .models import Author, Number
from django.test import TransactionTestCase
from django.db import connection, ProgrammingError
from django.db.utils import IntegrityError
from decimal import Decimal
from tests.system.django_spanner.utils import (
setup_instance,
teardown_instance,
setup_database,
teardown_database,
USE_EMULATOR,
)


class TestDecimal(TransactionTestCase):
@classmethod
def setUpClass(cls):
setup_instance()
setup_database()
with connection.schema_editor() as editor:
# Create the tables
editor.create_model(Author)
editor.create_model(Number)

@classmethod
def tearDownClass(cls):
with connection.schema_editor() as editor:
# delete the table
editor.delete_model(Author)
editor.delete_model(Number)
teardown_database()
teardown_instance()

def ratting_transform(self, value):
return value["ratting"]

def values_transform(self, value):
return value.num

def assertValuesEqual(
self, queryset, expected_values, transformer, ordered=True
):
self.assertQuerysetEqual(
queryset, expected_values, transformer, ordered
)

def test_insert_and_search_decimal_value(self):
"""
Tests model object creation with Author model.
"""
author_kent = Author(
first_name="Arthur", last_name="Kent", ratting=Decimal("4.1"),
)
author_kent.save()
qs1 = Author.objects.filter(ratting__gte=3).values("ratting")
self.assertValuesEqual(
qs1, [Decimal("4.1")], self.ratting_transform,
)
# Delete data from Author table.
Author.objects.all().delete()

def test_decimal_filter(self):
"""
Tests decimal filter query.
"""
# Insert data into Number table.
Number.objects.bulk_create(
Number(num=Decimal(i) / Decimal(10)) for i in range(10)
)
qs1 = Number.objects.filter(num__lte=Decimal(2) / Decimal(10))
self.assertValuesEqual(
qs1,
[Decimal(i) / Decimal(10) for i in range(3)],
self.values_transform,
ordered=False,
)
# Delete data from Number table.
Number.objects.all().delete()

def test_decimal_precision_limit(self):
"""
Tests decimal object precission limit.
"""
num_val = Number(num=Decimal(1) / Decimal(3))
if USE_EMULATOR:
msg = "The NUMERIC type supports 38 digits of precision and 9 digits of scale."
with self.assertRaisesRegex(IntegrityError, msg):
num_val.save()
else:
msg = "400 Invalid value for bind parameter a0: Expected NUMERIC."
with self.assertRaisesRegex(ProgrammingError, msg):
num_val.save()

def test_decimal_update(self):
"""
Tests decimal object update.
"""
author_kent = Author(
first_name="Arthur", last_name="Kent", ratting=Decimal("4.1"),
)
author_kent.save()
author_kent.ratting = Decimal("4.2")
author_kent.save()
qs1 = Author.objects.filter(ratting__gte=Decimal("4.2")).values(
"ratting"
)
self.assertValuesEqual(
qs1, [Decimal("4.2")], self.ratting_transform,
)
# Delete data from Author table.
Author.objects.all().delete()
Loading

0 comments on commit 3d014ae

Please sign in to comment.