Skip to content

Commit

Permalink
fix(bigquery): use pyarrow fallback for improved schema detection (#9321
Browse files Browse the repository at this point in the history
)

* fix(bigquery): use pyarrow fallback in schema autodetect

* Improve and refactor pyarrow schema detection

Add more pyarrow types, convert to pyarrow only the columns the schema
could not be detected for, etc.

* Use the word "augment" in helper's name

* Fix failed import in one of the tests
  • Loading branch information
plamut authored Nov 4, 2019
1 parent 518931b commit ed37540
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 13 deletions.
113 changes: 106 additions & 7 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,35 @@ def pyarrow_timestamp():
"TIME": pyarrow_time,
"TIMESTAMP": pyarrow_timestamp,
}
ARROW_SCALAR_IDS_TO_BQ = {
# https://arrow.apache.org/docs/python/api/datatypes.html#type-classes
pyarrow.bool_().id: "BOOL",
pyarrow.int8().id: "INT64",
pyarrow.int16().id: "INT64",
pyarrow.int32().id: "INT64",
pyarrow.int64().id: "INT64",
pyarrow.uint8().id: "INT64",
pyarrow.uint16().id: "INT64",
pyarrow.uint32().id: "INT64",
pyarrow.uint64().id: "INT64",
pyarrow.float16().id: "FLOAT64",
pyarrow.float32().id: "FLOAT64",
pyarrow.float64().id: "FLOAT64",
pyarrow.time32("ms").id: "TIME",
pyarrow.time64("ns").id: "TIME",
pyarrow.timestamp("ns").id: "TIMESTAMP",
pyarrow.date32().id: "DATE",
pyarrow.date64().id: "DATETIME", # because millisecond resolution
pyarrow.binary().id: "BYTES",
pyarrow.string().id: "STRING", # also alias for pyarrow.utf8()
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
# The exact decimal's scale and precision are not important, as only
# the type ID matters, and it's the same for all decimal128 instances.
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER


def bq_to_arrow_struct_data_type(field):
Expand Down Expand Up @@ -141,10 +168,11 @@ def bq_to_arrow_data_type(field):
return pyarrow.list_(inner_type)
return None

if field.field_type.upper() in schema._STRUCT_TYPES:
field_type_upper = field.field_type.upper() if field.field_type else ""
if field_type_upper in schema._STRUCT_TYPES:
return bq_to_arrow_struct_data_type(field)

data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper())
data_type_constructor = BQ_TO_ARROW_SCALARS.get(field_type_upper)
if data_type_constructor is None:
return None
return data_type_constructor()
Expand Down Expand Up @@ -183,9 +211,12 @@ def bq_to_arrow_schema(bq_schema):

def bq_to_arrow_array(series, bq_field):
arrow_type = bq_to_arrow_data_type(bq_field)

field_type_upper = bq_field.field_type.upper() if bq_field.field_type else ""

if bq_field.mode.upper() == "REPEATED":
return pyarrow.ListArray.from_pandas(series, type=arrow_type)
if bq_field.field_type.upper() in schema._STRUCT_TYPES:
if field_type_upper in schema._STRUCT_TYPES:
return pyarrow.StructArray.from_pandas(series, type=arrow_type)
return pyarrow.array(series, type=arrow_type)

Expand Down Expand Up @@ -267,6 +298,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused = set()

bq_schema_out = []
unknown_type_fields = []

for column, dtype in list_columns_and_indexes(dataframe):
# Use provided type from schema, if present.
bq_field = bq_schema_index.get(column)
Expand All @@ -278,12 +311,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
# Otherwise, try to automatically determine the type based on the
# pandas dtype.
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)

if bq_field.field_type is None:
unknown_type_fields.append(bq_field)

# Catch any schema mismatch. The developer explicitly asked to serialize a
# column, but it was not found.
if bq_schema_unused:
Expand All @@ -292,7 +325,73 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused
)
)
return tuple(bq_schema_out)

# If schema detection was not successful for all columns, also try with
# pyarrow, if available.
if unknown_type_fields:
if not pyarrow:
msg = u"Could not determine the type of columns: {}".format(
", ".join(field.name for field in unknown_type_fields)
)
warnings.warn(msg)
return None # We cannot detect the schema in full.

# The augment_schema() helper itself will also issue unknown type
# warnings if detection still fails for any of the fields.
bq_schema_out = augment_schema(dataframe, bq_schema_out)

return tuple(bq_schema_out) if bq_schema_out else None


def augment_schema(dataframe, current_bq_schema):
"""Try to deduce the unknown field types and return an improved schema.
This function requires ``pyarrow`` to run. If all the missing types still
cannot be detected, ``None`` is returned. If all types are already known,
a shallow copy of the given schema is returned.
Args:
dataframe (pandas.DataFrame):
DataFrame for which some of the field types are still unknown.
current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
A BigQuery schema for ``dataframe``. The types of some or all of
the fields may be ``None``.
Returns:
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]
"""
augmented_schema = []
unknown_type_fields = []

for field in current_bq_schema:
if field.field_type is not None:
augmented_schema.append(field)
continue

arrow_table = pyarrow.array(dataframe[field.name])
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)

if detected_type is None:
unknown_type_fields.append(field)
continue

new_field = schema.SchemaField(
name=field.name,
field_type=detected_type,
mode=field.mode,
description=field.description,
fields=field.fields,
)
augmented_schema.append(new_field)

if unknown_type_fields:
warnings.warn(
u"Pyarrow could not determine the type of columns: {}.".format(
", ".join(field.name for field in unknown_type_fields)
)
)
return None

return augmented_schema


def dataframe_to_arrow(dataframe, bq_schema):
Expand Down
180 changes: 180 additions & 0 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import decimal
import functools
import operator
import warnings

import mock
Expand Down Expand Up @@ -957,6 +958,185 @@ def test_dataframe_to_parquet_compression_method(module_under_test):


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)

with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
msg = str(expected_warnings[0])
assert "execution_date" in msg and "created_at" in msg


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

expected_schema = (
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
schema.SchemaField("status", "STRING", mode="NULLABLE"),
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
)
by_name = operator.attrgetter("name")
assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert not unwanted_warnings


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"struct_field": {"one": 2}, "status": u"FOO"},
{"struct_field": {"two": u"222"}, "status": u"BAR"},
]
)

with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
assert "struct_field" in str(expected_warnings[0])


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_augment_schema_type_detection_succeeds(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"bool_field": False,
"int_field": 123,
"float_field": 3.141592,
"time_field": datetime.time(17, 59, 47),
"timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55),
"date_field": datetime.date(2005, 5, 31),
"bytes_field": b"some bytes",
"string_field": u"some characters",
"numeric_field": decimal.Decimal("123.456"),
}
]
)

# NOTE: In Pandas dataframe, the dtype of Python's datetime instances is
# set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray.
# We thus cannot expect to get a DATETIME date when converting back to the
# BigQuery type.

current_schema = (
schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("int_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("float_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("time_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("date_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("string_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"),
)

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "Pyarrow could not" in str(warning)
]
assert not unwanted_warnings

# the augmented schema must match the expected
expected_schema = (
schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"),
schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"),
schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"),
schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"),
schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"),
schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"),
)
by_name = operator.attrgetter("name")
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_augment_schema_type_detection_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"status": u"FOO",
"struct_field": {"one": 1},
"struct_field_2": {"foo": u"123"},
},
{
"status": u"BAR",
"struct_field": {"two": u"111"},
"struct_field_2": {"bar": 27},
},
]
)
current_schema = [
schema.SchemaField("status", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"),
]

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

assert augmented_schema is None

expected_warnings = [
warning for warning in warned if "could not determine" in str(warning)
]
assert len(expected_warnings) == 1
warning_msg = str(expected_warnings[0])
assert "pyarrow" in warning_msg.lower()
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
dict_schema = [
Expand Down
Loading

0 comments on commit ed37540

Please sign in to comment.