Skip to content

Commit

Permalink
Feat(bigquery): add support for ML.PREDICT function (#2375)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Oct 4, 2023
1 parent 347ac51 commit 160d841
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4778,6 +4778,11 @@ class Posexplode(Func):
pass


# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function
class Predict(Func):
arg_types = {"this": True, "expression": True, "params_struct": False}


class Pow(Binary, Func):
_sql_names = ["POWER", "POW"]

Expand Down
8 changes: 8 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,6 +2887,14 @@ def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str:
def opclass_sql(self, expression: exp.Opclass) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"

def predict_sql(self, expression: exp.Predict) -> str:
model = self.sql(expression, "this")
model = f"MODEL {model}"
table = self.sql(expression, "expression")
table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table
parameters = self.sql(expression, "params_struct")
return self.func("PREDICT", model, table, parameters or None)


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
Expand Down
15 changes: 15 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ class Parser(metaclass=_Parser):
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
"PREDICT": lambda self: self._parse_predict(),
"SAFE_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
Expand Down Expand Up @@ -4407,6 +4408,20 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2)
)

def _parse_predict(self) -> exp.Predict:
self._match_text_seq("MODEL")
this = self._parse_table()

self._match(TokenType.COMMA)
self._match_text_seq("TABLE")

return self.expression(
exp.Predict,
this=this,
expression=self._parse_table(),
params_struct=self._match(TokenType.COMMA) and self._parse_bitwise(),
)

def _parse_join_hint(self, func_name: str) -> exp.JoinHint:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
Expand Down
18 changes: 18 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ def test_bigquery(self):
self.validate_identity("ROLLBACK TRANSACTION")
self.validate_identity("CAST(x AS BIGNUMERIC)")
self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1")
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
)
self.validate_identity(
"SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
)
self.validate_identity(
"DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')"
)
Expand Down

0 comments on commit 160d841

Please sign in to comment.