diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1dd73f5e26..5edf785c0d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4778,6 +4778,11 @@ class Posexplode(Func): pass +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function +class Predict(Func): + arg_types = {"this": True, "expression": True, "params_struct": False} + + class Pow(Binary, Func): _sql_names = ["POWER", "POW"] diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3cf4a7bd04..216e9c088e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2887,6 +2887,14 @@ def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str: def opclass_sql(self, expression: exp.Opclass) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def predict_sql(self, expression: exp.Predict) -> str: + model = self.sql(expression, "this") + model = f"MODEL {model}" + table = self.sql(expression, "expression") + table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table + parameters = self.sql(expression, "params_struct") + return self.func("PREDICT", model, table, parameters or None) + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/parser.py b/sqlglot/parser.py index ddaa9d6ebe..76b6ee51e1 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -789,6 +789,7 @@ class Parser(metaclass=_Parser): "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), + "PREDICT": lambda self: self._parse_predict(), "SAFE_CAST": lambda self: self._parse_cast(False), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), @@ -4407,6 +4408,20 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) ) + def _parse_predict(self) -> exp.Predict: + self._match_text_seq("MODEL") + this = self._parse_table() + + self._match(TokenType.COMMA) + self._match_text_seq("TABLE") + + return self.expression( + exp.Predict, + this=this, + expression=self._parse_table(), + params_struct=self._match(TokenType.COMMA) and self._parse_bitwise(), + ) + def _parse_join_hint(self, func_name: str) -> exp.JoinHint: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 8d172eaec8..5d17cc90a5 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -85,6 +85,24 @@ def test_bigquery(self): self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("CAST(x AS BIGNUMERIC)") self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))" + ) + self.validate_identity( + "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))" + ) self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" )