From 8b69d57c4b69327d126007d9765673dd88bfb355 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:04:17 +0800 Subject: [PATCH] [improvement][chat] Fix the issue with the DatabaseMatchStrategy variable under multi-threading (#1963) --- .../common/calcite/Configuration.java | 2 +- .../common/calcite/SqlDialectFactory.java | 9 +-- .../FieldAliasReplaceNameVisitor.java | 6 +- .../common/jsqlparser/SqlReplaceHelper.java | 1 + .../provider/ZhipuModelFactory.java | 5 +- .../jsqlparser/SqlReplaceHelperTest.java | 57 ++++++++++--------- .../chat/mapper/DatabaseMatchStrategy.java | 6 +- 7 files changed, 44 insertions(+), 42 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java b/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java index 516baeccf7..4bae99e581 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java @@ -81,7 +81,7 @@ public static SqlParser.Config getParserConfig(EngineType engineType) { .setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance()) .setLex(Lex.BIG_QUERY); if (EngineType.HANADB.equals(engineType)) { - parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE); + parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE); } parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED); diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java index c549ce88a4..ab7aea92f6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java @@ -21,10 +21,11 @@ public class SqlDialectFactory { .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false); - public static final Context HANADB_CONTEXT = SqlDialect.EMPTY_CONTEXT - .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") - .withIdentifierQuoteString("\"").withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) - .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true); + public static final Context HANADB_CONTEXT = + SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY) + .withLiteralQuoteString("'").withIdentifierQuoteString("\"") + .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) + .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true); private static Map sqlDialectMap; static { diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java index 91f391bc03..37bff0bc56 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java @@ -1,15 +1,13 @@ package com.tencent.supersonic.common.jsqlparser; +import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter; +import org.apache.commons.lang3.StringUtils; import java.util.HashMap; import java.util.Map; -import org.apache.commons.lang3.StringUtils; - -import net.sf.jsqlparser.expression.Alias; - public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter { private Map fieldNameMap; diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 7d4cd30f9f..111443a02a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -465,6 +465,7 @@ public static String replaceAliasFieldName(String sql, Map field } return selectStatement.toString(); } + public static String replaceAlias(String sql) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index 4fc14af222..b326db6826 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -9,6 +9,7 @@ import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Service; + import static java.time.Duration.ofSeconds; @Service @@ -32,8 +33,8 @@ public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelCo return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) .apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName()) .maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60)) - .connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60)).readTimeout(ofSeconds(60)) - .logRequests(embeddingModelConfig.getLogRequests()) + .connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60)) + .readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests()) .logResponses(embeddingModelConfig.getLogResponses()).build(); } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index 7bd5ee25ec..631ab59cf2 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -326,34 +326,35 @@ void testReplaceAlias() { @Test void testReplaceAliasFieldName() { - Map map = new HashMap<>(); - map.put("总访问次数", "\"总访问次数\""); - map.put("访问次数", "\"访问次数\""); - String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; - String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); - System.out.println(replaceSql); - Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " - + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", - replaceSql); - - sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " - + "group by 部门 order by 总访问次数 desc limit 10"; - replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); - System.out.println(replaceSql); - Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " - + "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql); - - sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " - + "group by 部门 order by 访问次数 desc limit 10"; - replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); - System.out.println(replaceSql); - Assert.assertEquals("SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, " - + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10", - replaceSql); + Map map = new HashMap<>(); + map.put("总访问次数", "\"总访问次数\""); + map.put("访问次数", "\"访问次数\""); + String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " + + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", + replaceSql); + + sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 总访问次数 desc limit 10"; + replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + + "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql); + + sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 访问次数 desc limit 10"; + replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals( + "SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, " + + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10", + replaceSql); } @Test diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index d1a406b5f2..5cc2d2fba9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -27,12 +27,12 @@ @Slf4j public class DatabaseMatchStrategy extends SingleMatchStrategy { - private List allElements; + private ThreadLocal> allElements = ThreadLocal.withInitial(ArrayList::new); @Override public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - this.allElements = getSchemaElements(chatQueryContext); + allElements.set(getSchemaElements(chatQueryContext)); return super.match(chatQueryContext, terms, detectDataSetIds); } @@ -43,7 +43,7 @@ public List detectByStep(ChatQueryContext chatQueryContext, } Double metricDimensionThresholdConfig = getThreshold(chatQueryContext); - Map> nameToItems = getNameToItems(allElements); + Map> nameToItems = getNameToItems(allElements.get()); List results = new ArrayList<>(); for (Entry> entry : nameToItems.entrySet()) { String name = entry.getKey();