Skip to content

Commit

Permalink
feat: use StandardSqlField class for Model.feature_columns and `M…
Browse files Browse the repository at this point in the history
…odel.label_columns` (#1117)
  • Loading branch information
tswast authored Jan 28, 2022
1 parent b67b255 commit 5f50242
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
15 changes: 11 additions & 4 deletions google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import google.cloud._helpers # type: ignore
from google.cloud.bigquery import _helpers
from google.cloud.bigquery import standard_sql
from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration


Expand Down Expand Up @@ -171,26 +172,32 @@ def training_runs(self) -> Sequence[Dict[str, Any]]:
)

@property
def feature_columns(self) -> Sequence[Dict[str, Any]]:
def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Input feature columns that were used to train this model.
Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("featureColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def label_columns(self) -> Sequence[Dict[str, Any]]:
def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Label columns that were used to train this model.
The output of the model will have a ``predicted_`` prefix to these columns.
Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("labelColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def best_trial_id(self) -> Optional[int]:
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,46 @@ def test_build_resource(object_under_test, resource, filter_fields, expected):
assert got == expected


def test_feature_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["featureColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.feature_columns == expected


def test_label_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["labelColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.label_columns == expected


def test_set_description(object_under_test):
assert not object_under_test.description
object_under_test.description = "A model description."
Expand Down

0 comments on commit 5f50242

Please sign in to comment.