Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass existing server_side_parameters to session connection wrapper and use to configure SparkSession. #691

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def open(cls, connection):
SessionConnectionWrapper,
)

handle = SessionConnectionWrapper(Connection())
handle = SessionConnectionWrapper(Connection(), creds.server_side_parameters)
Copy link
Contributor

@Fokko Fokko May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather pass this in the Connection constructor.

Suggested change
handle = SessionConnectionWrapper(Connection(), creds.server_side_parameters)
handle = SessionConnectionWrapper(Connection(creds.server_side_parameters))

And then the connection can pass it down to the constructor of the Cursor. The config doesn't change between .execute() calls, so I think it makes more sense to pass it in the constructors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fokko I haven't tested this, but is the new code what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alarocca-apixio That's exactly what I had in mind, thanks!

else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
18 changes: 12 additions & 6 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def close(self) -> None:
self._df = None
self._rows = None

def execute(self, sql: str, *parameters: Any) -> None:
def execute(self, sql: str, server_side_parameters, *parameters: Any) -> None:
"""
Execute a sql statement.

Expand All @@ -106,7 +106,12 @@ def execute(self, sql: str, *parameters: Any) -> None:
"""
if len(parameters) > 0:
sql = sql % parameters
spark_session = SparkSession.builder.enableHiveSupport().getOrCreate()
builder = SparkSession.builder.enableHiveSupport()

for k, v in server_side_parameters.items():
builder = builder.config(k, v)

spark_session = builder.getOrCreate()
self._df = spark_session.sql(sql)

def fetchall(self) -> Optional[List[Row]]:
Expand Down Expand Up @@ -172,10 +177,11 @@ def cursor(self) -> Cursor:


class SessionConnectionWrapper(object):
"""Connection wrapper for the sessoin connection method."""
"""Connection wrapper for the session connection method."""

def __init__(self, handle):
def __init__(self, handle, server_side_parameters):
self.handle = handle
self.server_side_parameters = server_side_parameters
self._cursor = None

def cursor(self):
Expand All @@ -200,10 +206,10 @@ def execute(self, sql, bindings=None):
sql = sql.strip()[:-1]

if bindings is None:
self._cursor.execute(sql)
self._cursor.execute(sql, self.server_side_parameters)
else:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, *bindings)
self._cursor.execute(sql, self.server_side_parameters, *bindings)

@property
def description(self):
Expand Down