Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor all get_sqla_engine to use contextmanager in codebase #21943

Merged
merged 23 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,13 +960,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

engine = self.database.get_sqla_engine()
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()

def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR
Expand Down
39 changes: 23 additions & 16 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
)

db_engine_spec = dataset.database.db_engine_spec
engine = dataset.database.get_sqla_engine(schema=dataset.schema)
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
Expand All @@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
cols = result_set.columns
with dataset.database.get_sqla_engine_with_context(
schema=dataset.schema
) as engine:
with closing(engine.raw_connection()) as conn:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do this now, but eventually we could have a database.get_raw_connection context manager, to simplify things a little bit. This way we could rewrite this as:

with dataset.database.get_raw_connection(schema=dataset.schema) as conn:
    cursor = conn.cursor()
    ...

And the implementation of get_raw_connection() would take care of closing the connection:

@contextmanager
def get_raw_connection(...):
    with get_sqla_engine_with_context(...) as engine:
        with closing(engine.raw_connection()) as conn:
             yield conn

cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
Expand All @@ -155,14 +159,17 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with closing(database.get_sqla_engine().raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
with database.get_sqla_engine_with_context() as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex

Expand Down
52 changes: 26 additions & 26 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def run(self) -> None: # pylint: disable=too-many-statements
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)

engine = database.get_sqla_engine()
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
Expand All @@ -100,31 +99,32 @@ def ping(engine: Engine) -> bool:
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)

try:
alive = func_timeout(
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
with database.get_sqla_engine_with_context() as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)

if not alive:
raise DBAPIError(ex_str or None, None, None)
Expand Down
31 changes: 16 additions & 15 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,22 @@ def run(self) -> None:
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)

engine = database.get_sqla_engine()
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
alive = False
with database.get_sqla_engine_with_context() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex

if not alive:
raise DatabaseOfflineError(
Expand Down
33 changes: 21 additions & 12 deletions superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,26 @@ def load_data(
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
logger.info("Loading data inside the import transaction")
connection = session.connection()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
else:
logger.warning("Loading data outside the import transaction")
connection = database.get_sqla_engine()

df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
with database.get_sqla_engine_with_context() as engine:
df.to_sql(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will only run in the else block, but it needs to run in the if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): block as well.

dataset.table_name,
con=engine,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
35 changes: 23 additions & 12 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Match,
Expand Down Expand Up @@ -471,8 +472,16 @@ def get_engine(
database: "Database",
schema: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
return database.get_sqla_engine(schema=schema, source=source)
) -> ContextManager[Engine]:
"""
Return an engine context manager.

>>> with DBEngineSpec.get_engine(database, schema, source) as engine:
... connection = engine.connect()
... connection.execute(sql)

"""
return database.get_sqla_engine_with_context(schema=schema, source=source)

@classmethod
def get_timestamp_expr(
Expand Down Expand Up @@ -894,17 +903,17 @@ def df_to_sql(
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""

engine = cls.get_engine(database)
to_sql_kwargs["name"] = table.table

if table.schema:
# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema

if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
with cls.get_engine(database) as engine:
if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to improve this a little bit... here we're building an engine just to check an attribute in the dialect, which means we're setting up and tearing down an SSH connection just to read an attribute. :-(

Maybe we should add a get_dialect method to the DB engine spec, that builds the engine without the context manager:

@classmethod
def get_dialect(database, schema, source):
     engine = database.get_sqla_engine(schema=schema, source=source)
     return engine.dialect

Then when we only need the dialect we can call this method, which is cheaper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case we still need the engine though, so it makes more sense to just use get_engine instead of just the dialect

https://github.com/apache/superset/pull/21943/files/7ce583678e1770d472527abb8270dd22e666b9c0#diff-2e62d64ef1113e48efdfeb2acbaa522fca13e49e6a00c2cfd4f74efc4ae1b45cR916

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, I missed that.


df.to_sql(con=engine, **to_sql_kwargs)
df.to_sql(con=engine, **to_sql_kwargs)

@classmethod
def convert_dttm( # pylint: disable=unused-argument
Expand Down Expand Up @@ -1277,13 +1286,15 @@ def estimate_query_cost(
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()

engine = cls.get_engine(database, schema=schema, source=source)
costs = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
with cls.get_engine(database, schema=schema, source=source) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,12 @@ def df_to_sql(
if not table.schema:
raise Exception("The table schema must be defined")

engine = cls.get_engine(database)
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}
to_gbq_kwargs = {}
with cls.get_engine(database) as engine:
to_gbq_kwargs = {
"destination_table": str(table),
"project_id": engine.url.host,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should add an attribute to DB engine specs annotating if they support SSH tunnel or not? BigQuery, eg, will probably never support it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a ticket for this

}

# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info
Expand Down
10 changes: 5 additions & 5 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def extra_table_metadata(
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
engine = cls.get_engine(database, schema=schema_name)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
with cls.get_engine(database, schema=schema_name) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]

try:
metadata = json.loads(results)
Expand Down
5 changes: 2 additions & 3 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,6 @@ def df_to_sql(
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""

engine = cls.get_engine(database)

if to_sql_kwargs["if_exists"] == "append":
raise SupersetException("Append operation not currently supported")

Expand All @@ -205,7 +203,8 @@ def df_to_sql(
if table_exists:
raise SupersetException("Table already exists")
elif to_sql_kwargs["if_exists"] == "replace":
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
with cls.get_engine(database) as engine:
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting that sometimes we use engine.raw_connection().execute, and others we use engine.execute. Ideally we should standardize in the latter wherever possible, since it's more concise.


def _get_hive_type(dtype: np.dtype) -> str:
hive_type_by_dtype = {
Expand Down
33 changes: 16 additions & 17 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,11 @@ def get_view_names(
).strip()
params = {}

engine = cls.get_engine(database, schema=schema)

with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()

return sorted([row[0] for row in results])

Expand Down Expand Up @@ -989,17 +988,17 @@ def get_create_view(
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError

engine = cls.get_engine(database, schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)

except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)

except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]

@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
Expand Down
Loading