Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support parameterized NUMERIC, BIGNUMERIC, STRING, and BYTES types #673

Merged
merged 16 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions google/cloud/bigquery/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ class SchemaField(object):

policy_tags (Optional[PolicyTagList]): The policy tag list for the field.

precision (Optional[int]):
Precison (number of digits) of fields with NUMERIC or BIGNUMERIC type.

scale (Optional[int]):
Scale (digits after decimal) of fields with NUMERIC or BIGNUMERIC type.

max_length (Optional[int]):
Maximim length of fields with STRING or BYTES type.

"""

def __init__(
Expand All @@ -77,6 +86,9 @@ def __init__(
description=_DEFAULT_VALUE,
fields=(),
policy_tags=None,
precision=_DEFAULT_VALUE,
scale=_DEFAULT_VALUE,
max_length=_DEFAULT_VALUE,
):
self._properties = {
"name": name,
Expand All @@ -86,9 +98,22 @@ def __init__(
self._properties["mode"] = mode.upper()
if description is not _DEFAULT_VALUE:
self._properties["description"] = description
if precision is not _DEFAULT_VALUE:
self._properties["precision"] = precision
if scale is not _DEFAULT_VALUE:
self._properties["scale"] = scale
if max_length is not _DEFAULT_VALUE:
self._properties["maxLength"] = max_length
self._fields = tuple(fields)
self._policy_tags = policy_tags

@staticmethod
def __get_int(api_repr, name):
v = api_repr.get(name, _DEFAULT_VALUE)
if v is not _DEFAULT_VALUE:
v = int(v)
return v

@classmethod
def from_api_repr(cls, api_repr: dict) -> "SchemaField":
"""Return a ``SchemaField`` object deserialized from a dictionary.
Expand All @@ -113,6 +138,9 @@ def from_api_repr(cls, api_repr: dict) -> "SchemaField":
description=description,
name=api_repr["name"],
policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")),
precision=cls.__get_int(api_repr, "precision"),
scale=cls.__get_int(api_repr, "scale"),
max_length=cls.__get_int(api_repr, "maxLength"),
)

@property
Expand Down Expand Up @@ -148,6 +176,21 @@ def description(self):
"""Optional[str]: description for the field."""
return self._properties.get("description")

@property
def precision(self):
"""Optional[int]: Precision (number of digits) for the NUMERIC field."""
return self._properties.get("precision")

@property
def scale(self):
"""Optional[int]: Scale (digits after decimal) for the NUMERIC field."""
return self._properties.get("scale")

@property
def max_length(self):
"""Optional[int]: Maximum length for the STRING or BYTES field."""
return self._properties.get("maxLength")

@property
def fields(self):
"""Optional[tuple]: Subfields contained in this field.
Expand Down Expand Up @@ -191,9 +234,19 @@ def _key(self):
Returns:
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`.
"""
field_type = self.field_type.upper()
if field_type == "STRING" or field_type == "BYTES":
if self.max_length is not None:
field_type = f"{field_type}({self.max_length})"
elif field_type.endswith("NUMERIC"):
if self.precision is not None:
if self.scale is not None:
field_type = f"{field_type}({self.precision}, {self.scale})"
else:
field_type = f"{field_type}({self.precision})"
return (
self.name,
self.field_type.upper(),
field_type,
# Mode is always str, if not given it defaults to a str value
self.mode.upper(), # pytype: disable=attribute-error
self.description,
Expand Down Expand Up @@ -269,21 +322,7 @@ def _parse_schema_resource(info):
Optional[Sequence[google.cloud.bigquery.schema.SchemaField`]:
A list of parsed fields, or ``None`` if no "fields" key found.
"""
if "fields" not in info:
return ()

schema = []
for r_field in info["fields"]:
name = r_field["name"]
field_type = r_field["type"]
mode = r_field.get("mode", "NULLABLE")
description = r_field.get("description")
sub_fields = _parse_schema_resource(r_field)
policy_tags = PolicyTagList.from_api_repr(r_field.get("policyTags"))
schema.append(
SchemaField(name, field_type, mode, description, sub_fields, policy_tags)
)
return schema
return [SchemaField.from_api_repr(f) for f in info.get("fields", ())]


def _build_schema_resource(fields):
Expand Down
29 changes: 29 additions & 0 deletions tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,35 @@ def test_list_rows_page_size(self):
page = next(pages)
self.assertEqual(page.num_items, num_last_page)

def test_parameterized_types_round_trip(self):
client = Config.CLIENT
table_id = f"{Config.DATASET}.test_parameterized_types_round_trip"
fields = (
("n", "NUMERIC"),
("n9", "NUMERIC(9)"),
("n92", "NUMERIC(9, 2)"),
("bn", "BIGNUMERIC"),
("bn9", "BIGNUMERIC(38)"),
("bn92", "BIGNUMERIC(38, 22)"),
("s", "STRING"),
("s9", "STRING(9)"),
("b", "BYTES"),
("b9", "BYTES(9)"),
)
self.to_delete.insert(0, Table(f"{client.project}.{table_id}"))
client.query(
"create table {} ({})".format(
table_id, ", ".join(" ".join(f) for f in fields)
)
).result()
table = client.get_table(table_id)
table_id2 = table_id + "2"
self.to_delete.insert(0, Table(f"{client.project}.{table_id2}"))
client.create_table(Table(f"{client.project}.{table_id2}", table.schema))
table2 = client.get_table(table_id2)

self.assertEqual(tuple(s._key()[:2] for s in table2.schema), fields)

def temp_dataset(self, dataset_id, location=None):
project = Config.CLIENT.project
dataset_ref = bigquery.DatasetReference(project, dataset_id)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def _verifySchema(self, query, resource):
self.assertEqual(found.description, expected.get("description"))
self.assertEqual(found.fields, expected.get("fields", ()))
else:
self.assertEqual(query.schema, ())
self.assertEqual(query.schema, [])

def test_ctor_defaults(self):
query = self._make_one(self._make_resource())
Expand All @@ -1312,7 +1312,7 @@ def test_ctor_defaults(self):
self.assertIsNone(query.page_token)
self.assertEqual(query.project, self.PROJECT)
self.assertEqual(query.rows, [])
self.assertEqual(query.schema, ())
self.assertEqual(query.schema, [])
self.assertIsNone(query.total_rows)
self.assertIsNone(query.total_bytes_processed)

Expand Down
123 changes: 123 additions & 0 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import mock
import pytest


class TestSchemaField(unittest.TestCase):
Expand Down Expand Up @@ -715,3 +716,125 @@ def test___hash__not_equals(self):
set_one = {policy1}
set_two = {policy2}
self.assertNotEqual(set_one, set_two)


@pytest.mark.parametrize(
"api,expect,key2",
[
(
dict(name="n", type="NUMERIC"),
("n", "NUMERIC", None, None, None),
("n", "NUMERIC"),
),
(
dict(name="n", type="NUMERIC", precision=9),
("n", "NUMERIC", 9, None, None),
("n", "NUMERIC(9)"),
),
(
dict(name="n", type="NUMERIC", precision=9, scale=2),
("n", "NUMERIC", 9, 2, None),
("n", "NUMERIC(9, 2)"),
),
(
dict(name="n", type="BIGNUMERIC"),
("n", "BIGNUMERIC", None, None, None),
("n", "BIGNUMERIC"),
),
(
dict(name="n", type="BIGNUMERIC", precision=40),
("n", "BIGNUMERIC", 40, None, None),
("n", "BIGNUMERIC(40)"),
),
(
dict(name="n", type="BIGNUMERIC", precision=40, scale=2),
("n", "BIGNUMERIC", 40, 2, None),
("n", "BIGNUMERIC(40, 2)"),
),
(
dict(name="n", type="STRING"),
("n", "STRING", None, None, None),
("n", "STRING"),
),
(
dict(name="n", type="STRING", maxLength=9),
("n", "STRING", None, None, 9),
("n", "STRING(9)"),
),
(
dict(name="n", type="BYTES"),
("n", "BYTES", None, None, None),
("n", "BYTES"),
),
(
dict(name="n", type="BYTES", maxLength=9),
("n", "BYTES", None, None, 9),
("n", "BYTES(9)"),
),
],
)
def test_from_api_repr_parameterized(api, expect, key2):
from google.cloud.bigquery.schema import SchemaField

field = SchemaField.from_api_repr(api)

assert (
field.name,
field.field_type,
field.precision,
field.scale,
field.max_length,
) == expect

assert field._key()[:2] == key2


@pytest.mark.parametrize(
"field,api",
[
(
dict(name="n", field_type="NUMERIC"),
dict(name="n", type="NUMERIC", mode="NULLABLE"),
),
(
dict(name="n", field_type="NUMERIC", precision=9),
dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9),
),
(
dict(name="n", field_type="NUMERIC", precision=9, scale=2),
dict(name="n", type="NUMERIC", mode="NULLABLE", precision=9, scale=2),
),
(
dict(name="n", field_type="BIGNUMERIC"),
dict(name="n", type="BIGNUMERIC", mode="NULLABLE"),
),
(
dict(name="n", field_type="BIGNUMERIC", precision=40),
dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40),
),
(
dict(name="n", field_type="BIGNUMERIC", precision=40, scale=2),
dict(name="n", type="BIGNUMERIC", mode="NULLABLE", precision=40, scale=2),
),
(
dict(name="n", field_type="STRING"),
dict(name="n", type="STRING", mode="NULLABLE"),
),
(
dict(name="n", field_type="STRING", max_length=9),
dict(name="n", type="STRING", mode="NULLABLE", maxLength=9),
),
(
dict(name="n", field_type="BYTES"),
dict(name="n", type="BYTES", mode="NULLABLE"),
),
(
dict(name="n", field_type="BYTES", max_length=9),
dict(name="n", type="BYTES", mode="NULLABLE", maxLength=9),
),
],
)
def test_to_api_repr_parameterized(field, api):
from google.cloud.bigquery.schema import SchemaField

assert SchemaField(**field).to_api_repr() == api