Skip to content

Commit

Permalink
(improvement)(Headless) Dataset supports query mode settings, and the…
Browse files Browse the repository at this point in the history
… chat layer supports tag mode (#802)
  • Loading branch information
lexluo09 authored Mar 12, 2024
1 parent 1d91a97 commit ae7acb4
Show file tree
Hide file tree
Showing 26 changed files with 171 additions and 207 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.api.pojo;

import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import lombok.Data;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import lombok.Data;

@Data
public class DataSetSchema {
Expand All @@ -22,6 +22,7 @@ public class DataSetSchema {
private Set<SchemaElement> tagValues = new HashSet<>();
private SchemaElement entity = new SchemaElement();
private QueryConfig queryConfig;
private QueryType queryType;

public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo sema
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {

Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());

if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
package com.tencent.supersonic.chat.core.parser;

import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

/**
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
Expand All @@ -38,50 +26,12 @@ public void parse(QueryContext queryContext, ChatContext chatContext) {
// 1.init S2SQL
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType
QueryType queryType = getQueryType(queryContext, semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}

private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
return QueryType.ID;
}
//1. entity queryType
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());

if (CollectionUtils.isNotEmpty(whereFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
return QueryType.ID;
}
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
return QueryType.TAG;
}
}
}
//2. metric queryType
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
if (containMetric) {
return QueryType.METRIC;
}
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long dataSetId = parseInfo.getDataSetId();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
parseInfo.setQueryType(dataSetSchema.getQueryType());
}
return QueryType.ID;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@

import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import java.util.HashMap;
Expand Down Expand Up @@ -53,12 +51,6 @@ public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = convertQueryStruct();
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();

OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (optimizationConfig.isUseS2SqlSwitch()) {
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getQuerySQL());
}

SemanticQueryResp semanticQueryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
String text = generateTableText(semanticQueryResp);
Map<String, Object> properties = parseInfo.getProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,38 @@
package com.tencent.supersonic.chat.core.query.rule;

import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

@Slf4j
@ToString
Expand Down Expand Up @@ -107,30 +105,20 @@ private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema seman
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetIds.iterator().next()));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> tag2Values = new HashMap<>();

for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
SchemaElement element = schemaMatch.getElement();
element.setOrder(1 - schemaMatch.getSimilarity());
switch (element.getType()) {
case ID:
SchemaElement entityElement = semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
} else {
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
addToValues(semanticSchema, SchemaElementType.ENTITY, id2Values, schemaMatch);
break;
case TAG_VALUE:
addToValues(semanticSchema, SchemaElementType.TAG, tag2Values, schemaMatch);
break;
case VALUE:
SchemaElement dimElement = semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
if (dimElement != null) {
if (dim2Values.containsKey(element.getId())) {
dim2Values.get(element.getId()).add(schemaMatch);
} else {
dim2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
addToValues(semanticSchema, SchemaElementType.DIMENSION, dim2Values, schemaMatch);
break;
case DIMENSION:
parseInfo.getDimensions().add(element);
Expand All @@ -145,43 +133,53 @@ private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema seman
}
}

if (!id2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
addFilters(parseInfo, semanticSchema, entry, SchemaElementType.ENTITY);
}
}
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
addToFilters(tag2Values, parseInfo, semanticSchema, SchemaElementType.TAG);
}

if (!dim2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dim2Values.entrySet()) {
addFilters(parseInfo, semanticSchema, entry, SchemaElementType.DIMENSION);
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values, SemanticParseInfo parseInfo,
SemanticSchema semanticSchema, SchemaElementType entity) {
if (Objects.isNull(id2Values) || id2Values.isEmpty()) {
return;
}
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());

if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.setEntity(semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
parseInfo.getDimensionFilters().add(dimensionFilter);
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
}
}
}

private void addFilters(SemanticParseInfo parseInfo, SemanticSchema semanticSchema,
Entry<Long, List<SchemaElementMatch>> entry, SchemaElementType elementType) {
SchemaElement dimension = semanticSchema.getElement(elementType, entry.getKey());

if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
private void addToValues(SemanticSchema semanticSchema, SchemaElementType entity,
Map<Long, List<SchemaElementMatch>> id2Values, SchemaElementMatch schemaMatch) {
SchemaElement element = schemaMatch.getElement();
SchemaElement entityElement = semanticSchema.getElement(entity, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
} else {
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
}

Expand All @@ -199,11 +197,6 @@ public QueryResult execute(User user) {
QueryResult queryResult = new QueryResult();
QueryStructReq queryStructReq = convertQueryStruct();

OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (optimizationConfig.isUseS2SqlSwitch()) {
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
queryStructReq.setCorrectS2SQL(parseInfo.getSqlInfo().getCorrectS2SQL());
}
SemanticQueryResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);

if (queryResp != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.core.query.rule.tag;

import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TAG;

import org.springframework.stereotype.Component;

Expand All @@ -14,7 +14,7 @@ public class TagDetailQuery extends TagSemanticQuery {

public TagDetailQuery() {
super();
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1)
queryMatcher.addOption(TAG, REQUIRED, AT_LEAST, 1)
.addOption(ID, REQUIRED, AT_LEAST, 1);
}

Expand Down
Loading

0 comments on commit ae7acb4

Please sign in to comment.