Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server side parameters to session connection method #823

Merged
merged 10 commits into from
Jul 24, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230707-104150.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support server_side_parameters for Spark session connection method
time: 2023-07-07T10:41:50.01541+02:00
custom:
Author: alarocca-apixio
Issue: "690"
5 changes: 4 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def open(cls, connection: Connection) -> Connection:

creds = connection.credentials
exc = None
handle: Any
mikealfare marked this conversation as resolved.
Show resolved Hide resolved

for i in range(1 + creds.connect_retries):
try:
Expand Down Expand Up @@ -460,7 +461,9 @@ def open(cls, connection: Connection) -> Connection:
SessionConnectionWrapper,
)

handle = SessionConnectionWrapper(Connection()) # type: ignore
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
17 changes: 13 additions & 4 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datetime as dt
from types import TracebackType
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from dbt.events import AdapterLogger
from dbt.utils import DECIMALS
Expand All @@ -24,9 +24,10 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self) -> None:
def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unfamiliar with this mechanic. What does the *, do in the signature? Is it similar to unpacking with my_arg, *_? Do we need it?

Copy link
Collaborator Author

@JCZuurmond JCZuurmond Jul 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is similar to the unpacking you mentioned. I guess that in a signature this mechanic is called "packing" as it is the opposite of unpacking. (I could not quickly find a PEP about this.)

What it does: it packs all positional argument. In this case, all positional arguments are packed into nothing. If you add a parameter after the * you pack all positional arguments into that parameter (as a tuple):

def sum(*numbers):
    total = 0
    for number in numbers:
        total += number
    return total
    
sum(1, 2, 3)
sum(5, 7, 9, 11)

The trick here is that the * without an argument forces all arguments after it to become key-word arguments. I like to use that to improve readability:

Connection(foo)

vs

Connection(server_side_parameters=foo)

The later is more readable.

foo is of course a badly chosen variable name. But, I expect the first positional argument of a Connection to be a connection_string, like conn = pyodbc.connect(connection_str, autocommit=True), or something other than server_side_paramters which is an (optional) additional parameter.

The * is not required, I added it to improve code readability.

self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.server_side_parameters = server_side_parameters or {}

def __enter__(self) -> Cursor:
return self
Expand Down Expand Up @@ -106,7 +107,12 @@ def execute(self, sql: str, *parameters: Any) -> None:
"""
if len(parameters) > 0:
sql = sql % parameters
spark_session = SparkSession.builder.enableHiveSupport().getOrCreate()
builder = SparkSession.builder.enableHiveSupport()

for parameter, value in self.server_side_parameters.items():
builder = builder.config(parameter, value)

spark_session = builder.getOrCreate()
self._df = spark_session.sql(sql)

def fetchall(self) -> Optional[List[Row]]:
Expand Down Expand Up @@ -159,6 +165,9 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above.

self.server_side_parameters = server_side_parameters or {}

def cursor(self) -> Cursor:
"""
Get a cursor.
Expand All @@ -168,7 +177,7 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor()
return Cursor(server_side_parameters=self.server_side_parameters)


class SessionConnectionWrapper(object):
Expand Down