Skip to content

Commit

Permalink
BigQuery: Autofetch table schema on load if not provided (googleapis#…
Browse files Browse the repository at this point in the history
…9108)

* Autofetch table schema on load if not provided

* Avoid fetching table schema if WRITE_TRUNCATE job

* Skip dataframe columns list check

A similar check is already performed on the server, and server-side
errors are preferred to client errors.

* Raise table NotFound in auto Pandas schema tests

A mock should raise this error instead of returning a table to
trigger schema generation from Pandas dtypes.

* Use list_columns_and_indexes() for names list
  • Loading branch information
plamut authored and emar-kar committed Sep 18, 2019
1 parent e4fe8b7 commit 6031f06
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 7 deletions.
2 changes: 1 addition & 1 deletion bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def list_columns_and_indexes(dataframe):
"""Return all index and column names with dtypes.
Returns:
Sequence[Tuple[dtype, str]]:
Sequence[Tuple[str, dtype]]:
Returns a sorted list of indexes and column names with
corresponding dtypes. If an index is missing a name or has the
same name as a column, the index is omitted.
Expand Down
21 changes: 21 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,27 @@ def load_table_from_dataframe(
if location is None:
location = self.location

# If table schema is not provided, we try to fetch the existing table
# schema, and check if dataframe schema is compatible with it - except
# for WRITE_TRUNCATE jobs, the existing schema does not matter then.
if (
not job_config.schema
and job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
):
try:
table = self.get_table(destination)
except google.api_core.exceptions.NotFound:
table = None
else:
columns_and_indexes = frozenset(
name
for name, _ in _pandas_helpers.list_columns_and_indexes(dataframe)
)
# schema fields not present in the dataframe are not needed
job_config.schema = [
field for field in table.schema if field.name in columns_and_indexes
]

job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, job_config.schema
)
Expand Down
147 changes: 141 additions & 6 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import gzip
import io
import json
import operator
import unittest
import warnings

Expand Down Expand Up @@ -5279,15 +5280,23 @@ def test_load_table_from_file_bad_mode(self):
def test_load_table_from_dataframe(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5314,15 +5323,23 @@ def test_load_table_from_dataframe(self):
def test_load_table_from_dataframe_w_client_location(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client(location=self.LOCATION)
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5349,20 +5366,33 @@ def test_load_table_from_dataframe_w_client_location(self):
def test_load_table_from_dataframe_w_custom_job_config(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)
job_config = job.LoadJobConfig()
job_config = job.LoadJobConfig(
write_disposition=job.WriteDisposition.WRITE_TRUNCATE
)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch as get_table:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
)

# no need to fetch and inspect table schema for WRITE_TRUNCATE jobs
assert not get_table.called

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
Expand All @@ -5378,6 +5408,7 @@ def test_load_table_from_dataframe_w_custom_job_config(self):

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET
assert sent_config.write_disposition == job.WriteDisposition.WRITE_TRUNCATE

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
Expand Down Expand Up @@ -5421,7 +5452,12 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

with load_patch as load_table_from_file:
get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)
Expand Down Expand Up @@ -5449,6 +5485,100 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
SchemaField("ts_col", "TIMESTAMP"),
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_index_and_auto_schema(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
df_data = collections.OrderedDict(
[("int_col", [10, 20, 30]), ("float_col", [1.0, 2.0, 3.0])]
)
dataframe = pandas.DataFrame(
df_data,
index=pandas.Index(name="unique_name", data=["one", "two", "three"]),
)

load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("unique_name", "STRING"),
]
),
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=self.LOCATION,
project=None,
job_config=mock.ANY,
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET

sent_schema = sorted(sent_config.schema, key=operator.attrgetter("name"))
expected_sent_schema = [
SchemaField("float_col", "FLOAT"),
SchemaField("int_col", "INTEGER"),
SchemaField("unique_name", "STRING"),
]
assert sent_schema == expected_sent_schema

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_unknown_table(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file, get_table_patch:
# there should be no error
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=None,
project=None,
job_config=mock.ANY,
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_struct_fields_error(self):
Expand Down Expand Up @@ -5741,6 +5871,11 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
Expand All @@ -5749,7 +5884,7 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
dataframe, "to_parquet", wraps=dataframe.to_parquet
)

with load_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
with load_patch, get_table_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
client.load_table_from_dataframe(
dataframe,
self.TABLE_REF,
Expand Down

0 comments on commit 6031f06

Please sign in to comment.