-
Notifications
You must be signed in to change notification settings - Fork 13.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: refactor all get_sqla_engine
to use contextmanager in codebase
#21943
Changes from all commits
158da8d
face73f
95d079e
87c0d79
11b240b
1bfdbda
4146d5a
54fc147
fdc6ca3
66c0801
1f9ec5e
8811a99
82d7532
1f829ac
d53d116
752161d
0ac6fb1
31f3c1d
e089a8d
7ce5836
b05f0e8
12b05bd
89020b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,17 +166,26 @@ def load_data( | |
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): | ||
logger.info("Loading data inside the import transaction") | ||
connection = session.connection() | ||
df.to_sql( | ||
dataset.table_name, | ||
con=connection, | ||
schema=dataset.schema, | ||
if_exists="replace", | ||
chunksize=CHUNKSIZE, | ||
dtype=dtype, | ||
index=False, | ||
method="multi", | ||
) | ||
else: | ||
logger.warning("Loading data outside the import transaction") | ||
connection = database.get_sqla_engine() | ||
|
||
df.to_sql( | ||
dataset.table_name, | ||
con=connection, | ||
schema=dataset.schema, | ||
if_exists="replace", | ||
chunksize=CHUNKSIZE, | ||
dtype=dtype, | ||
index=False, | ||
method="multi", | ||
) | ||
with database.get_sqla_engine_with_context() as engine: | ||
df.to_sql( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will only run in the |
||
dataset.table_name, | ||
con=engine, | ||
schema=dataset.schema, | ||
if_exists="replace", | ||
chunksize=CHUNKSIZE, | ||
dtype=dtype, | ||
index=False, | ||
method="multi", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from typing import ( | ||
Any, | ||
Callable, | ||
ContextManager, | ||
Dict, | ||
List, | ||
Match, | ||
|
@@ -471,8 +472,16 @@ def get_engine( | |
database: "Database", | ||
schema: Optional[str] = None, | ||
source: Optional[utils.QuerySource] = None, | ||
) -> Engine: | ||
return database.get_sqla_engine(schema=schema, source=source) | ||
) -> ContextManager[Engine]: | ||
""" | ||
Return an engine context manager. | ||
|
||
>>> with DBEngineSpec.get_engine(database, schema, source) as engine: | ||
... connection = engine.connect() | ||
... connection.execute(sql) | ||
|
||
""" | ||
return database.get_sqla_engine_with_context(schema=schema, source=source) | ||
|
||
@classmethod | ||
def get_timestamp_expr( | ||
|
@@ -894,17 +903,17 @@ def df_to_sql( | |
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method | ||
""" | ||
|
||
engine = cls.get_engine(database) | ||
to_sql_kwargs["name"] = table.table | ||
|
||
if table.schema: | ||
# Only add schema when it is preset and non empty. | ||
to_sql_kwargs["schema"] = table.schema | ||
|
||
if engine.dialect.supports_multivalues_insert: | ||
to_sql_kwargs["method"] = "multi" | ||
with cls.get_engine(database) as engine: | ||
if engine.dialect.supports_multivalues_insert: | ||
to_sql_kwargs["method"] = "multi" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to improve this a little bit... here we're building an engine just to check an attribute in the dialect, which means we're setting up and tearing down an SSH connection just to read an attribute. :-( Maybe we should add a @classmethod
def get_dialect(database, schema, source):
engine = database.get_sqla_engine(schema=schema, source=source)
return engine.dialect Then when we only need the dialect we can call this method, which is cheaper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this case we still need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good point, I missed that. |
||
|
||
df.to_sql(con=engine, **to_sql_kwargs) | ||
df.to_sql(con=engine, **to_sql_kwargs) | ||
|
||
@classmethod | ||
def convert_dttm( # pylint: disable=unused-argument | ||
|
@@ -1277,13 +1286,15 @@ def estimate_query_cost( | |
parsed_query = sql_parse.ParsedQuery(sql) | ||
statements = parsed_query.get_statements() | ||
|
||
engine = cls.get_engine(database, schema=schema, source=source) | ||
costs = [] | ||
with closing(engine.raw_connection()) as conn: | ||
cursor = conn.cursor() | ||
for statement in statements: | ||
processed_statement = cls.process_statement(statement, database) | ||
costs.append(cls.estimate_statement_cost(processed_statement, cursor)) | ||
with cls.get_engine(database, schema=schema, source=source) as engine: | ||
with closing(engine.raw_connection()) as conn: | ||
cursor = conn.cursor() | ||
for statement in statements: | ||
processed_statement = cls.process_statement(statement, database) | ||
costs.append( | ||
cls.estimate_statement_cost(processed_statement, cursor) | ||
) | ||
return costs | ||
|
||
@classmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -340,8 +340,12 @@ def df_to_sql( | |
if not table.schema: | ||
raise Exception("The table schema must be defined") | ||
|
||
engine = cls.get_engine(database) | ||
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host} | ||
to_gbq_kwargs = {} | ||
with cls.get_engine(database) as engine: | ||
to_gbq_kwargs = { | ||
"destination_table": str(table), | ||
"project_id": engine.url.host, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should add an attribute to DB engine specs annotating if they support SSH tunnel or not? BigQuery, eg, will probably never support it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. created a ticket for this |
||
} | ||
|
||
# Add credentials if they are set on the SQLAlchemy dialect. | ||
creds = engine.dialect.credentials_info | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,8 +185,6 @@ def df_to_sql( | |
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method | ||
""" | ||
|
||
engine = cls.get_engine(database) | ||
|
||
if to_sql_kwargs["if_exists"] == "append": | ||
raise SupersetException("Append operation not currently supported") | ||
|
||
|
@@ -205,7 +203,8 @@ def df_to_sql( | |
if table_exists: | ||
raise SupersetException("Table already exists") | ||
elif to_sql_kwargs["if_exists"] == "replace": | ||
engine.execute(f"DROP TABLE IF EXISTS {str(table)}") | ||
with cls.get_engine(database) as engine: | ||
engine.execute(f"DROP TABLE IF EXISTS {str(table)}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's interesting that sometimes we use |
||
|
||
def _get_hive_type(dtype: np.dtype) -> str: | ||
hive_type_by_dtype = { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to do this now, but eventually we could have a
database.get_raw_connection
context manager, to simplify things a little bit. This way we could rewrite this as:And the implementation of
get_raw_connection()
would take care of closing the connection: