Skip to content

Commit

Permalink
Allow passing schema as dicts in pandas helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Oct 28, 2019
1 parent 323573e commit 65301f7
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 10 deletions.
54 changes: 44 additions & 10 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
Args:
dataframe (pandas.DataFrame):
DataFrame for which the client determines the BigQuery schema.
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
bq_schema (Sequence[Union[ \
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
Sequence[Mapping[str, str]] \
]]):
A BigQuery schema. Use this argument to override the autodetected
type for some or all of the DataFrame columns.
Expand All @@ -249,6 +252,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
any column cannot be determined.
"""
if bq_schema:
bq_schema = schema._to_schema_fields(bq_schema)
for field in bq_schema:
if field.field_type in schema._STRUCT_TYPES:
raise ValueError(
Expand Down Expand Up @@ -297,7 +301,10 @@ def dataframe_to_arrow(dataframe, bq_schema):
Args:
dataframe (pandas.DataFrame):
DataFrame to convert to Arrow table.
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
bq_schema (Sequence[Union[ \
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
Sequence[Mapping[str, str]] \
]]):
Desired BigQuery schema. Number of columns must match number of
columns in the DataFrame.
Expand All @@ -310,6 +317,8 @@ def dataframe_to_arrow(dataframe, bq_schema):
column_and_index_names = set(
name for name, _ in list_columns_and_indexes(dataframe)
)

bq_schema = schema._to_schema_fields(bq_schema)
bq_field_names = set(field.name for field in bq_schema)

extra_fields = bq_field_names - column_and_index_names
Expand Down Expand Up @@ -354,7 +363,10 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN
Args:
dataframe (pandas.DataFrame):
DataFrame to convert to Parquet file.
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
bq_schema (Sequence[Union[ \
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
Sequence[Mapping[str, str]] \
]]):
Desired BigQuery schema. Number of columns must match number of
columns in the DataFrame.
filepath (str):
Expand All @@ -368,6 +380,7 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN
if pyarrow is None:
raise ValueError("pyarrow is required for BigQuery schema conversion.")

bq_schema = schema._to_schema_fields(bq_schema)
arrow_table = dataframe_to_arrow(dataframe, bq_schema)
pyarrow.parquet.write_table(arrow_table, filepath, compression=parquet_compression)

Expand All @@ -388,20 +401,24 @@ def _tabledata_list_page_to_arrow(page, column_names, arrow_types):
return pyarrow.RecordBatch.from_arrays(arrays, names=column_names)


def download_arrow_tabledata_list(pages, schema):
def download_arrow_tabledata_list(pages, bq_schema):
"""Use tabledata.list to construct an iterable of RecordBatches.
Args:
pages (Iterator[:class:`google.api_core.page_iterator.Page`]):
An iterator over the result pages.
schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
bq_schema (Sequence[Union[ \
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
Sequence[Mapping[str, str]] \
]]):
A decription of the fields in result pages.
Yields:
:class:`pyarrow.RecordBatch`
The next page of records as a ``pyarrow`` record batch.
"""
column_names = bq_to_arrow_schema(schema) or [field.name for field in schema]
arrow_types = [bq_to_arrow_data_type(field) for field in schema]
bq_schema = schema._to_schema_fields(bq_schema)
column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema]
arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema]

for page in pages:
yield _tabledata_list_page_to_arrow(page, column_names, arrow_types)
Expand All @@ -422,9 +439,26 @@ def _tabledata_list_page_to_dataframe(page, column_names, dtypes):
return pandas.DataFrame(columns, columns=column_names)


def download_dataframe_tabledata_list(pages, schema, dtypes):
"""Use (slower, but free) tabledata.list to construct a DataFrame."""
column_names = [field.name for field in schema]
def download_dataframe_tabledata_list(pages, bq_schema, dtypes):
"""Use (slower, but free) tabledata.list to construct a DataFrame.
Args:
pages (Iterator[:class:`google.api_core.page_iterator.Page`]):
An iterator over the result pages.
bq_schema (Sequence[Union[ \
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
Sequence[Mapping[str, str]] \
]]):
A decription of the fields in result pages.
dtypes(Mapping[str, numpy.dtype]):
The types of columns in result data to hint construction of the
resulting DataFrame. Not all column types have to be specified.
Yields:
:class:`pandas.DataFrame`
The next page of records as a ``pandas.DataFrame`` record batch.
"""
bq_schema = schema._to_schema_fields(bq_schema)
column_names = [field.name for field in bq_schema]
for page in pages:
yield _tabledata_list_page_to_dataframe(page, column_names, dtypes)

Expand Down
137 changes: 137 additions & 0 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,32 @@ def test_list_columns_and_indexes_with_multiindex(module_under_test):
assert columns_and_indexes == expected


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_dataframe_to_bq_schema_dict_sequence(module_under_test):
df_data = collections.OrderedDict(
[
("str_column", [u"hello", u"world"]),
("int_column", [42, 8]),
("bool_column", [True, False]),
]
)
dataframe = pandas.DataFrame(df_data)

dict_schema = [
{"name": "str_column", "type": "STRING", "mode": "NULLABLE"},
{"name": "bool_column", "type": "BOOL", "mode": "REQUIRED"},
]

returned_schema = module_under_test.dataframe_to_bq_schema(dataframe, dict_schema)

expected_schema = (
schema.SchemaField("str_column", "STRING", "NULLABLE"),
schema.SchemaField("int_column", "INTEGER", "NULLABLE"),
schema.SchemaField("bool_column", "BOOL", "REQUIRED"),
)
assert returned_schema == expected_schema


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_arrow_with_multiindex(module_under_test):
Expand Down Expand Up @@ -856,6 +882,28 @@ def test_dataframe_to_arrow_with_unknown_type(module_under_test):
assert arrow_schema[3].name == "field03"


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_arrow_dict_sequence_schema(module_under_test):
dict_schema = [
{"name": "field01", "type": "STRING", "mode": "REQUIRED"},
{"name": "field02", "type": "BOOL", "mode": "NULLABLE"},
]

dataframe = pandas.DataFrame(
{"field01": [u"hello", u"world"], "field02": [True, False]}
)

arrow_table = module_under_test.dataframe_to_arrow(dataframe, dict_schema)
arrow_schema = arrow_table.schema

expected_fields = [
pyarrow.field("field01", "string", nullable=False),
pyarrow.field("field02", "bool", nullable=True),
]
assert list(arrow_schema) == expected_fields


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_dataframe_to_parquet_without_pyarrow(module_under_test, monkeypatch):
monkeypatch.setattr(module_under_test, "pyarrow", None)
Expand Down Expand Up @@ -908,6 +956,36 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
assert call_args.kwargs.get("compression") == "ZSTD"


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
dict_schema = [
{"name": "field01", "type": "STRING", "mode": "REQUIRED"},
{"name": "field02", "type": "BOOL", "mode": "NULLABLE"},
]

dataframe = pandas.DataFrame(
{"field01": [u"hello", u"world"], "field02": [True, False]}
)

write_table_patch = mock.patch.object(
module_under_test.pyarrow.parquet, "write_table", autospec=True
)
to_arrow_patch = mock.patch.object(
module_under_test, "dataframe_to_arrow", autospec=True
)

with write_table_patch, to_arrow_patch as fake_to_arrow:
module_under_test.dataframe_to_parquet(dataframe, dict_schema, None)

expected_schema_arg = [
schema.SchemaField("field01", "STRING", mode="REQUIRED"),
schema.SchemaField("field02", "BOOL", mode="NULLABLE"),
]
schema_arg = fake_to_arrow.call_args.args[1]
assert schema_arg == expected_schema_arg


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_download_arrow_tabledata_list_unknown_field_type(module_under_test):
fake_page = api_core.page_iterator.Page(
Expand Down Expand Up @@ -977,3 +1055,62 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test):
col = result.columns[1]
assert type(col) is pyarrow.lib.StringArray
assert list(col) == ["2.2", "22.22", "222.222"]


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test):
fake_page = api_core.page_iterator.Page(
parent=mock.Mock(),
items=[{"page_data": "foo"}],
item_to_value=api_core.page_iterator._item_to_value_identity,
)
fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]]
pages = [fake_page]

dict_schema = [
{"name": "population_size", "type": "INTEGER", "mode": "NULLABLE"},
{"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"},
]

results_gen = module_under_test.download_arrow_tabledata_list(pages, dict_schema)
result = next(results_gen)

assert len(result.columns) == 2
col = result.columns[0]
assert type(col) is pyarrow.lib.Int64Array
assert list(col) == [1, 10, 100]
col = result.columns[1]
assert type(col) is pyarrow.lib.StringArray
assert list(col) == ["2.2", "22.22", "222.222"]


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_test):
fake_page = api_core.page_iterator.Page(
parent=mock.Mock(),
items=[{"page_data": "foo"}],
item_to_value=api_core.page_iterator._item_to_value_identity,
)
fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]]
pages = [fake_page]

dict_schema = [
{"name": "population_size", "type": "INTEGER", "mode": "NULLABLE"},
{"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"},
]

results_gen = module_under_test.download_dataframe_tabledata_list(
pages, dict_schema, dtypes={}
)
result = next(results_gen)

expected_result = pandas.DataFrame(
collections.OrderedDict(
[
("population_size", [1, 10, 100]),
("non_alien_field", ["2.2", "22.22", "222.222"]),
]
)
)
assert result.equals(expected_result)

0 comments on commit 65301f7

Please sign in to comment.