diff --git a/morpheus/loaders/sql_loader.py b/morpheus/loaders/sql_loader.py index e3619f78c3..098d0df53c 100644 --- a/morpheus/loaders/sql_loader.py +++ b/morpheus/loaders/sql_loader.py @@ -18,6 +18,7 @@ import pandas as pd from sqlalchemy import create_engine +from sqlalchemy import engine import cudf @@ -40,7 +41,7 @@ def _parse_query_data( Parameters ---------- query_data : Dict[str, Union[str, Optional[Dict[str, Any]]]] - The dictionary containing the connection string, query, and params (optional). + The dictionary containing the query, and params (optional). Returns ------- @@ -48,22 +49,19 @@ def _parse_query_data( A dictionary containing parsed connection string, query, and params (if present). """ - return { - "connection_string": query_data["connection_string"], - "query": query_data["query"], - "params": query_data.get("params", None) - } + return {"query": query_data["query"], "params": query_data.get("params", None)} -def _read_sql(connection_string: str, query: str, params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> \ - typing.Dict[str, pd.DataFrame]: +def _read_sql(engine_obj: engine.Engine, + query: str, + params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> typing.Dict[str, pd.DataFrame]: """ Creates a DataFrame from a SQL query. Parameters ---------- - connection_string : str - Connection string to the database. + engine_obj : engine.Engine + SQL engine instance. query : str SQL query. params : Optional[Dict[str, Any]], default=None @@ -75,14 +73,10 @@ def _read_sql(connection_string: str, query: str, params: typing.Optional[typing A dictionary containing a DataFrame of the SQL query result. """ - # TODO(Devin): PERFORMANCE OPTIMIZATION - # TODO(Devin): Add connection pooling -- Probably needs to go on the actual loader - engine = create_engine(connection_string) - if (params is None): - df = pd.read_sql(query, engine) + df = pd.read_sql(query, engine_obj) else: - df = pd.read_sql(query, engine, params=params) + df = pd.read_sql(query, engine_obj, params=params) return {"df": df} @@ -132,14 +126,24 @@ def sql_loader(control_message: ControlMessage, task: typing.Dict[str, typing.An with CMDefaultFailureContextManager(control_message): final_df = None + engine_registry = {} sql_config = task["sql_config"] queries = sql_config["queries"] - - for query_data in queries: - aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df) - execution_chain = ExecutionChain(function_chain=[_parse_query_data, _read_sql, aggregate_df]) - final_df = execution_chain(query_data=query_data) + try: + for query_data in queries: + conn_str = query_data.pop("connection_string") + if conn_str not in engine_registry: + engine_registry[conn_str] = create_engine(conn_str) + + aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df) + read_sql = functools.partial(_read_sql, engine_obj=engine_registry[conn_str]) + execution_chain = ExecutionChain(function_chain=[_parse_query_data, read_sql, aggregate_df]) + final_df = execution_chain(query_data=query_data) + finally: + # Dispose all open connections. + for engine_obj in engine_registry.values(): + engine_obj.dispose() control_message.payload(MessageMeta(final_df)) diff --git a/morpheus/utils/control_message_utils.py b/morpheus/utils/control_message_utils.py index b1bd40fb36..7d6d7a9254 100644 --- a/morpheus/utils/control_message_utils.py +++ b/morpheus/utils/control_message_utils.py @@ -83,10 +83,11 @@ def cm_default_failure_context_manager(raise_on_failure: bool = False) -> typing def decorator(func): @wraps(func) - def wrapper(control_messsage: ControlMessage, *args, **kwargs): - with CMDefaultFailureContextManager(control_message=control_messsage, + def wrapper(control_message: ControlMessage, *args, **kwargs): + ret_cm = control_message + with CMDefaultFailureContextManager(control_message=control_message, raise_on_failure=raise_on_failure) as ctx_mgr: - cm_ensure_payload_not_null(control_message=control_messsage) + cm_ensure_payload_not_null(control_message=control_message) ret_cm = func(ctx_mgr.control_message, *args, **kwargs) return ret_cm