Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix Control Message Utils & SQL Max Connections Exhaust #1243

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions morpheus/loaders/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import engine

import cudf

Expand All @@ -40,30 +41,27 @@ def _parse_query_data(
Parameters
----------
query_data : Dict[str, Union[str, Optional[Dict[str, Any]]]]
The dictionary containing the connection string, query, and params (optional).
The dictionary containing the query, and params (optional).

Returns
-------
Dict[str, Union[str, Optional[Dict[str, Any]]]]
A dictionary containing parsed connection string, query, and params (if present).
"""

return {
"connection_string": query_data["connection_string"],
"query": query_data["query"],
"params": query_data.get("params", None)
}
return {"query": query_data["query"], "params": query_data.get("params", None)}


def _read_sql(connection_string: str, query: str, params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> \
typing.Dict[str, pd.DataFrame]:
def _read_sql(engine_obj: engine.Engine,
query: str,
params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> typing.Dict[str, pd.DataFrame]:
"""
Creates a DataFrame from a SQL query.

Parameters
----------
connection_string : str
Connection string to the database.
engine_obj : engine.Engine
SQL engine instance.
query : str
SQL query.
params : Optional[Dict[str, Any]], default=None
Expand All @@ -75,14 +73,10 @@ def _read_sql(connection_string: str, query: str, params: typing.Optional[typing
A dictionary containing a DataFrame of the SQL query result.
"""

# TODO(Devin): PERFORMANCE OPTIMIZATION
# TODO(Devin): Add connection pooling -- Probably needs to go on the actual loader
engine = create_engine(connection_string)

if (params is None):
df = pd.read_sql(query, engine)
df = pd.read_sql(query, engine_obj)
else:
df = pd.read_sql(query, engine, params=params)
df = pd.read_sql(query, engine_obj, params=params)

return {"df": df}

Expand Down Expand Up @@ -132,14 +126,24 @@ def sql_loader(control_message: ControlMessage, task: typing.Dict[str, typing.An

with CMDefaultFailureContextManager(control_message):
final_df = None
engine_registry = {}

sql_config = task["sql_config"]
queries = sql_config["queries"]

for query_data in queries:
aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df)
execution_chain = ExecutionChain(function_chain=[_parse_query_data, _read_sql, aggregate_df])
final_df = execution_chain(query_data=query_data)
try:
for query_data in queries:
conn_str = query_data.pop("connection_string")
if conn_str not in engine_registry:
engine_registry[conn_str] = create_engine(conn_str)

aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df)
read_sql = functools.partial(_read_sql, engine_obj=engine_registry[conn_str])
execution_chain = ExecutionChain(function_chain=[_parse_query_data, read_sql, aggregate_df])
final_df = execution_chain(query_data=query_data)
finally:
# Dispose all open connections.
for engine_obj in engine_registry.values():
engine_obj.dispose()

control_message.payload(MessageMeta(final_df))

Expand Down
7 changes: 4 additions & 3 deletions morpheus/utils/control_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ def cm_default_failure_context_manager(raise_on_failure: bool = False) -> typing
def decorator(func):

@wraps(func)
def wrapper(control_messsage: ControlMessage, *args, **kwargs):
with CMDefaultFailureContextManager(control_message=control_messsage,
def wrapper(control_message: ControlMessage, *args, **kwargs):
ret_cm = control_message
with CMDefaultFailureContextManager(control_message=control_message,
raise_on_failure=raise_on_failure) as ctx_mgr:
cm_ensure_payload_not_null(control_message=control_messsage)
cm_ensure_payload_not_null(control_message=control_message)
ret_cm = func(ctx_mgr.control_message, *args, **kwargs)

return ret_cm
Expand Down