Skip to content

Commit

Permalink
Autofetch table schema on load if not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Aug 27, 2019
1 parent 86bb5cf commit adeb233
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 5 deletions.
23 changes: 23 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,29 @@ def load_table_from_dataframe(
if location is None:
location = self.location

# If table schema is not provided, we try to fetch the existing table
# schema, and check if dataframe schema is compatible with it.
if not job_config.schema:
try:
table = self.get_table(destination)
except google.api_core.exceptions.NotFound:
table = None
else:
table_col_names = {field.name for field in table.schema}
dframe_col_names = set(dataframe.columns)

in_dframe_only = dframe_col_names - table_col_names
if in_dframe_only:
raise ValueError(
"Dataframe contains columns that are not present in "
"table: {}".format(in_dframe_only)
)

# schema fields not present in the dataframe are not needed
job_config.schema = [
field for field in table.schema if field.name in dframe_col_names
]

job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, job_config.schema
)
Expand Down
121 changes: 116 additions & 5 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5228,15 +5228,23 @@ def test_load_table_from_file_bad_mode(self):
def test_load_table_from_dataframe(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5263,15 +5271,23 @@ def test_load_table_from_dataframe(self):
def test_load_table_from_dataframe_w_client_location(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client(location=self.LOCATION)
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5298,16 +5314,24 @@ def test_load_table_from_dataframe_w_client_location(self):
def test_load_table_from_dataframe_w_custom_job_config(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)
job_config = job.LoadJobConfig()

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
)
Expand Down Expand Up @@ -5370,7 +5394,20 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

with load_patch as load_table_from_file:
get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("bool_col", "BOOLEAN"),
SchemaField("dt_col", "DATETIME"),
SchemaField("ts_col", "TIMESTAMP"),
]
),
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)
Expand Down Expand Up @@ -5398,6 +5435,71 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
SchemaField("ts_col", "TIMESTAMP"),
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_unknown_df_columns(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "typo_age": 100}, {"id": 2, "typo_age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with pytest.raises(ValueError) as exc_info, load_patch, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

err_msg = str(exc_info.value)
assert "Dataframe contains columns that are not present in table" in err_msg
assert "typo_age" in err_msg
assert "id" not in err_msg

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_unknown_table(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file, get_table_patch:
# there should be no error
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=None,
project=None,
job_config=mock.ANY,
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_struct_fields_error(self):
Expand Down Expand Up @@ -5686,10 +5788,19 @@ def test_load_table_from_dataframe_w_schema_arrow_custom_compression(self):
@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
Expand All @@ -5698,7 +5809,7 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
dataframe, "to_parquet", wraps=dataframe.to_parquet
)

with load_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
with load_patch, get_table_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
client.load_table_from_dataframe(
dataframe,
self.TABLE_REF,
Expand Down

0 comments on commit adeb233

Please sign in to comment.