Skip to content

Commit

Permalink
BigQuery: Add to_standard_sql() method to SchemaField (#8880)
Browse files Browse the repository at this point in the history
* Add to_standard_sql() method to SchemaField

* Support standard SQL names in to_standard_sql()

* Add support for ARRAY type in to_standard_sql()
  • Loading branch information
plamut authored Aug 5, 2019
1 parent 8c8e360 commit 0ce1ca5
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 0 deletions.
62 changes: 62 additions & 0 deletions bigquery/google/cloud/bigquery/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,33 @@

"""Schemas for BigQuery tables / queries."""

from google.cloud.bigquery_v2 import types


# SQL types reference:
# https://cloud.google.com/bigquery/data-types#legacy_sql_data_types
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types
LEGACY_TO_STANDARD_TYPES = {
"STRING": types.StandardSqlDataType.STRING,
"BYTES": types.StandardSqlDataType.BYTES,
"INTEGER": types.StandardSqlDataType.INT64,
"INT64": types.StandardSqlDataType.INT64,
"FLOAT": types.StandardSqlDataType.FLOAT64,
"FLOAT64": types.StandardSqlDataType.FLOAT64,
"NUMERIC": types.StandardSqlDataType.NUMERIC,
"BOOLEAN": types.StandardSqlDataType.BOOL,
"BOOL": types.StandardSqlDataType.BOOL,
"GEOGRAPHY": types.StandardSqlDataType.GEOGRAPHY,
"RECORD": types.StandardSqlDataType.STRUCT,
"STRUCT": types.StandardSqlDataType.STRUCT,
"TIMESTAMP": types.StandardSqlDataType.TIMESTAMP,
"DATE": types.StandardSqlDataType.DATE,
"TIME": types.StandardSqlDataType.TIME,
"DATETIME": types.StandardSqlDataType.DATETIME,
# no direct conversion from ARRAY, the latter is represented by mode="REPEATED"
}
"""String names of the legacy SQL types to integer codes of Standard SQL types."""


class SchemaField(object):
"""Describe a single field within a table schema.
Expand Down Expand Up @@ -146,6 +173,41 @@ def _key(self):
self._fields,
)

def to_standard_sql(self):
"""Return the field as the standard SQL field representation object.
Returns:
An instance of :class:`~google.cloud.bigquery_v2.types.StandardSqlField`.
"""
sql_type = types.StandardSqlDataType()

if self.mode == "REPEATED":
sql_type.type_kind = types.StandardSqlDataType.ARRAY
else:
sql_type.type_kind = LEGACY_TO_STANDARD_TYPES.get(
self.field_type, types.StandardSqlDataType.TYPE_KIND_UNSPECIFIED
)

if sql_type.type_kind == types.StandardSqlDataType.ARRAY: # noqa: E721
array_element_type = LEGACY_TO_STANDARD_TYPES.get(
self.field_type, types.StandardSqlDataType.TYPE_KIND_UNSPECIFIED
)
sql_type.array_element_type.type_kind = array_element_type

# ARRAY cannot directly contain other arrays, only scalar types and STRUCTs
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#array-type
if array_element_type == types.StandardSqlDataType.STRUCT: # noqa: E721
sql_type.array_element_type.struct_type.fields.extend(
field.to_standard_sql() for field in self.fields
)

elif sql_type.type_kind == types.StandardSqlDataType.STRUCT: # noqa: E721
sql_type.struct_type.fields.extend(
field.to_standard_sql() for field in self.fields
)

return types.StandardSqlField(name=self.name, type=sql_type)

def __eq__(self, other):
if not isinstance(other, SchemaField):
return NotImplemented
Expand Down
166 changes: 166 additions & 0 deletions bigquery/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def _get_target_class():

return SchemaField

@staticmethod
def _get_standard_sql_data_type_class():
from google.cloud.bigquery_v2 import types

return types.StandardSqlDataType

def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

Expand Down Expand Up @@ -151,6 +157,166 @@ def test_fields_property(self):
schema_field = self._make_one("boat", "RECORD", fields=fields)
self.assertIs(schema_field.fields, fields)

def test_to_standard_sql_simple_type(self):
sql_type = self._get_standard_sql_data_type_class()
examples = (
# a few legacy types
("INTEGER", sql_type.INT64),
("FLOAT", sql_type.FLOAT64),
("BOOLEAN", sql_type.BOOL),
("DATETIME", sql_type.DATETIME),
# a few standard types
("INT64", sql_type.INT64),
("FLOAT64", sql_type.FLOAT64),
("BOOL", sql_type.BOOL),
("GEOGRAPHY", sql_type.GEOGRAPHY),
)
for legacy_type, standard_type in examples:
field = self._make_one("some_field", legacy_type)
standard_field = field.to_standard_sql()
self.assertEqual(standard_field.name, "some_field")
self.assertEqual(standard_field.type.type_kind, standard_type)
self.assertFalse(standard_field.type.HasField("sub_type"))

def test_to_standard_sql_struct_type(self):
from google.cloud.bigquery_v2 import types

# Expected result object:
#
# name: "image_usage"
# type {
# type_kind: STRUCT
# struct_type {
# fields {
# name: "image_content"
# type {type_kind: BYTES}
# }
# fields {
# name: "last_used"
# type {
# type_kind: STRUCT
# struct_type {
# fields {
# name: "date_field"
# type {type_kind: DATE}
# }
# fields {
# name: "time_field"
# type {type_kind: TIME}
# }
# }
# }
# }
# }
# }

sql_type = self._get_standard_sql_data_type_class()

# level 2 fields
sub_sub_field_date = types.StandardSqlField(
name="date_field", type=sql_type(type_kind=sql_type.DATE)
)
sub_sub_field_time = types.StandardSqlField(
name="time_field", type=sql_type(type_kind=sql_type.TIME)
)

# level 1 fields
sub_field_struct = types.StandardSqlField(
name="last_used", type=sql_type(type_kind=sql_type.STRUCT)
)
sub_field_struct.type.struct_type.fields.extend(
[sub_sub_field_date, sub_sub_field_time]
)
sub_field_bytes = types.StandardSqlField(
name="image_content", type=sql_type(type_kind=sql_type.BYTES)
)

# level 0 (top level)
expected_result = types.StandardSqlField(
name="image_usage", type=sql_type(type_kind=sql_type.STRUCT)
)
expected_result.type.struct_type.fields.extend(
[sub_field_bytes, sub_field_struct]
)

# construct legacy SchemaField object
sub_sub_field1 = self._make_one("date_field", "DATE")
sub_sub_field2 = self._make_one("time_field", "TIME")
sub_field_record = self._make_one(
"last_used", "RECORD", fields=(sub_sub_field1, sub_sub_field2)
)
sub_field_bytes = self._make_one("image_content", "BYTES")

for type_name in ("RECORD", "STRUCT"):
schema_field = self._make_one(
"image_usage", type_name, fields=(sub_field_bytes, sub_field_record)
)
standard_field = schema_field.to_standard_sql()
self.assertEqual(standard_field, expected_result)

def test_to_standard_sql_array_type_simple(self):
from google.cloud.bigquery_v2 import types

sql_type = self._get_standard_sql_data_type_class()

# construct expected result object
expected_sql_type = sql_type(type_kind=sql_type.ARRAY)
expected_sql_type.array_element_type.type_kind = sql_type.INT64
expected_result = types.StandardSqlField(
name="valid_numbers", type=expected_sql_type
)

# construct "repeated" SchemaField object and convert to standard SQL
schema_field = self._make_one("valid_numbers", "INT64", mode="REPEATED")
standard_field = schema_field.to_standard_sql()

self.assertEqual(standard_field, expected_result)

def test_to_standard_sql_array_type_struct(self):
from google.cloud.bigquery_v2 import types

sql_type = self._get_standard_sql_data_type_class()

# define person STRUCT
name_field = types.StandardSqlField(
name="name", type=sql_type(type_kind=sql_type.STRING)
)
age_field = types.StandardSqlField(
name="age", type=sql_type(type_kind=sql_type.INT64)
)
person_struct = types.StandardSqlField(
name="person_info", type=sql_type(type_kind=sql_type.STRUCT)
)
person_struct.type.struct_type.fields.extend([name_field, age_field])

# define expected result - an ARRAY of person structs
expected_sql_type = sql_type(
type_kind=sql_type.ARRAY, array_element_type=person_struct.type
)
expected_result = types.StandardSqlField(
name="known_people", type=expected_sql_type
)

# construct legacy repeated SchemaField object
sub_field1 = self._make_one("name", "STRING")
sub_field2 = self._make_one("age", "INTEGER")
schema_field = self._make_one(
"known_people", "RECORD", fields=(sub_field1, sub_field2), mode="REPEATED"
)

standard_field = schema_field.to_standard_sql()
self.assertEqual(standard_field, expected_result)

def test_to_standard_sql_unknown_type(self):
sql_type = self._get_standard_sql_data_type_class()
field = self._make_one("weird_field", "TROOLEAN")

standard_field = field.to_standard_sql()

self.assertEqual(standard_field.name, "weird_field")
self.assertEqual(standard_field.type.type_kind, sql_type.TYPE_KIND_UNSPECIFIED)
self.assertFalse(standard_field.type.HasField("sub_type"))

def test___eq___wrong_type(self):
field = self._make_one("test", "STRING")
other = object()
Expand Down

0 comments on commit 0ce1ca5

Please sign in to comment.