Skip to content

Commit

Permalink
engine查询增加参数化选项,修改sql参数化方式 (#2112)
Browse files Browse the repository at this point in the history
* engine增加escape_string用于处理字符串参数转义

* engine查询增加参数化选项,修改sql参数化方式
  • Loading branch information
hhyo authored Apr 9, 2023
1 parent bc8a1e3 commit 7921044
Show file tree
Hide file tree
Showing 18 changed files with 339 additions and 168 deletions.
4 changes: 1 addition & 3 deletions sql/binlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ def del_binlog(request):
result = {"status": 1, "msg": "实例不存在", "data": []}
return HttpResponse(json.dumps(result), content_type="application/json")

# escape
binlog = MySQLdb.escape_string(binlog).decode("utf-8")

if binlog:
query_engine = get_engine(instance=instance)
binlog = query_engine.escape_string(binlog)
query_result = query_engine.query(sql=rf"purge master logs to '{binlog}';")
if query_result.error is None:
result = {"status": 0, "msg": "清理成功", "data": ""}
Expand Down
8 changes: 5 additions & 3 deletions sql/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def table_list(request):
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
data = query_engine.get_group_tables_by_db(db_name=db_name)
res = {"status": 0, "data": data}
except Instance.DoesNotExist:
Expand All @@ -50,13 +51,16 @@ def table_info(request):
db_name = request.GET.get("db_name", "")
tb_name = request.GET.get("tb_name", "")
db_type = request.GET.get("db_type", "")

if instance_name and db_name and tb_name:
data = {}
try:
instance = Instance.objects.get(
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
tb_name = query_engine.escape_string(tb_name)
data["meta_data"] = query_engine.get_table_meta_data(
db_name=db_name, tb_name=tb_name
)
Expand Down Expand Up @@ -91,8 +95,6 @@ def export(request):
"""导出数据字典"""
instance_name = request.GET.get("instance_name", "")
db_name = request.GET.get("db_name", "")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

try:
instance = user_instances(
Expand All @@ -104,7 +106,7 @@ def export(request):

# 普通用户仅可以获取指定数据库的字典信息
if db_name:
dbs = [db_name]
dbs = [query_engine.escape_string(db_name)]
# 管理员可以导出整个实例的字典信息
elif request.user.is_superuser:
dbs = query_engine.get_all_databases().rows
Expand Down
16 changes: 14 additions & 2 deletions sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def info(self):
"""返回引擎简介"""
return "Base engine"

def escape_string(self, value: str) -> str:
"""参数转义"""
return value

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down Expand Up @@ -167,7 +171,15 @@ def filter_sql(self, sql="", limit_num=0):
"""给查询语句增加结果级限制或者改写语句, 返回修改后的语句"""
return sql.strip()

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs
):
"""实际查询 返回一个ResultSet"""
return ResultSet()

Expand All @@ -180,7 +192,7 @@ def execute_check(self, db_name=None, sql=""):
"""执行语句的检查 返回一个ReviewSet"""
return ReviewSet()

def execute(self):
def execute(self, **kwargs):
"""执行语句 返回一个ReviewSet"""
return ReviewSet()

Expand Down
40 changes: 28 additions & 12 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: UTF-8 -*-
from clickhouse_driver import connect
from clickhouse_driver.util.escape import escape_chars_map
from sql.utils.sql_utils import get_syntax_type
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
Expand Down Expand Up @@ -49,6 +50,10 @@ def name(self):
def info(self):
return "ClickHouse engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return "%s" % "".join(escape_chars_map.get(c, c) for c in value)

@property
def auto_backup(self):
"""是否支持备份"""
Expand All @@ -63,11 +68,9 @@ def server_version(self):

def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
sql = f"""select engine
from system.tables
where database='{tb_name.split('.')[0]}'
and name='{tb_name.split('.')[1]}'"""
query_result = self.query(sql=sql)
db, tb = tb_name.split(".")
sql = f"""select engine from system.tables where database=%(db)s and name=%(tb)s"""
query_result = self.query(sql=sql, parameters={"db": db, "tb": tb})
if query_result.rows:
result = {"status": 1, "engine": query_result.rows[0][0]}
else:
Expand Down Expand Up @@ -104,15 +107,20 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
from
system.columns
where
database = '{db_name}'
and table = '{tb_name}';"""
result = self.query(db_name=db_name, sql=sql)
database = %(db_name)s
and table = %(tb_name)s;"""
result = self.query(
db_name=db_name,
sql=sql,
parameters={"db_name": db_name, "tb_name": tb_name},
)
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
tb_name = self.escape_string(tb_name)
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)

Expand All @@ -121,13 +129,21 @@ def describe_table(self, db_name, tb_name, **kwargs):
)
return result

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
cursor.execute(sql, parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down Expand Up @@ -462,14 +478,14 @@ def execute_workflow(self, workflow):
break
return execute_result

def execute(self, db_name=None, sql="", close_conn=True):
def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
try:
cursor = conn.cursor()
for statement in sqlparse.split(sql):
cursor.execute(statement)
cursor.execute(statement, parameters)
cursor.close()
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
Expand Down
8 changes: 6 additions & 2 deletions sql/engines/goinception.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def get_backup_connection():
autocommit=True,
)

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return MySQLdb.escape_string(value).decode("utf-8")

def execute_check(self, instance=None, db_name=None, sql=""):
"""inception check"""
# 判断如果配置了隧道则连接隧道
Expand Down Expand Up @@ -282,8 +286,8 @@ def set_variable(self, variable_name, variable_value):

def osc_control(self, **kwargs):
"""控制osc执行,获取进度、终止、暂停、恢复等"""
sqlsha1 = MySQLdb.escape_string(kwargs.get("sqlsha1")).decode("utf-8")
command = MySQLdb.escape_string(kwargs.get("command")).decode("utf-8")
sqlsha1 = self.escape_string(kwargs.get("sqlsha1", ""))
command = self.escape_string(kwargs.get("command", ""))
if command == "get":
sql = f"inception get osc_percent '{sqlsha1}';"
else:
Expand Down
Loading

0 comments on commit 7921044

Please sign in to comment.