Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqllab): add latest partition support for BigQuery #30760

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 77 additions & 47 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@

from __future__ import annotations

import logging
import re
import urllib
from datetime import datetime
from re import Pattern
from textwrap import dedent
from typing import Any, TYPE_CHECKING, TypedDict

import pandas as pd
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy import column, types
from sqlalchemy import column, func, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
Expand All @@ -49,6 +50,11 @@
from superset.utils import core as utils, json
from superset.utils.hashing import md5_sha_from_str

if TYPE_CHECKING:
from sqlalchemy.sql.expression import Select

logger = logging.getLogger(__name__)

try:
from google.cloud import bigquery
from google.oauth2 import service_account
Expand Down Expand Up @@ -284,66 +290,90 @@ def _truncate_label(cls, label: str) -> str:
return "_" + md5_sha_from_str(label)

@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
def where_latest_partition(
cls,
database: Database,
table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
if partition_column := cls.get_time_partition_column(database, table):
max_partition_id = cls.get_max_partition_id(database, table)
query = query.where(
column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id)
)

:param indexes: Raw indexes as returned by SQLAlchemy
:return: cleaner, more aligned index definition
"""
normalized_idxs = []
# Fixing a bug/behavior observed in pybigquery==0.4.15 where
# the index's `column_names` == [None]
# Here we're returning only non-None indexes
for ix in indexes:
column_names = ix.get("column_names") or []
ix["column_names"] = [col for col in column_names if col is not None]
if ix["column_names"]:
normalized_idxs.append(ix)
return normalized_idxs
return query

@classmethod
def get_indexes(
def get_max_partition_id(
cls,
database: Database,
inspector: Inspector,
table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
) -> Select | None:
sql = dedent(f"""\
SELECT
MAX(partition_id) AS max_partition_id
FROM `{table.schema}.INFORMATION_SCHEMA.PARTITIONS`
WHERE table_name = '{table.table}'
""")
df = database.get_df(sql)
return df.iat[0, 0]

:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table: The table instance to inspect
:returns: The indexes
"""
@classmethod
def get_time_partition_column(
cls,
database: Database,
table: Table,
) -> str | None:
with cls.get_engine(
database, catalog=table.catalog, schema=table.schema
) as engine:
client = cls._get_client(engine, database)
bq_table = client.get_table(f"{table.schema}.{table.table}")

return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
if bq_table.time_partitioning:
return bq_table.time_partitioning.field
return None

@classmethod
def get_extra_table_metadata(
cls,
database: Database,
table: Table,
) -> dict[str, Any]:
indexes = database.get_indexes(table)
if not indexes:
return {}
partitions_columns = [
index.get("column_names", [])
for index in indexes
if index.get("name") == "partition"
]
cluster_columns = [
index.get("column_names", [])
for index in indexes
if index.get("name") == "clustering"
]
return {
"partitions": {"cols": partitions_columns},
"clustering": {"cols": cluster_columns},
}
payload = {}
partition_column = cls.get_time_partition_column(database, table)
with cls.get_engine(
database, catalog=table.catalog, schema=table.schema
) as engine:
if partition_column:
max_partition_id = cls.get_max_partition_id(database, table)
sql = cls.select_star(
database,
table,
engine,
indent=False,
show_cols=False,
latest_partition=True,
)
payload.update(
{
"partitions": {
"cols": [partition_column],
"latest": {partition_column: max_partition_id},
"partitionQuery": sql,
},
"indexes": [
{
"name": "partitioned",
"cols": [partition_column],
"type": "partitioned",
}
],
}
)
return payload

@classmethod
def epoch_to_dttm(cls) -> str:
Expand Down
133 changes: 40 additions & 93 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from contextlib import contextmanager

import pytest
from pandas import DataFrame
Expand All @@ -32,6 +33,15 @@
)


@contextmanager
def mock_engine_with_credentials(*args, **kwargs):
engine_mock = mock.Mock()
engine_mock.dialect.credentials_info = {
"key": "value"
} # Add the credentials_info attribute
yield engine_mock


class TestBigQueryDbEngineSpec(TestDbEngineSpec):
def test_bigquery_sqla_column_label(self):
"""
Expand Down Expand Up @@ -111,108 +121,45 @@ def values(self):
result = BigQueryEngineSpec.fetch_data(None, 0)
assert result == [1, 2]

def test_get_extra_table_metadata(self):
@mock.patch.object(
BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials
)
@mock.patch.object(BigQueryEngineSpec, "get_time_partition_column")
@mock.patch.object(BigQueryEngineSpec, "get_max_partition_id")
@mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`")
def test_get_extra_table_metadata(
self,
mock_quote_table,
mock_get_max_partition_id,
mock_get_time_partition_column,
mock_get_engine,
):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
database = mock.Mock()
sql = "SELECT * FROM `table_name`"
database.compile_sqla_query.return_value = sql
tbl = Table("some_table", "some_schema")

# Test no indexes
database.get_indexes = mock.MagicMock(return_value=None)
result = BigQueryEngineSpec.get_extra_table_metadata(
database,
Table("some_table", "some_schema"),
)
mock_get_time_partition_column.return_value = None
mock_get_max_partition_id.return_value = None
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
assert result == {}

index_metadata = [
{
"name": "clustering",
"column_names": ["c_col1", "c_col2", "c_col3"],
},
{
"name": "partition",
"column_names": ["p_col1", "p_col2", "p_col3"],
mock_get_time_partition_column.return_value = "ds"
mock_get_max_partition_id.return_value = "19690101"
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
print(result)
assert result == {
"indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}],
"partitions": {
"cols": ["ds"],
"latest": {"ds": "19690101"},
"partitionQuery": sql,
},
]
expected_result = {
"partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
"clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
}
database.get_indexes = mock.MagicMock(return_value=index_metadata)
result = BigQueryEngineSpec.get_extra_table_metadata(
database,
Table("some_table", "some_schema"),
)
assert result == expected_result

def test_get_indexes(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

database = mock.Mock()
inspector = mock.Mock()
schema = "foo"
table_name = "bar"

inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": [None],
"unique": False,
}
]
)

assert (
BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
)
== []
)

inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
)

assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]

inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm", None],
"unique": False,
}
]
)

assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]

@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
Expand Down
Loading