Skip to content

Commit

Permalink
Add spark-connect connection method
Browse files Browse the repository at this point in the history
  • Loading branch information
vakarisbk committed Feb 17, 2024
1 parent 5d90ff9 commit 4e7f5d5
Show file tree
Hide file tree
Showing 20 changed files with 220 additions and 53 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231004-191452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add support for Spark Connect
time: 2023-10-04T19:14:52.858895+03:00
custom:
Author: vakarisbk
Issue: "899"
29 changes: 28 additions & 1 deletion dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ def get_spark_container(client: dagger.Client) -> (dagger.Service, str):
return spark_ctr, "spark_db"


def get_spark_connect_container(client: dagger.Client) -> (dagger.Container, str):
spark_ctr_base = (
client.container()
.from_("spark:3.5.0-scala2.12-java17-ubuntu")
.with_exec(
[
"/opt/spark/bin/spark-submit",
"--class",
"org.apache.spark.sql.connect.service.SparkConnectServer",
"--conf",
"spark.sql.catalogImplementation=hive",
"--packages",
"org.apache.spark:spark-connect_2.12:3.5.0",
"--conf",
"spark.jars.ivy=/tmp",
]
)
.with_exposed_port(15002)
.as_service()
)
return spark_ctr_base, "localhost"


async def test_spark(test_args):
async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client:
test_profile = test_args.profile
Expand Down Expand Up @@ -133,7 +156,11 @@ async def test_spark(test_args):
)

elif test_profile == "spark_session":
tst_container = tst_container.with_exec(["pip", "install", "pyspark"])
tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"])

elif test_profile == "spark_connect":
spark_ctr, spark_host = get_spark_connect_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)
tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"])

tst_container = tst_container.with_(env_variables(TESTING_ENV_VARS))
Expand Down
62 changes: 61 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = "http"
ODBC = "odbc"
SESSION = "session"
CONNECT = "connect"


@dataclass
Expand Down Expand Up @@ -154,6 +155,21 @@ def __post_init__(self) -> None:
f"ImportError({e.msg})"
) from e

if self.method == SparkConnectionMethod.CONNECT:
try:
import pyspark # noqa: F401 F811
import grpc # noqa: F401
import pyarrow # noqa: F401
import pandas # noqa: F401
except ImportError as e:
raise dbt.exceptions.DbtRuntimeError(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install dbt-spark[connect]`\n\n"
f"ImportError({e.msg})"
) from e

if self.method != SparkConnectionMethod.SESSION:
self.host = self.host.rstrip("/")

Expand Down Expand Up @@ -524,8 +540,52 @@ def open(cls, connection: Connection) -> Connection:
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(
conn_method=creds.method,
conn_url="localhost",
server_side_parameters=creds.server_side_parameters,
)
)
elif SparkConnectionMethod.CONNECT:
# Create the url

host = creds.host
port = creds.port
token = creds.token
use_ssl = creds.use_ssl
user = creds.user

# URL Format: sc://localhost:15002/;user_id=str;token=str;use_ssl=bool
if not host.startswith("sc://"):
base_url = f"sc://{host}"
base_url += f":{str(port)}"

url_extensions = []
if user:
url_extensions.append(f"user_id={user}")
if use_ssl:
url_extensions.append(f"use_ssl={use_ssl}")
if token:
url_extensions.append(f"token={token}")

conn_url = base_url + ";".join(url_extensions)

logger.debug("connection url: {}".format(conn_url))

from .session import ( # noqa: F401
Connection,
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
Connection(
conn_method=creds.method,
conn_url=conn_url,
server_side_parameters=creds.server_side_parameters,
)
)
else:
raise DbtConfigError(f"invalid credential method: {creds.method}")
Expand Down
35 changes: 29 additions & 6 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence

from dbt.adapters.spark.connections import SparkConnectionWrapper
from dbt.adapters.spark.connections import SparkConnectionMethod, SparkConnectionWrapper
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.utils.encoding import DECIMALS
from dbt_common.exceptions import DbtRuntimeError
Expand All @@ -27,9 +27,17 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[str, Any]] = None,
) -> None:
self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.conn_method: SparkConnectionMethod = conn_method
self.conn_url: str = conn_url
self.server_side_parameters = server_side_parameters or {}

def __enter__(self) -> Cursor:
Expand Down Expand Up @@ -113,12 +121,15 @@ def execute(self, sql: str, *parameters: Any) -> None:
if len(parameters) > 0:
sql = sql % parameters

builder = SparkSession.builder.enableHiveSupport()
builder = SparkSession.builder

for parameter, value in self.server_side_parameters.items():
builder = builder.config(parameter, value)

spark_session = builder.getOrCreate()
if self.conn_method == SparkConnectionMethod.CONNECT:
spark_session = builder.remote(self.conn_url).getOrCreate()
elif self.conn_method == SparkConnectionMethod.SESSION:
spark_session = builder.enableHiveSupport().getOrCreate()

try:
self._df = spark_session.sql(sql)
Expand Down Expand Up @@ -175,7 +186,15 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[Any, str]] = None,
) -> None:
self.conn_method = conn_method
self.conn_url = conn_url
self.server_side_parameters = server_side_parameters or {}

def cursor(self) -> Cursor:
Expand All @@ -187,7 +206,11 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor(server_side_parameters=self.server_side_parameters)
return Cursor(
conn_method=self.conn_method,
conn_url=self.conn_url,
server_side_parameters=self.server_side_parameters,
)


class SessionConnectionWrapper(SparkConnectionWrapper):
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ sqlparams>=3.0.0
thrift>=0.13.0
sqlparse>=0.4.2 # not directly required, pinned by Snyk to avoid a vulnerability

#spark-connect
pyspark[connect]>=3.5.0,<4

types-PyYAML
types-python-dateutil
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,16 @@ def _get_plugin_version_dict():
"thrift>=0.11.0,<0.17.0",
]
session_extras = ["pyspark>=3.0.0,<4.0.0"]
all_extras = odbc_extras + pyhive_extras + session_extras
connect_extras = [
"pyspark==3.5.0",
"pandas>=1.05",
"pyarrow>=4.0.0",
"numpy>=1.15",
"grpcio>=1.46,<1.57",
"grpcio-status>=1.46,<1.57",
"googleapis-common-protos==1.56.4",
]
all_extras = odbc_extras + pyhive_extras + session_extras + connect_extras

setup(
name=package_name,
Expand All @@ -71,6 +80,7 @@ def _get_plugin_version_dict():
"ODBC": odbc_extras,
"PyHive": pyhive_extras,
"session": session_extras,
"connect": connect_extras,
"all": all_extras,
},
zip_safe=False,
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dbt_profile_target(request):
target = databricks_http_cluster_target()
elif profile_type == "spark_session":
target = spark_session_target()
elif profile_type == "spark_connect":
target = spark_connect_target()
else:
raise ValueError(f"Invalid profile type '{profile_type}'")
return target
Expand Down Expand Up @@ -95,11 +97,11 @@ def databricks_http_cluster_target():


def spark_session_target():
return {
"type": "spark",
"host": "localhost",
"method": "session",
}
return {"type": "spark", "host": "localhost", "method": "session"}


def spark_connect_target():
return {"type": "spark", "host": "localhost", "port": 15002, "method": "connect"}


@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/dbt_clone/test_dbt_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestSparkBigqueryClonePossible(BaseClonePossible):
@pytest.fixture(scope="class")
def models(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestMergeExcludeColumns(BaseMergeExcludeColumns):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_connect")
class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -45,7 +45,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestDeltaOnSchemaChange(BaseIncrementalOnSchemaChangeSetup):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestIncrementalPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -46,7 +46,7 @@ def models(self):
}


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dbt.tests.adapter.incremental.test_incremental_unique_id import BaseIncrementalUniqueKey


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestUniqueKeySpark(BaseIncrementalUniqueKey):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"])

@pytest.mark.skip_profile(
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"apache_spark",
"databricks_http_cluster",
"databricks_sql_endpoint",
"spark_session",
"spark_connect",
)
def test_delta_strategies(self, project):
self.run_and_test(project)
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/persist_docs/test_persist_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaTable:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_delta_comments(self, project):
assert result[2].startswith("Some stuff here and then a call to")


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaView:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_delta_comments(self, project):
assert result[2] is None


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsMissingColumn:
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/adapter/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TestGenericTestsSpark(BaseGenericTests):

# These tests were not enabled in the dbtspec files, so skipping here.
# Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource
@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestSnapshotCheckColsSpark(BaseSnapshotCheckCols):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -66,7 +66,7 @@ def project_config_update(self):

# These tests were not enabled in the dbtspec files, so skipping here.
# Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource
@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestSnapshotTimestampSpark(BaseSnapshotTimestamp):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Loading

0 comments on commit 4e7f5d5

Please sign in to comment.