Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine查询增加参数化选项,修改sql参数化方式 #2112

Merged
merged 5 commits into from
Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
| MsSQL | √ | × | √ | × | √ | × | × | × | × | × |
| Redis | √ | × | √ | × | × | × | × | × | × | × |
| PgSQL | √ | × | √ | × | × | × | × | × | × | × |
| Oracle | √ | √ | √ | √ | √ | × | × | × | × | × |
| MongoDB | √ | √ | √ | × | × | × | × | √ | × | × |
| Oracle | √ | √ | √ | √ | √ | × | | × | × | × |
| MongoDB | √ | √ | √ | × | × | × | | √ | × | × |
| Phoenix | √ | × | √ | × | × | × | × | × | × | × |
| ODPS | √ | × | × | × | × | × | × | × | × | × |
| ClickHouse | √ | √ | √ | × | × | × | × | × | × | × |
Expand Down
4 changes: 1 addition & 3 deletions sql/binlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ def del_binlog(request):
result = {"status": 1, "msg": "实例不存在", "data": []}
return HttpResponse(json.dumps(result), content_type="application/json")

# escape
binlog = MySQLdb.escape_string(binlog).decode("utf-8")

if binlog:
query_engine = get_engine(instance=instance)
binlog = query_engine.escape_string(binlog)
query_result = query_engine.query(sql=rf"purge master logs to '{binlog}';")
if query_result.error is None:
result = {"status": 0, "msg": "清理成功", "data": ""}
Expand Down
8 changes: 5 additions & 3 deletions sql/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def table_list(request):
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
data = query_engine.get_group_tables_by_db(db_name=db_name)
res = {"status": 0, "data": data}
except Instance.DoesNotExist:
Expand All @@ -50,13 +51,16 @@ def table_info(request):
db_name = request.GET.get("db_name", "")
tb_name = request.GET.get("tb_name", "")
db_type = request.GET.get("db_type", "")

if instance_name and db_name and tb_name:
data = {}
try:
instance = Instance.objects.get(
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
tb_name = query_engine.escape_string(tb_name)
data["meta_data"] = query_engine.get_table_meta_data(
db_name=db_name, tb_name=tb_name
)
Expand Down Expand Up @@ -91,8 +95,6 @@ def export(request):
"""导出数据字典"""
instance_name = request.GET.get("instance_name", "")
db_name = request.GET.get("db_name", "")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

try:
instance = user_instances(
Expand All @@ -104,7 +106,7 @@ def export(request):

# 普通用户仅可以获取指定数据库的字典信息
if db_name:
dbs = [db_name]
dbs = [query_engine.escape_string(db_name)]
# 管理员可以导出整个实例的字典信息
elif request.user.is_superuser:
dbs = query_engine.get_all_databases().rows
Expand Down
16 changes: 14 additions & 2 deletions sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def info(self):
"""返回引擎简介"""
return "Base engine"

def escape_string(self, value: str) -> str:
"""参数转义"""
return value

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down Expand Up @@ -167,7 +171,15 @@ def filter_sql(self, sql="", limit_num=0):
"""给查询语句增加结果级限制或者改写语句, 返回修改后的语句"""
return sql.strip()

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs
):
"""实际查询 返回一个ResultSet"""
return ResultSet()

Expand All @@ -180,7 +192,7 @@ def execute_check(self, db_name=None, sql=""):
"""执行语句的检查 返回一个ReviewSet"""
return ReviewSet()

def execute(self):
def execute(self, **kwargs):
"""执行语句 返回一个ReviewSet"""
return ReviewSet()

Expand Down
40 changes: 28 additions & 12 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: UTF-8 -*-
from clickhouse_driver import connect
from clickhouse_driver.util.escape import escape_chars_map
from sql.utils.sql_utils import get_syntax_type
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
Expand Down Expand Up @@ -49,6 +50,10 @@ def name(self):
def info(self):
return "ClickHouse engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return "%s" % "".join(escape_chars_map.get(c, c) for c in value)

@property
def auto_backup(self):
"""是否支持备份"""
Expand All @@ -63,11 +68,9 @@ def server_version(self):

def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
sql = f"""select engine
from system.tables
where database='{tb_name.split('.')[0]}'
and name='{tb_name.split('.')[1]}'"""
query_result = self.query(sql=sql)
db, tb = tb_name.split(".")
sql = f"""select engine from system.tables where database=%(db)s and name=%(tb)s"""
query_result = self.query(sql=sql, parameters={"db": db, "tb": tb})
if query_result.rows:
result = {"status": 1, "engine": query_result.rows[0][0]}
else:
Expand Down Expand Up @@ -104,15 +107,20 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
from
system.columns
where
database = '{db_name}'
and table = '{tb_name}';"""
result = self.query(db_name=db_name, sql=sql)
database = %(db_name)s
and table = %(tb_name)s;"""
result = self.query(
db_name=db_name,
sql=sql,
parameters={"db_name": db_name, "tb_name": tb_name},
)
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
tb_name = self.escape_string(tb_name)
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)

Expand All @@ -121,13 +129,21 @@ def describe_table(self, db_name, tb_name, **kwargs):
)
return result

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
cursor.execute(sql, parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down Expand Up @@ -462,14 +478,14 @@ def execute_workflow(self, workflow):
break
return execute_result

def execute(self, db_name=None, sql="", close_conn=True):
def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
try:
cursor = conn.cursor()
for statement in sqlparse.split(sql):
cursor.execute(statement)
cursor.execute(statement, parameters)
cursor.close()
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
Expand Down
8 changes: 6 additions & 2 deletions sql/engines/goinception.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def get_backup_connection():
autocommit=True,
)

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return MySQLdb.escape_string(value).decode("utf-8")

def execute_check(self, instance=None, db_name=None, sql=""):
"""inception check"""
# 判断如果配置了隧道则连接隧道
Expand Down Expand Up @@ -282,8 +286,8 @@ def set_variable(self, variable_name, variable_value):

def osc_control(self, **kwargs):
"""控制osc执行,获取进度、终止、暂停、恢复等"""
sqlsha1 = MySQLdb.escape_string(kwargs.get("sqlsha1")).decode("utf-8")
command = MySQLdb.escape_string(kwargs.get("command")).decode("utf-8")
sqlsha1 = self.escape_string(kwargs.get("sqlsha1", ""))
command = self.escape_string(kwargs.get("command", ""))
if command == "get":
sql = f"inception get osc_percent '{sqlsha1}';"
else:
Expand Down
Loading