Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bigquery): use pyarrow fallback for improved schema detection #9321

Merged
merged 6 commits into from
Nov 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 106 additions & 7 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,35 @@ def pyarrow_timestamp():
"TIME": pyarrow_time,
"TIMESTAMP": pyarrow_timestamp,
}
ARROW_SCALAR_IDS_TO_BQ = {
# https://arrow.apache.org/docs/python/api/datatypes.html#type-classes
pyarrow.bool_().id: "BOOL",
pyarrow.int8().id: "INT64",
pyarrow.int16().id: "INT64",
pyarrow.int32().id: "INT64",
pyarrow.int64().id: "INT64",
pyarrow.uint8().id: "INT64",
pyarrow.uint16().id: "INT64",
pyarrow.uint32().id: "INT64",
pyarrow.uint64().id: "INT64",
pyarrow.float16().id: "FLOAT64",
pyarrow.float32().id: "FLOAT64",
pyarrow.float64().id: "FLOAT64",
pyarrow.time32("ms").id: "TIME",
pyarrow.time64("ns").id: "TIME",
pyarrow.timestamp("ns").id: "TIMESTAMP",
pyarrow.date32().id: "DATE",
pyarrow.date64().id: "DATETIME", # because millisecond resolution
pyarrow.binary().id: "BYTES",
pyarrow.string().id: "STRING", # also alias for pyarrow.utf8()
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
# The exact decimal's scale and precision are not important, as only
# the type ID matters, and it's the same for all decimal128 instances.
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER


def bq_to_arrow_struct_data_type(field):
Expand Down Expand Up @@ -141,10 +168,11 @@ def bq_to_arrow_data_type(field):
return pyarrow.list_(inner_type)
return None

if field.field_type.upper() in schema._STRUCT_TYPES:
field_type_upper = field.field_type.upper() if field.field_type else ""
if field_type_upper in schema._STRUCT_TYPES:
return bq_to_arrow_struct_data_type(field)

data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper())
data_type_constructor = BQ_TO_ARROW_SCALARS.get(field_type_upper)
if data_type_constructor is None:
return None
return data_type_constructor()
Expand Down Expand Up @@ -183,9 +211,12 @@ def bq_to_arrow_schema(bq_schema):

def bq_to_arrow_array(series, bq_field):
arrow_type = bq_to_arrow_data_type(bq_field)

field_type_upper = bq_field.field_type.upper() if bq_field.field_type else ""

if bq_field.mode.upper() == "REPEATED":
return pyarrow.ListArray.from_pandas(series, type=arrow_type)
if bq_field.field_type.upper() in schema._STRUCT_TYPES:
if field_type_upper in schema._STRUCT_TYPES:
return pyarrow.StructArray.from_pandas(series, type=arrow_type)
return pyarrow.array(series, type=arrow_type)

Expand Down Expand Up @@ -267,6 +298,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused = set()

bq_schema_out = []
unknown_type_fields = []

for column, dtype in list_columns_and_indexes(dataframe):
# Use provided type from schema, if present.
bq_field = bq_schema_index.get(column)
Expand All @@ -278,12 +311,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
# Otherwise, try to automatically determine the type based on the
# pandas dtype.
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)

if bq_field.field_type is None:
unknown_type_fields.append(bq_field)

# Catch any schema mismatch. The developer explicitly asked to serialize a
# column, but it was not found.
if bq_schema_unused:
Expand All @@ -292,7 +325,73 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused
)
)
return tuple(bq_schema_out)

# If schema detection was not successful for all columns, also try with
# pyarrow, if available.
if unknown_type_fields:
if not pyarrow:
msg = u"Could not determine the type of columns: {}".format(
", ".join(field.name for field in unknown_type_fields)
)
warnings.warn(msg)
return None # We cannot detect the schema in full.

# The augment_schema() helper itself will also issue unknown type
# warnings if detection still fails for any of the fields.
bq_schema_out = augment_schema(dataframe, bq_schema_out)

return tuple(bq_schema_out) if bq_schema_out else None


def augment_schema(dataframe, current_bq_schema):
"""Try to deduce the unknown field types and return an improved schema.

This function requires ``pyarrow`` to run. If all the missing types still
cannot be detected, ``None`` is returned. If all types are already known,
a shallow copy of the given schema is returned.

Args:
dataframe (pandas.DataFrame):
DataFrame for which some of the field types are still unknown.
current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
A BigQuery schema for ``dataframe``. The types of some or all of
the fields may be ``None``.
Returns:
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]
"""
augmented_schema = []
unknown_type_fields = []

for field in current_bq_schema:
if field.field_type is not None:
augmented_schema.append(field)
continue

arrow_table = pyarrow.array(dataframe[field.name])
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)

if detected_type is None:
unknown_type_fields.append(field)
continue

new_field = schema.SchemaField(
name=field.name,
field_type=detected_type,
mode=field.mode,
description=field.description,
fields=field.fields,
)
augmented_schema.append(new_field)

if unknown_type_fields:
warnings.warn(
u"Pyarrow could not determine the type of columns: {}.".format(
", ".join(field.name for field in unknown_type_fields)
)
)
return None

return augmented_schema


def dataframe_to_arrow(dataframe, bq_schema):
Expand Down
180 changes: 180 additions & 0 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import decimal
import functools
import operator
import warnings

import mock
Expand Down Expand Up @@ -957,6 +958,185 @@ def test_dataframe_to_parquet_compression_method(module_under_test):


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)

with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
msg = str(expected_warnings[0])
assert "execution_date" in msg and "created_at" in msg


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

expected_schema = (
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
schema.SchemaField("status", "STRING", mode="NULLABLE"),
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
)
by_name = operator.attrgetter("name")
assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert not unwanted_warnings


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"struct_field": {"one": 2}, "status": u"FOO"},
{"struct_field": {"two": u"222"}, "status": u"BAR"},
]
)

with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
assert "struct_field" in str(expected_warnings[0])


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_augment_schema_type_detection_succeeds(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"bool_field": False,
"int_field": 123,
"float_field": 3.141592,
"time_field": datetime.time(17, 59, 47),
"timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55),
"date_field": datetime.date(2005, 5, 31),
"bytes_field": b"some bytes",
"string_field": u"some characters",
"numeric_field": decimal.Decimal("123.456"),
}
]
)

# NOTE: In Pandas dataframe, the dtype of Python's datetime instances is
# set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray.
# We thus cannot expect to get a DATETIME date when converting back to the
# BigQuery type.

current_schema = (
schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("int_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("float_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("time_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("date_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("string_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"),
)

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "Pyarrow could not" in str(warning)
]
assert not unwanted_warnings

# the augmented schema must match the expected
expected_schema = (
schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"),
schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"),
schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"),
schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"),
schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"),
schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"),
)
by_name = operator.attrgetter("name")
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_augment_schema_type_detection_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"status": u"FOO",
"struct_field": {"one": 1},
"struct_field_2": {"foo": u"123"},
},
{
"status": u"BAR",
"struct_field": {"two": u"111"},
"struct_field_2": {"bar": 27},
},
]
)
current_schema = [
schema.SchemaField("status", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"),
]

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

assert augmented_schema is None

expected_warnings = [
warning for warning in warned if "could not determine" in str(warning)
]
assert len(expected_warnings) == 1
warning_msg = str(expected_warnings[0])
assert "pyarrow" in warning_msg.lower()
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
dict_schema = [
Expand Down
Loading