Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ibis): Support pyspark sql data source #896

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class QueryPostgresDTO(QueryDTO):
connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field


class QueryPySparkDTO(QueryDTO):
connection_info: ConnectionUrl | PySparkConnectionInfo = connection_info_field


class QuerySnowflakeDTO(QueryDTO):
connection_info: SnowflakeConnectionInfo = connection_info_field

Expand Down Expand Up @@ -109,6 +113,17 @@ class PostgresConnectionInfo(BaseModel):
password: SecretStr


class PySparkConnectionInfo(BaseModel):
app_name: SecretStr = Field(examples=["wrenai"])
master: SecretStr = Field(
default="local[*]",
description="Spark master URL (e.g., 'local[*]', 'spark://master:7077')",
)
configs: dict[str, str] | None = Field(
default=None, description="Additional Spark configurations"
)


class SnowflakeConnectionInfo(BaseModel):
user: SecretStr
password: SecretStr
Expand Down Expand Up @@ -137,6 +152,7 @@ class TrinoConnectionInfo(BaseModel):
| MSSqlConnectionInfo
| MySqlConnectionInfo
| PostgresConnectionInfo
| PySparkConnectionInfo
| SnowflakeConnectionInfo
| TrinoConnectionInfo
)
Expand Down
19 changes: 19 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
from google.oauth2 import service_account
from ibis import BaseBackend
from pyspark.sql import SparkSession

from app.model import (
BigQueryConnectionInfo,
Expand All @@ -16,13 +17,15 @@
MSSqlConnectionInfo,
MySqlConnectionInfo,
PostgresConnectionInfo,
PySparkConnectionInfo,
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
QueryDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryPostgresDTO,
QueryPySparkDTO,
QuerySnowflakeDTO,
QueryTrinoDTO,
SnowflakeConnectionInfo,
Expand All @@ -37,6 +40,7 @@ class DataSource(StrEnum):
mssql = auto()
mysql = auto()
postgres = auto()
pyspark = auto()
snowflake = auto()
trino = auto()

Expand All @@ -60,6 +64,7 @@ class DataSourceExtension(Enum):
mssql = QueryMSSqlDTO
mysql = QueryMySqlDTO
postgres = QueryPostgresDTO
pyspark = QueryPySparkDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO

Expand Down Expand Up @@ -143,6 +148,20 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
password=info.password.get_secret_value(),
)

@staticmethod
def get_pyspark_connection(info: PySparkConnectionInfo) -> BaseBackend:
builder = SparkSession.builder.appName(info.app_name.get_secret_value()).master(
info.master.get_secret_value()
)

if info.configs:
for key, value in info.configs.items():
builder = builder.config(key, value)

# Create or get existing Spark session
spark_session = builder.getOrCreate()
return ibis.pyspark.connect(session=spark_session)

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
Expand Down
37 changes: 35 additions & 2 deletions ibis-server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ibis-framework = { version = "9.5.0", extras = [
"mssql",
"mysql",
"postgres",
"pyspark",
"snowflake",
"trino",
] }
Expand All @@ -42,6 +43,7 @@ sqlalchemy = "2.0.36"
pre-commit = "4.0.1"
ruff = "0.8.0"
trino = ">=0.321,<1"
pyspark = "3.5.1"
psycopg2 = ">=2.8.4,<3"
clickhouse-connect = "0.8.7"

Expand All @@ -54,6 +56,7 @@ markers = [
"mssql: mark a test as a mssql test",
"mysql: mark a test as a mysql test",
"postgres: mark a test as a postgres test",
"pyspark: mark a test as a pyspark test",
"snowflake: mark a test as a snowflake test",
"trino: mark a test as a trino test",
"beta: mark a test as a test for beta versions of the engine",
Expand Down
191 changes: 191 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_pyspark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import base64

# import os
import orjson
import pytest
from fastapi.testclient import TestClient

from app.main import app
from app.model.validator import rules

pytestmark = pytest.mark.pyspark

base_url = "/v2/connector/pyspark"

connection_info = {
"app_name": "MyApp",
"master": "local",
}

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"properties": {},
"refSql": "select * from tpch.orders",
"columns": [
{"name": "orderkey", "expression": "O_ORDERKEY", "type": "integer"},
{"name": "custkey", "expression": "O_CUSTKEY", "type": "integer"},
{
"name": "orderstatus",
"expression": "O_ORDERSTATUS",
"type": "varchar",
},
{
"name": "totalprice",
"expression": "O_TOTALPRICE",
"type": "float",
},
{"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"},
{
"name": "order_cust_key",
"expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)",
"type": "varchar",
},
{
"name": "timestamp",
"expression": "cast('2024-01-01T23:59:59' as timestamp)",
"type": "timestamp",
},
{
"name": "timestamptz",
"expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "cast(NULL as timestamp)",
"type": "timestamp",
},
],
"primaryKey": "orderkey",
},
],
}


@pytest.fixture
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


with TestClient(app) as client:
# def test_query(manifest_str):
# response = client.post(
# url=f"{base_url}/query",
# json={
# "connectionInfo": connection_info,
# "manifestStr": manifest_str,
# "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1',
# },
# )
# assert response.status_code == 200
# result = response.json()
# assert len(result["columns"]) == len(manifest["models"][0]["columns"])
# assert len(result["data"]) == 1
# assert result["data"][0] == [
# 1,
# 36901,
# "O",
# "173665.47",
# "1996-01-02",
# "1_36901",
# "2024-01-01 23:59:59.000000",
# "2024-01-01 23:59:59.000000 UTC",
# None,
# ]
# assert result["dtypes"] == {
# "orderkey": "int64",
# "custkey": "int64",
# "orderstatus": "object",
# "totalprice": "object",
# "orderdate": "object",
# "order_cust_key": "object",
# "timestamp": "object",
# "timestamptz": "object",
# "test_null_time": "datetime64[ns]",
# }

def test_query_without_manifest():
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "manifestStr"]
assert result["detail"][0]["msg"] == "Field required"

def test_query_without_sql(manifest_str):
response = client.post(
url=f"{base_url}/query",
json={"connectionInfo": connection_info, "manifestStr": manifest_str},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "sql"]
assert result["detail"][0]["msg"] == "Field required"

def test_query_without_connection_info(manifest_str):
response = client.post(
url=f"{base_url}/query",
json={
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"

# def test_query_with_dry_run(manifest_str):
# response = client.post(
# url=f"{base_url}/query",
# params={"dryRun": True},
# json={
# "connectionInfo": connection_info,
# "manifestStr": manifest_str,
# "sql": 'SELECT * FROM "Orders" LIMIT 1',
# },
# )
# assert response.status_code == 204

def test_query_with_dry_run_and_invalid_sql(manifest_str):
response = client.post(
url=f"{base_url}/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM X",
},
)
assert response.status_code == 422
assert response.text is not None

def test_validate_with_unknown_rule(manifest_str):
response = client.post(
url=f"{base_url}/validate/unknown_rule",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"parameters": {"modelName": "Orders", "columnName": "orderkey"},
},
)
assert response.status_code == 404
assert (
response.text
== f"The rule `unknown_rule` is not in the rules, rules: {rules}"
)
Loading