diff --git a/README.md b/README.md index f3152b4929..dff7b545a5 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,8 @@ | MsSQL | √ | × | √ | × | √ | × | × | × | × | × | | Redis | √ | × | √ | × | × | × | × | × | × | × | | PgSQL | √ | × | √ | × | × | × | × | × | × | × | -| Oracle | √ | √ | √ | √ | √ | × | × | × | × | × | -| MongoDB | √ | √ | √ | × | × | × | × | √ | × | × | +| Oracle | √ | √ | √ | √ | √ | × | √ | × | × | × | +| MongoDB | √ | √ | √ | × | × | × | √ | √ | × | × | | Phoenix | √ | × | √ | × | × | × | × | × | × | × | | ODPS | √ | × | × | × | × | × | × | × | × | × | | ClickHouse | √ | √ | √ | × | × | × | × | × | × | × | diff --git a/sql/binlog.py b/sql/binlog.py index b36013f1e4..812f00b905 100644 --- a/sql/binlog.py +++ b/sql/binlog.py @@ -66,11 +66,9 @@ def del_binlog(request): result = {"status": 1, "msg": "实例不存在", "data": []} return HttpResponse(json.dumps(result), content_type="application/json") - # escape - binlog = MySQLdb.escape_string(binlog).decode("utf-8") - if binlog: query_engine = get_engine(instance=instance) + binlog = query_engine.escape_string(binlog) query_result = query_engine.query(sql=rf"purge master logs to '{binlog}';") if query_result.error is None: result = {"status": 0, "msg": "清理成功", "data": ""} diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 2ee48b8d5e..4a702d8075 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -29,6 +29,7 @@ def table_list(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) data = query_engine.get_group_tables_by_db(db_name=db_name) res = {"status": 0, "data": data} except Instance.DoesNotExist: @@ -50,6 +51,7 @@ def table_info(request): db_name = request.GET.get("db_name", "") tb_name = request.GET.get("tb_name", "") db_type = request.GET.get("db_type", "") + if instance_name and db_name and tb_name: data = {} try: @@ -57,6 +59,8 @@ def table_info(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + tb_name = query_engine.escape_string(tb_name) data["meta_data"] = query_engine.get_table_meta_data( db_name=db_name, tb_name=tb_name ) @@ -91,8 +95,6 @@ def export(request): """导出数据字典""" instance_name = request.GET.get("instance_name", "") db_name = request.GET.get("db_name", "") - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") try: instance = user_instances( @@ -104,7 +106,7 @@ def export(request): # 普通用户仅可以获取指定数据库的字典信息 if db_name: - dbs = [db_name] + dbs = [query_engine.escape_string(db_name)] # 管理员可以导出整个实例的字典信息 elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index cbdef337d7..a4f461bc1f 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -86,6 +86,10 @@ def info(self): """返回引擎简介""" return "Base engine" + def escape_string(self, value: str) -> str: + """参数转义""" + return value + @property def auto_backup(self): """是否支持备份""" @@ -167,7 +171,15 @@ def filter_sql(self, sql="", limit_num=0): """给查询语句增加结果级限制或者改写语句, 返回修改后的语句""" return sql.strip() - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs + ): """实际查询 返回一个ResultSet""" return ResultSet() @@ -180,7 +192,7 @@ def execute_check(self, db_name=None, sql=""): """执行语句的检查 返回一个ReviewSet""" return ReviewSet() - def execute(self): + def execute(self, **kwargs): """执行语句 返回一个ReviewSet""" return ReviewSet() diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 22216c4be5..33dedfdd31 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -1,5 +1,6 @@ # -*- coding: UTF-8 -*- from clickhouse_driver import connect +from clickhouse_driver.util.escape import escape_chars_map from sql.utils.sql_utils import get_syntax_type from .models import ResultSet, ReviewResult, ReviewSet from common.utils.timer import FuncTimer @@ -49,6 +50,10 @@ def name(self): def info(self): return "ClickHouse engine" + def escape_string(self, value: str) -> str: + """字符串参数转义""" + return "%s" % "".join(escape_chars_map.get(c, c) for c in value) + @property def auto_backup(self): """是否支持备份""" @@ -63,11 +68,9 @@ def server_version(self): def get_table_engine(self, tb_name): """获取某个table的engine type""" - sql = f"""select engine - from system.tables - where database='{tb_name.split('.')[0]}' - and name='{tb_name.split('.')[1]}'""" - query_result = self.query(sql=sql) + db, tb = tb_name.split(".") + sql = f"""select engine from system.tables where database=%(db)s and name=%(tb)s""" + query_result = self.query(sql=sql, parameters={"db": db, "tb": tb}) if query_result.rows: result = {"status": 1, "engine": query_result.rows[0][0]} else: @@ -104,15 +107,20 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): from system.columns where - database = '{db_name}' - and table = '{tb_name}';""" - result = self.query(db_name=db_name, sql=sql) + database = %(db_name)s + and table = %(tb_name)s;""" + result = self.query( + db_name=db_name, + sql=sql, + parameters={"db_name": db_name, "tb_name": tb_name}, + ) column_list = [row[0] for row in result.rows] result.rows = column_list return result def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet 类似查询""" + tb_name = self.escape_string(tb_name) sql = f"show create table `{tb_name}`;" result = self.query(db_name=db_name, sql=sql) @@ -121,13 +129,21 @@ def describe_table(self, db_name, tb_name, **kwargs): ) return result - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection(db_name=db_name) cursor = conn.cursor() - cursor.execute(sql) + cursor.execute(sql, parameters) if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) else: @@ -462,14 +478,14 @@ def execute_workflow(self, workflow): break return execute_result - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) try: cursor = conn.cursor() for statement in sqlparse.split(sql): - cursor.execute(statement) + cursor.execute(statement, parameters) cursor.close() except Exception as e: logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}") diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index c228d39c3b..419a866d76 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -63,6 +63,10 @@ def get_backup_connection(): autocommit=True, ) + def escape_string(self, value: str) -> str: + """字符串参数转义""" + return MySQLdb.escape_string(value).decode("utf-8") + def execute_check(self, instance=None, db_name=None, sql=""): """inception check""" # 判断如果配置了隧道则连接隧道 @@ -282,8 +286,8 @@ def set_variable(self, variable_name, variable_value): def osc_control(self, **kwargs): """控制osc执行,获取进度、终止、暂停、恢复等""" - sqlsha1 = MySQLdb.escape_string(kwargs.get("sqlsha1")).decode("utf-8") - command = MySQLdb.escape_string(kwargs.get("command")).decode("utf-8") + sqlsha1 = self.escape_string(kwargs.get("sqlsha1", "")) + command = self.escape_string(kwargs.get("command", "")) if command == "get": sql = f"inception get osc_percent '{sqlsha1}';" else: diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py index 712ba07224..9a56a0dadf 100644 --- a/sql/engines/mssql.py +++ b/sql/engines/mssql.py @@ -24,11 +24,21 @@ def get_connection(self, db_name=None): self.password, self.instance.charset or "UTF8", ) + if db_name: + connstr = f"{connstr};DATABASE={db_name}" if self.conn: return self.conn self.conn = pyodbc.connect(connstr) return self.conn + @property + def name(self): + return "MsSQL" + + @property + def info(self): + return "MsSQL engine" + def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" sql = "SELECT name FROM master.sys.databases order by name" @@ -44,10 +54,8 @@ def get_all_databases(self): def get_all_tables(self, db_name, **kwargs): """获取table 列表, 返回一个ResultSet""" sql = """SELECT TABLE_NAME - FROM {0}.INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""".format( - db_name - ) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""" result = self.query(db_name=db_name, sql=sql) tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list @@ -80,7 +88,7 @@ def get_group_tables_by_db(self, db_name): def get_table_meta_data(self, db_name, tb_name, **kwargs): """数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}""" sql = f""" - SELECT space.*,table_comment,index_length,IDENT_CURRENT('{tb_name}') as auto_increment + SELECT space.*,table_comment,index_length,IDENT_CURRENT(?) as auto_increment FROM ( SELECT t.NAME AS table_name, @@ -99,7 +107,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): INNER JOIN sys.allocation_units a ON p.partition_id = a.container_id WHERE - t.NAME ='{tb_name}' + t.NAME =? AND t.is_ms_shipped = 0 AND i.OBJECT_ID > 255 GROUP BY @@ -120,7 +128,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): t.NAME AS table_name, SUM(page_count * 8) AS index_length FROM sys.dm_db_index_physical_stats( - db_id(), object_id('{tb_name}'), NULL, NULL, 'DETAILED') AS s + db_id(), object_id(?), NULL, NULL, 'DETAILED') AS s JOIN sys.indexes AS i ON s.[object_id] = i.[object_id] AND s.index_id = i.index_id INNER JOIN @@ -129,7 +137,15 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): ) AS index_size ON index_size.table_name = space.table_name; """ - _meta_data = self.query(db_name, sql) + _meta_data = self.query( + db_name, + sql, + parameters=( + tb_name, + tb_name, + tb_name, + ), + ) return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): @@ -140,8 +156,15 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): COLLATION_NAME 列字符集, IS_NULLABLE 是否为空, COLUMN_DEFAULT 默认值 - from INFORMATION_SCHEMA.columns where TABLE_CATALOG='{db_name}' and TABLE_NAME = '{tb_name}';""" - _desc_data = self.query(db_name, sql) + from INFORMATION_SCHEMA.columns where TABLE_CATALOG=? and TABLE_NAME = ?;""" + _desc_data = self.query( + db_name, + sql, + parameters=( + db_name, + tb_name, + ), + ) return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): @@ -152,9 +175,9 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): i.name AS 索引名, is_unique as 唯一性,is_primary_key as 是否主建 FROM sys.indexes AS i - WHERE i.object_id = OBJECT_ID('{tb_name}') + WHERE i.object_id = OBJECT_ID(?) group by i.name,i.object_id,i.index_id,is_unique,is_primary_key;""" - _index_data = self.query(db_name, sql) + _index_data = self.query(db_name, sql, parameters=(tb_name,)) return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): @@ -189,8 +212,10 @@ def get_tables_metas_data(self, db_name, **kwargs): COLLATION_NAME, IS_NULLABLE, COLUMN_DEFAULT - from INFORMATION_SCHEMA.columns where TABLE_CATALOG='{db_name}' and TABLE_NAME = '{tb["TABLE_NAME"]}';""" - query_result = self.query(db_name=db_name, sql=sql_cols, close_conn=False) + from INFORMATION_SCHEMA.columns where TABLE_CATALOG=? and TABLE_NAME = '{tb["TABLE_NAME"]}';""" + query_result = self.query( + db_name=db_name, sql=sql_cols, close_conn=False, parameters=(db_name,) + ) columns = [] # 转换查询结果为dict @@ -216,19 +241,17 @@ def describe_table(self, db_name, tb_name, **kwargs): c.scale ColumnScale, c.isnullable ColumnNull, case when i.id is not null then 'Y' else 'N' end TablePk - from (select name,id,uid from {0}..sysobjects where (xtype='U' or xtype='V') ) o - inner join {0}..syscolumns c on o.id=c.id - inner join {0}..systypes t on c.xtype=t.xusertype - left join {0}..sysusers u on u.uid=o.uid - left join (select name,id,uid,parent_obj from {0}..sysobjects where xtype='PK' ) opk on opk.parent_obj=o.id - left join (select id,name,indid from {0}..sysindexes) ie on ie.id=o.id and ie.name=opk.name - left join {0}..sysindexkeys i on i.id=o.id and i.colid=c.colid and i.indid=ie.indid + from (select name,id,uid from sysobjects where (xtype='U' or xtype='V') ) o + inner join syscolumns c on o.id=c.id + inner join systypes t on c.xtype=t.xusertype + left join sysusers u on u.uid=o.uid + left join (select name,id,uid,parent_obj from sysobjects where xtype='PK' ) opk on opk.parent_obj=o.id + left join (select id,name,indid from sysindexes) ie on ie.id=o.id and ie.name=opk.name + left join sysindexkeys i on i.id=o.id and i.colid=c.colid and i.indid=ie.indid WHERE O.name NOT LIKE 'MS%' AND O.name NOT LIKE 'SY%' - and O.name='{1}' - order by o.name,c.colid""".format( - db_name, tb_name - ) - result = self.query(sql=sql) + and O.name=? + order by o.name,c.colid""" + result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,)) return result def query_check(self, db_name=None, sql=""): @@ -300,15 +323,25 @@ def filter_sql(self, sql="", limit_num=0): return sql_lower.replace("select", "select top {}".format(limit_num)) return sql.strip() - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters: tuple = None, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: - conn = self.get_connection() + conn = self.get_connection(db_name) cursor = conn.cursor() - if db_name: - cursor.execute("use [{}];".format(db_name)) - cursor.execute(sql) + # https://github.com/mkleehammer/pyodbc/wiki/Cursor#executesql-parameters + if parameters: + cursor.execute(sql, *parameters) + else: + cursor.execute(sql) if int(limit_num) > 0: rows = cursor.fetchmany(int(limit_num)) else: @@ -371,7 +404,7 @@ def execute_workflow(self, workflow): db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content ) - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """执行sql语句 返回 Review set""" execute_result = ReviewSet(full_sql=sql) conn = self.get_connection(db_name=db_name) diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index e2bc7430f6..2b05b6488b 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -98,6 +98,10 @@ def name(self): def info(self): return "MySQL engine" + def escape_string(self, value: str) -> str: + """字符串参数转义""" + return MySQLdb.escape_string(value).decode("utf-8") + @property def auto_backup(self): """是否支持备份""" @@ -166,16 +170,14 @@ def get_all_tables(self, db_name, **kwargs): return result def get_group_tables_by_db(self, db_name): - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") data = {} sql = f"""SELECT TABLE_NAME, TABLE_COMMENT FROM information_schema.TABLES WHERE - TABLE_SCHEMA='{db_name}';""" - result = self.query(db_name=db_name, sql=sql) + TABLE_SCHEMA=%(db_name)s;""" + result = self.query(db_name=db_name, sql=sql, parameters={"db_name": db_name}) for row in result.rows: table_name, table_cmt = row[0], row[1] if table_name[0] not in data: @@ -185,9 +187,6 @@ def get_group_tables_by_db(self, db_name): def get_table_meta_data(self, db_name, tb_name, **kwargs): """数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}""" - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") sql = f"""SELECT TABLE_NAME as table_name, ENGINE as engine, @@ -208,9 +207,11 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): FROM information_schema.TABLES WHERE - TABLE_SCHEMA='{db_name}' - AND TABLE_NAME='{tb_name}'""" - _meta_data = self.query(db_name, sql) + TABLE_SCHEMA=%(db_name)s + AND TABLE_NAME=%(tb_name)s""" + _meta_data = self.query( + db_name, sql, parameters={"db_name": db_name, "tb_name": tb_name} + ) return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): @@ -227,10 +228,12 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): FROM information_schema.COLUMNS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}' + TABLE_SCHEMA = %(db_name)s + AND TABLE_NAME = %(tb_name)s ORDER BY ORDINAL_POSITION;""" - _desc_data = self.query(db_name, sql) + _desc_data = self.query( + db_name, sql, parameters={"db_name": db_name, "tb_name": tb_name} + ) return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): @@ -247,18 +250,23 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): FROM information_schema.STATISTICS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}';""" - _index_data = self.query(db_name, sql) + TABLE_SCHEMA = %(db_name)s + AND TABLE_NAME = %(tb_name)s;""" + _index_data = self.query( + db_name, sql, parameters={"db_name": db_name, "tb_name": tb_name} + ) return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): """获取数据库所有表格信息,用作数据字典导出接口""" sql_tbs = ( - f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';" + f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=%(db_name)s;" ) tbs = self.query( - sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False + sql=sql_tbs, + cursorclass=MySQLdb.cursors.DictCursor, + close_conn=False, + parameters={"db_name": db_name}, ).rows table_metas = [] for tb in tbs: @@ -285,10 +293,13 @@ def get_tables_metas_data(self, db_name, **kwargs): def get_bind_users(self, db_name: str): sql_get_bind_users = f"""select group_concat(distinct(GRANTEE)),TABLE_SCHEMA from information_schema.SCHEMA_PRIVILEGES - where TABLE_SCHEMA='{db_name}' + where TABLE_SCHEMA=%(db_name)s group by TABLE_SCHEMA;""" return self.query( - "information_schema", sql_get_bind_users, close_conn=False + "information_schema", + sql_get_bind_users, + close_conn=False, + parameters={"db_name": db_name}, ).rows def get_all_databases_summary(self): @@ -348,9 +359,9 @@ def get_instance_users_summary(self): def create_instance_user(self, **kwargs): """实例账号管理功能,创建实例账号""" # escape - user = MySQLdb.escape_string(kwargs.get("user", "")).decode("utf-8") - host = MySQLdb.escape_string(kwargs.get("host", "")).decode("utf-8") - password1 = MySQLdb.escape_string(kwargs.get("password1", "")).decode("utf-8") + user = self.escape_string(kwargs.get("user", "")) + host = self.escape_string(kwargs.get("host", "")) + password1 = self.escape_string(kwargs.get("password1", "")) remark = kwargs.get("remark", "") # 在一个事务内执行 hosts = host.split("|") @@ -376,14 +387,14 @@ def create_instance_user(self, **kwargs): def drop_instance_user(self, user_host: str, **kwarg): """实例账号管理功能,删除实例账号""" # escape - user_host = MySQLdb.escape_string(user_host).decode("utf-8") + user_host = self.escape_string(user_host) return self.execute(db_name="mysql", sql=f"DROP USER {user_host};") def reset_instance_user_pwd(self, user_host: str, reset_pwd: str, **kwargs): """实例账号管理功能,重置实例账号密码""" # escape - user_host = MySQLdb.escape_string(user_host).decode("utf-8") - reset_pwd = MySQLdb.escape_string(reset_pwd).decode("utf-8") + user_host = self.escape_string(user_host) + reset_pwd = self.escape_string(reset_pwd) return self.execute( db_name="mysql", sql=f"ALTER USER {user_host} IDENTIFIED BY '{reset_pwd}';" ) @@ -401,16 +412,21 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): FROM information_schema.COLUMNS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}' + TABLE_SCHEMA = %(db_name)s + AND TABLE_NAME = %(tb_name)s ORDER BY ORDINAL_POSITION;""" - result = self.query(db_name=db_name, sql=sql) + result = self.query( + db_name=db_name, + sql=sql, + parameters=({"db_name": db_name, "tb_name": tb_name}), + ) column_list = [row[0] for row in result.rows] result.rows = column_list return result def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet 类似查询""" + tb_name = self.escape_string(tb_name) sql = f"show create table `{tb_name}`;" result = self.query(db_name=db_name, sql=sql) return result @@ -431,7 +447,15 @@ def result_set_binary_as_hex(result_set): result_set.rows = tuple(new_rows) return result_set - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) max_execution_time = kwargs.get("max_execution_time", 0) @@ -444,7 +468,7 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): cursor.execute(f"set session max_execution_time={max_execution_time};") except MySQLdb.OperationalError: pass - effect_row = cursor.execute(sql) + effect_row = cursor.execute(sql, parameters) if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) else: @@ -624,14 +648,14 @@ def execute_workflow(self, workflow): # inception执行 return self.inc_engine.execute(workflow) - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) try: cursor = conn.cursor() for statement in sqlparse.split(sql): - cursor.execute(statement) + cursor.execute(statement, parameters) conn.commit() cursor.close() except Exception as e: @@ -679,7 +703,7 @@ def processlist(self, command_type): """获取连接信息""" base_sql = "select id, user, host, db, command, time, state, ifnull(info,'') as info from information_schema.processlist" # escape - command_type = MySQLdb.escape_string(command_type).decode("utf-8") + command_type = self.escape_string(command_type) if not command_type: command_type = "Query" if command_type == "All": diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py index 160165bae2..fe188f28cb 100644 --- a/sql/engines/oracle.py +++ b/sql/engines/oracle.py @@ -160,17 +160,22 @@ def _get_all_schemas(self): def get_all_tables(self, db_name, **kwargs): """获取table 列表, 返回一个ResultSet""" - sql = f"""SELECT table_name FROM all_tables WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') AND OWNER = '{db_name}' AND IOT_NAME IS NULL AND DURATION IS NULL order by table_name - """ - result = self.query(db_name=db_name, sql=sql) + sql = f"""SELECT table_name + FROM all_tables + WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') + AND OWNER = :db_name AND IOT_NAME IS NULL + AND DURATION IS NULL order by table_name""" + result = self.query(db_name=db_name, sql=sql, parameters={"db_name": db_name}) tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result def get_group_tables_by_db(self, db_name): data = {} - table_list_sql = f"""SELECT table_name, comments FROM dba_tab_comments WHERE owner = '{db_name}'""" - result = self.query(db_name=db_name, sql=table_list_sql) + table_list_sql = f"""SELECT table_name, comments FROM dba_tab_comments WHERE owner = :db_name""" + result = self.query( + db_name=db_name, sql=table_list_sql, parameters={"db_name": db_name} + ) for row in result.rows: table_name, table_cmt = row[0], row[1] if table_name[0] not in data: @@ -205,9 +210,13 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): and bss.TABLE_NAME = tcs.table_name WHERE - tcs.OWNER='{db_name}' - AND tcs.TABLE_NAME='{tb_name}'""" - _meta_data = self.query(db_name=db_name, sql=meta_data_sql) + tcs.OWNER=:db_name + AND tcs.TABLE_NAME=:tb_name""" + _meta_data = self.query( + db_name=db_name, + sql=meta_data_sql, + parameters={"db_name": db_name, "tb_name": tb_name}, + ) return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): @@ -249,10 +258,14 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): and acs.table_name = ics.TABLE_NAME and acs.index_name = ics.INDEX_NAME WHERE - bcs.OWNER='{db_name}' - AND bcs.TABLE_NAME='{tb_name}' + bcs.OWNER=:db_name + AND bcs.TABLE_NAME=:tb_name ORDER BY bcs.COLUMN_NAME""" - _desc_data = self.query(db_name=db_name, sql=desc_sql) + _desc_data = self.query( + db_name=db_name, + sql=desc_sql, + parameters={"db_name": db_name, "tb_name": tb_name}, + ) return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): @@ -272,9 +285,11 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): on ais.owner = pis.owner and ais.index_name = pis.index_name WHERE - ais.owner = '{db_name}' - AND ais.table_name = '{tb_name}'""" - _index_data = self.query(db_name, index_sql) + ais.owner = :db_name + AND ais.table_name = :tb_name""" + _index_data = self.query( + db_name, index_sql, parameters={"db_name": db_name, "tb_name": tb_name} + ) return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): @@ -324,9 +339,11 @@ def get_tables_metas_data(self, db_name, **kwargs): on t1.OWNER = bcs.OWNER AND t1.TABLE_NAME = bcs.TABLE_NAME AND t1.column_name = bcs.COLUMN_NAME - WHERE bcs.OWNER = '{db_name}' + WHERE bcs.OWNER = :db_name order by bcs.TABLE_NAME, comments""" - cols_req = self.query(sql=sql_cols, close_conn=False).rows + cols_req = self.query( + sql=sql_cols, close_conn=False, parameters={"db_name": db_name} + ).rows # 给查询结果定义列名,query_engine.query的游标是0 1 2 cols_df = pd.DataFrame( @@ -371,8 +388,8 @@ def get_tables_metas_data(self, db_name, **kwargs): def get_all_objects(self, db_name, **kwargs): """获取object_name 列表, 返回一个ResultSet""" - sql = f"""SELECT object_name FROM all_objects WHERE OWNER = '{db_name}' """ - result = self.query(db_name=db_name, sql=sql) + sql = f"""SELECT object_name FROM all_objects WHERE OWNER = :db_name """ + result = self.query(db_name=db_name, sql=sql, parameters={"db_name": db_name}) tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -398,9 +415,13 @@ def describe_table(self, db_name, tb_name, **kwargs): WHERE a.table_name = b.table_name and a.owner = b.OWNER and a.COLUMN_NAME = b.COLUMN_NAME - and a.table_name = '{tb_name}' and a.owner = '{db_name}' order by column_id + and a.table_name = :tb_name and a.owner = :db_name order by column_id """ - result = self.query(db_name=db_name, sql=sql) + result = self.query( + db_name=db_name, + sql=sql, + parameters={"db_name": db_name, "tb_name": tb_name}, + ) return result def object_name_check(self, db_name=None, object_name=""): @@ -426,8 +447,13 @@ def object_name_check(self, db_name=None, object_name=""): object_name = object_name.replace('"', "") else: object_name = object_name.upper() - sql = f""" SELECT object_name FROM all_objects WHERE OWNER = '{schema_name}' and OBJECT_NAME = '{object_name}' """ - result = self.query(db_name=db_name, sql=sql, close_conn=False) + sql = f""" SELECT object_name FROM all_objects WHERE OWNER = :schema_name and OBJECT_NAME = :object_name """ + result = self.query( + db_name=db_name, + sql=sql, + close_conn=False, + parameters={"schema_name": schema_name, "object_name": object_name}, + ) if result.affected_rows > 0: return True else: @@ -567,7 +593,10 @@ def explain_check(self, db_name=None, sql="", close_conn=False): conn = self.get_connection() cursor = conn.cursor() if db_name: - cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ') + cursor.execute( + f" ALTER SESSION SET CURRENT_SCHEMA = :db_name ", + {"db_name": db_name}, + ) if re.match(r"^explain", sql, re.I): sql = sql else: @@ -626,21 +655,32 @@ def filter_sql(self, sql="", limit_num=0): sql = f"select sql_audit.* from ({sql.rstrip(';')}) sql_audit where rownum <= {limit_num}" return sql.strip() - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection() cursor = conn.cursor() if db_name: - cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ') + cursor.execute( + f" ALTER SESSION SET CURRENT_SCHEMA = :db_name ", + {"db_name": db_name}, + ) sql = sql.rstrip(";") # 支持oralce查询SQL执行计划语句 if re.match(r"^explain", sql, re.I): cursor.execute(sql) # 重置SQL文本,获取SQL执行计划 sql = f"select PLAN_TABLE_OUTPUT from table(dbms_xplan.display)" - cursor.execute(sql) + cursor.execute(sql, parameters or []) fields = cursor.description if any(x[1] == cx_Oracle.CLOB for x in fields): rows = [ @@ -1349,23 +1389,26 @@ def sqltuningadvisor(self, db_name=None, sql="", close_conn=True, **kwargs): my_task_name VARCHAR2(30); my_sqltext CLOB; BEGIN - my_sqltext := '{sql}'; + my_sqltext := :sql; my_task_name := DBMS_SQLTUNE.CREATE_TUNING_TASK( sql_text => my_sqltext, - user_name => '{db_name}', + user_name => :db_name, scope => 'COMPREHENSIVE', time_limit => 30, - task_name => '{task_name}', + task_name => :task_name, description => 'tuning'); - DBMS_SQLTUNE.EXECUTE_TUNING_TASK( task_name => '{task_name}'); + DBMS_SQLTUNE.EXECUTE_TUNING_TASK( task_name => :task_name); END;""" task_begin = 1 - cursor.execute(create_task_sql) + cursor.execute( + create_task_sql, + {"sql": sql, "db_name": db_name, "task_name": task_name}, + ) # 获取分析报告 get_task_sql = ( - f"""select DBMS_SQLTUNE.REPORT_TUNING_TASK( '{task_name}') from dual""" + f"""select DBMS_SQLTUNE.REPORT_TUNING_TASK(:task_name) from dual""" ) - cursor.execute(get_task_sql) + cursor.execute(get_task_sql, {"task_name": task_name}) fields = cursor.description if any(x[1] == cx_Oracle.CLOB for x in fields): rows = [ @@ -1392,7 +1435,7 @@ def sqltuningadvisor(self, db_name=None, sql="", close_conn=True, **kwargs): self.close() return result_set - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) @@ -1400,7 +1443,7 @@ def execute(self, db_name=None, sql="", close_conn=True): cursor = conn.cursor() for statement in sqlparse.split(sql): statement = statement.rstrip(";") - cursor.execute(statement) + cursor.execute(statement, parameters or []) except Exception as e: logger.warning(f"Oracle语句执行报错,语句:{sql},错误信息{traceback.format_exc()}") result.error = str(e) @@ -1497,11 +1540,11 @@ def tablespace(self, offset=0, row_count=14): and a.tablespace_name = c.tablespace_name and a.tablespace_name = d.tablespace_name order by total_space desc ) e - where rownum <={} - ) f where f.rownumber >={};""".format( - row_count, offset + where rownum <=:row_count + ) f where f.rownumber >=:offset;""" + return self.query( + sql=sql, parameters={"row_count": row_count, "offset": offset} ) - return self.query(sql=sql) def tablespace_count(self): """获取表空间数量""" diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index c7b5f72e24..2c303a1644 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -93,8 +93,10 @@ def get_all_tables(self, db_name, **kwargs): schema_name = kwargs.get("schema_name") sql = f"""SELECT table_name FROM information_schema.tables - where table_schema ='{schema_name}';""" - result = self.query(db_name=db_name, sql=sql) + where table_schema =%(schema_name)s;""" + result = self.query( + db_name=db_name, sql=sql, parameters={"schema_name": schema_name} + ) tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -110,9 +112,13 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): schema_name = kwargs.get("schema_name") sql = f"""SELECT column_name FROM information_schema.columns - where table_name='{tb_name}' - and table_schema ='{schema_name}';""" - result = self.query(db_name=db_name, sql=sql) + where table_name=%(tb_name)s + and table_schema=%(schema_name)s;""" + result = self.query( + db_name=db_name, + sql=sql, + parameters={"schema_name": schema_name, "tb_name": tb_name}, + ) column_list = [row[0] for row in result.rows] result.rows = column_list return result @@ -139,10 +145,15 @@ def describe_table(self, db_name, tb_name, **kwargs): information_schema.columns col left join pg_description des on col.table_name::regclass = des.objoid and col.ordinal_position = des.objsubid - where table_name = '{tb_name}' - and col.table_schema = '{schema_name}' + where table_name = %(tb_name)s + and col.table_schema = %(schema_name)s order by ordinal_position;""" - result = self.query(db_name=db_name, schema_name=schema_name, sql=sql) + result = self.query( + db_name=db_name, + schema_name=schema_name, + sql=sql, + parameters={"schema_name": schema_name, "tb_name": tb_name}, + ) return result def query_check(self, db_name=None, sql=""): @@ -164,7 +175,15 @@ def query_check(self, db_name=None, sql=""): result["msg"] = "SQL语句中含有 * " return result - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): """返回 ResultSet""" schema_name = kwargs.get("schema_name") result_set = ResultSet(full_sql=sql) @@ -177,8 +196,10 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): except: pass if schema_name: - cursor.execute(f"SET search_path TO {schema_name};") - cursor.execute(sql) + cursor.execute( + f"SET search_path TO %(schema_name)s;", {"schema_name": schema_name} + ) + cursor.execute(sql, parameters) effect_row = cursor.rowcount if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) diff --git a/sql/engines/phoenix.py b/sql/engines/phoenix.py index c68c911d19..6383c2537e 100644 --- a/sql/engines/phoenix.py +++ b/sql/engines/phoenix.py @@ -31,8 +31,8 @@ def get_all_databases(self): def get_all_tables(self, db_name, **kwargs): """获取table 列表, 返回一个ResultSet""" - sql = f"SELECT DISTINCT TABLE_NAME FROM SYSTEM.CATALOG WHERE TABLE_SCHEM = '{db_name}'" - result = self.query(db_name=db_name, sql=sql) + sql = f"SELECT DISTINCT TABLE_NAME FROM SYSTEM.CATALOG WHERE TABLE_SCHEM = ?" + result = self.query(db_name=db_name, sql=sql, parameters=(db_name,)) result.rows = [row[0] for row in result.rows if row[0] is not None] return result @@ -40,14 +40,20 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): """获取所有字段, 返回一个ResultSet""" sql = f""" SELECT DISTINCT COLUMN_NAME FROM SYSTEM.CATALOG - WHERE TABLE_SCHEM = '{db_name}' AND table_name = '{tb_name}' AND column_name is not null""" - return self.query(sql=sql) + WHERE TABLE_SCHEM = ? AND table_name = ? AND column_name is not null""" + return self.query( + sql=sql, + parameters=( + db_name, + tb_name, + ), + ) def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet""" sql = f"""SELECT COLUMN_NAME,SqlTypeName(DATA_TYPE) FROM SYSTEM.CATALOG - WHERE TABLE_SCHEM = '{db_name}' and table_name = '{tb_name}' and column_name is not null""" - result = self.query(sql=sql) + WHERE TABLE_SCHEM = ? and table_name = ? and column_name is not null""" + result = self.query(sql=sql, parameters=(db_name, tb_name)) return result def query_check(self, db_name=None, sql=""): @@ -87,13 +93,21 @@ def filter_sql(self, sql="", limit_num=0): sql = f"{sql};" return sql.strip() - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection() cursor = conn.cursor() - cursor.execute(sql) + cursor.execute(sql, parameters) if int(limit_num) > 0: rows = cursor.fetchmany(int(limit_num)) else: @@ -142,7 +156,7 @@ def execute_workflow(self, workflow): db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content ) - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """原生执行语句""" execute_result = ReviewSet(full_sql=sql) conn = self.get_connection(db_name=db_name) @@ -151,7 +165,7 @@ def execute(self, db_name=None, sql="", close_conn=True): split_sql = sqlparse.split(sql) for statement in split_sql: try: - cursor.execute(statement.rstrip(";")) + cursor.execute(statement.rstrip(";"), parameters) except Exception as e: logger.error(f"Phoenix命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}") execute_result.error = str(e) diff --git a/sql/instance.py b/sql/instance.py index c041f7035f..e7a3105f9a 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -163,7 +163,6 @@ def param_edit(request): instance_id = request.POST.get("instance_id") variable_name = request.POST.get("variable_name") variable_value = request.POST.get("runtime_value") - try: ins = Instance.objects.get(id=instance_id) except Instance.DoesNotExist: @@ -172,6 +171,8 @@ def param_edit(request): # 修改参数 engine = get_engine(instance=ins) + variable_name = engine.escape_string(variable_name) + variable_value = engine.escape_string(variable_value) # 校验是否配置模板 if not ParamTemplate.objects.filter(variable_name=variable_name).exists(): result = {"status": 1, "msg": "请先在参数模板中配置该参数!", "data": []} @@ -320,12 +321,10 @@ def instance_resource(request): result = {"status": 0, "msg": "ok", "data": []} try: - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") - query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: @@ -363,10 +362,14 @@ def describe(request): db_name = request.POST.get("db_name") schema_name = request.POST.get("schema_name") tb_name = request.POST.get("tb_name") + result = {"status": 0, "msg": "ok", "data": []} try: query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) query_result = query_engine.describe_table( db_name, tb_name, schema_name=schema_name ) diff --git a/sql/instance_account.py b/sql/instance_account.py index be3188ed56..0e085cd00c 100644 --- a/sql/instance_account.py +++ b/sql/instance_account.py @@ -180,7 +180,7 @@ def grant(request): privs = json.loads(request.POST.get("privs")) # escape - user_host = MySQLdb.escape_string(user_host).decode("utf-8") + user_host = engine.escape_string(user_host) # 全局权限 if priv_type == 0: @@ -331,14 +331,14 @@ def lock(request): return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # escape - user_host = MySQLdb.escape_string(user_host).decode("utf-8") + engine = get_engine(instance=instance) + user_host = engine.escape_string(user_host) if is_locked == "N": lock_sql = f"ALTER USER {user_host} ACCOUNT LOCK;" elif is_locked == "Y": lock_sql = f"ALTER USER {user_host} ACCOUNT UNLOCK;" - engine = get_engine(instance=instance) exec_result = engine.execute(db_name="mysql", sql=lock_sql) if exec_result.error: return JsonResponse({"status": 1, "msg": exec_result.error}) diff --git a/sql/instance_database.py b/sql/instance_database.py index 61418d7e15..223b3f36b2 100644 --- a/sql/instance_database.py +++ b/sql/instance_database.py @@ -93,7 +93,7 @@ def create(request): engine = get_engine(instance=instance) if instance.db_type == "mysql": # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") + db_name = engine.escape_string(db_name) exec_result = engine.execute( db_name="information_schema", sql=f"create database {db_name};" ) diff --git a/sql/sql_optimize.py b/sql/sql_optimize.py index b147fa9c98..62a9f8a80b 100644 --- a/sql/sql_optimize.py +++ b/sql/sql_optimize.py @@ -163,8 +163,6 @@ def optimize_sqltuning(request): except Instance.DoesNotExist: result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []} return HttpResponse(json.dumps(result), content_type="application/json") - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") sql_tunning = SqlTuning( instance_name=instance_name, db_name=db_name, sqltext=sqltext @@ -235,6 +233,7 @@ def explain(request): # 执行获取执行计划语句 query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result @@ -287,6 +286,7 @@ def optimize_sqltuningadvisor(request): # 执行获取优化报告 query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result diff --git a/sql/sql_tuning.py b/sql/sql_tuning.py index 973406cba9..4dac2a46cf 100644 --- a/sql/sql_tuning.py +++ b/sql/sql_tuning.py @@ -13,7 +13,7 @@ def __init__(self, instance_name, db_name, sqltext): instance = Instance.objects.get(instance_name=instance_name) query_engine = get_engine(instance=instance) self.engine = query_engine - self.db_name = db_name + self.db_name = self.engine.escape_string(db_name) self.sqltext = sqltext self.sql_variable = """ select diff --git a/sql/tests.py b/sql/tests.py index 5a3e8202c9..357a1ccea0 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -2539,7 +2539,7 @@ def test_param_edit_variable_not_config( data = { "instance_id": self.master.id, "variable_name": "1", - "variable_value": "false", + "runtime_value": "false", } r = self.client.post(path="/param/edit/", data=data) self.assertEqual( diff --git a/sql_api/api_instance.py b/sql_api/api_instance.py index 4cb50b51ba..6787ca4149 100644 --- a/sql_api/api_instance.py +++ b/sql_api/api_instance.py @@ -187,12 +187,10 @@ def post(self, request): instance = Instance.objects.get(pk=instance_id) try: - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") - query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py index fa702cefb5..480a704e1a 100644 --- a/sql_api/api_workflow.py +++ b/sql_api/api_workflow.py @@ -1,3 +1,4 @@ +import MySQLdb from django.contrib.auth.decorators import permission_required from django.utils.decorators import method_decorator from rest_framework import views, generics, status, serializers, permissions @@ -60,9 +61,11 @@ def post(self, request): instance = serializer.get_instance() # 交给engine进行检测 try: + db_name = request.data["db_name"] check_engine = get_engine(instance=instance) + db_name = check_engine.escape_string(db_name) check_result = check_engine.execute_check( - db_name=request.data["db_name"], sql=request.data["full_sql"].strip() + db_name=db_name, sql=request.data["full_sql"].strip() ) except Exception as e: raise serializers.ValidationError({"errors": f"{e}"})