diff --git a/google/cloud/bigquery/model.py b/google/cloud/bigquery/model.py index 52fe6276e..4d2bc346c 100644 --- a/google/cloud/bigquery/model.py +++ b/google/cloud/bigquery/model.py @@ -23,6 +23,7 @@ import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers +from google.cloud.bigquery import standard_sql from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration @@ -171,26 +172,32 @@ def training_runs(self) -> Sequence[Dict[str, Any]]: ) @property - def feature_columns(self) -> Sequence[Dict[str, Any]]: + def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]: """Input feature columns that were used to train this model. Read-only. """ - return typing.cast( + resource: Sequence[Dict[str, Any]] = typing.cast( Sequence[Dict[str, Any]], self._properties.get("featureColumns", []) ) + return [ + standard_sql.StandardSqlField.from_api_repr(column) for column in resource + ] @property - def label_columns(self) -> Sequence[Dict[str, Any]]: + def label_columns(self) -> Sequence[standard_sql.StandardSqlField]: """Label columns that were used to train this model. The output of the model will have a ``predicted_`` prefix to these columns. Read-only. """ - return typing.cast( + resource: Sequence[Dict[str, Any]] = typing.cast( Sequence[Dict[str, Any]], self._properties.get("labelColumns", []) ) + return [ + standard_sql.StandardSqlField.from_api_repr(column) for column in resource + ] @property def best_trial_id(self) -> Optional[int]: diff --git a/tests/unit/model/test_model.py b/tests/unit/model/test_model.py index 3cc1dd4c4..1ae988414 100644 --- a/tests/unit/model/test_model.py +++ b/tests/unit/model/test_model.py @@ -273,6 +273,46 @@ def test_build_resource(object_under_test, resource, filter_fields, expected): assert got == expected +def test_feature_columns(object_under_test): + from google.cloud.bigquery import standard_sql + + object_under_test._properties["featureColumns"] = [ + {"name": "col_1", "type": {"typeKind": "STRING"}}, + {"name": "col_2", "type": {"typeKind": "FLOAT64"}}, + ] + expected = [ + standard_sql.StandardSqlField( + "col_1", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING), + ), + standard_sql.StandardSqlField( + "col_2", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64), + ), + ] + assert object_under_test.feature_columns == expected + + +def test_label_columns(object_under_test): + from google.cloud.bigquery import standard_sql + + object_under_test._properties["labelColumns"] = [ + {"name": "col_1", "type": {"typeKind": "STRING"}}, + {"name": "col_2", "type": {"typeKind": "FLOAT64"}}, + ] + expected = [ + standard_sql.StandardSqlField( + "col_1", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING), + ), + standard_sql.StandardSqlField( + "col_2", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64), + ), + ] + assert object_under_test.label_columns == expected + + def test_set_description(object_under_test): assert not object_under_test.description object_under_test.description = "A model description."