From f4d1db0e1bad81c1f0fc565bfbc38b1d5be9bdae Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Wed, 22 Jan 2020 13:49:18 +0000 Subject: [PATCH] test(bigquery): add tests for concatenating categorical columns --- bigquery/tests/unit/test_table.py | 168 ++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) diff --git a/bigquery/tests/unit/test_table.py b/bigquery/tests/unit/test_table.py index 6e8958cdc46c..079ec6e000d3 100644 --- a/bigquery/tests/unit/test_table.py +++ b/bigquery/tests/unit/test_table.py @@ -3242,6 +3242,174 @@ def test_to_dataframe_w_bqstorage_snapshot(self): with pytest.raises(ValueError): row_iterator.to_dataframe(bqstorage_client) + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf( + bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`" + ) + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_to_dataframe_concat_categorical_dtype_w_pyarrow(self): + from google.cloud.bigquery import schema + from google.cloud.bigquery import table as mut + from google.cloud.bigquery_storage_v1beta1 import reader + + arrow_fields = [ + # Not alphabetical to test column order. + pyarrow.field("col_str", pyarrow.utf8()), + # The backend returns strings, and without other info, pyarrow contains + # string data in categorical columns, too (and not maybe the Dictionary + # type that corresponds to pandas.Categorical). + pyarrow.field("col_category", pyarrow.utf8()), + ] + arrow_schema = pyarrow.schema(arrow_fields) + + # create a mock BQ storage client + bqstorage_client = mock.create_autospec( + bigquery_storage_v1beta1.BigQueryStorageClient + ) + bqstorage_client.transport = mock.create_autospec( + big_query_storage_grpc_transport.BigQueryStorageGrpcTransport + ) + session = bigquery_storage_v1beta1.types.ReadSession( + streams=[{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"}], + arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()}, + ) + bqstorage_client.create_read_session.return_value = session + + mock_rowstream = mock.create_autospec(reader.ReadRowsStream) + bqstorage_client.read_rows.return_value = mock_rowstream + + # prepare the iterator over mocked rows + mock_rows = mock.create_autospec(reader.ReadRowsIterable) + mock_rowstream.rows.return_value = mock_rows + page_items = [ + [ + pyarrow.array(["foo", "bar", "baz"]), # col_str + pyarrow.array(["low", "medium", "low"]), # col_category + ], + [ + pyarrow.array(["foo_page2", "bar_page2", "baz_page2"]), # col_str + pyarrow.array(["medium", "high", "low"]), # col_category + ], + ] + + mock_pages = [] + + for record_list in page_items: + page_record_batch = pyarrow.RecordBatch.from_arrays( + record_list, schema=arrow_schema + ) + mock_page = mock.create_autospec(reader.ReadRowsPage) + mock_page.to_arrow.return_value = page_record_batch + mock_pages.append(mock_page) + + type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages) + + schema = [ + schema.SchemaField("col_str", "IGNORED"), + schema.SchemaField("col_category", "IGNORED"), + ] + + row_iterator = mut.RowIterator( + _mock_client(), + None, # api_request: ignored + None, # path: ignored + schema, + table=mut.TableReference.from_string("proj.dset.tbl"), + selected_fields=schema, + ) + + # run the method under test + got = row_iterator.to_dataframe( + bqstorage_client=bqstorage_client, + dtypes={ + "col_category": pandas.core.dtypes.dtypes.CategoricalDtype( + categories=["low", "medium", "high"], ordered=False, + ), + }, + ) + + # Are the columns in the expected order? + column_names = ["col_str", "col_category"] + self.assertEqual(list(got), column_names) + + # Have expected number of rows? + total_pages = len(mock_pages) # we have a single stream, thus these two equal + total_rows = len(page_items[0][0]) * total_pages + self.assertEqual(len(got.index), total_rows) + + # Are column types correct? + expected_dtypes = [ + pandas.core.dtypes.dtypes.np.dtype("O"), # the default for string data + pandas.core.dtypes.dtypes.CategoricalDtype( + categories=["low", "medium", "high"], ordered=False, + ), + ] + self.assertEqual(list(got.dtypes), expected_dtypes) + + # And the data in the categorical column? + self.assertEqual( + list(got["col_category"]), + ["low", "medium", "low", "medium", "high", "low"], + ) + + # Don't close the client if it was passed in. + bqstorage_client.transport.channel.close.assert_not_called() + + @unittest.skipIf(pandas is None, "Requires `pandas`") + def test_to_dataframe_concat_categorical_dtype_wo_pyarrow(self): + from google.cloud.bigquery.schema import SchemaField + + schema = [ + SchemaField("col_str", "STRING"), + SchemaField("col_category", "STRING"), + ] + row_data = [ + [u"foo", u"low"], + [u"bar", u"medium"], + [u"baz", u"low"], + [u"foo_page2", u"medium"], + [u"bar_page2", u"high"], + [u"baz_page2", u"low"], + ] + path = "/foo" + + rows = [{"f": [{"v": field} for field in row]} for row in row_data[:3]] + rows_page2 = [{"f": [{"v": field} for field in row]} for row in row_data[3:]] + api_request = mock.Mock( + side_effect=[{"rows": rows, "pageToken": "NEXTPAGE"}, {"rows": rows_page2}] + ) + + row_iterator = self._make_one(_mock_client(), api_request, path, schema) + + with mock.patch("google.cloud.bigquery.table.pyarrow", None): + got = row_iterator.to_dataframe( + dtypes={ + "col_category": pandas.core.dtypes.dtypes.CategoricalDtype( + categories=["low", "medium", "high"], ordered=False, + ), + }, + ) + + self.assertIsInstance(got, pandas.DataFrame) + self.assertEqual(len(got), 6) # verify the number of rows + expected_columns = [field.name for field in schema] + self.assertEqual(list(got), expected_columns) # verify the column names + + # Are column types correct? + expected_dtypes = [ + pandas.core.dtypes.dtypes.np.dtype("O"), # the default for string data + pandas.core.dtypes.dtypes.CategoricalDtype( + categories=["low", "medium", "high"], ordered=False, + ), + ] + self.assertEqual(list(got.dtypes), expected_dtypes) + + # And the data in the categorical column? + self.assertEqual( + list(got["col_category"]), + ["low", "medium", "low", "medium", "high", "low"], + ) + class TestPartitionRange(unittest.TestCase): def _get_target_class(self):