Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): sql pairs #961

Open
wants to merge 52 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d9e5f94
add api template and remove user_id
cyyeh Nov 25, 2024
346705f
fix conflict
cyyeh Dec 3, 2024
345ff22
fix conflict
cyyeh Dec 3, 2024
a17ee4a
update
cyyeh Nov 26, 2024
3eb4911
fix conflict
cyyeh Dec 3, 2024
983c5ab
reorganize tests
cyyeh Nov 28, 2024
43ebbf8
fix typo
cyyeh Nov 28, 2024
ba8d18c
fix typo
cyyeh Nov 28, 2024
efe9e07
fix typo
cyyeh Nov 28, 2024
95138a2
fix typo
cyyeh Nov 28, 2024
c252c5b
add sql_pairs_preparation tests
cyyeh Nov 28, 2024
0bde109
fix sql_pairs_deletion and add test
cyyeh Nov 28, 2024
42ad598
add test
cyyeh Nov 28, 2024
35fdc04
update
cyyeh Nov 28, 2024
9f25ecf
fix sql pairs
cyyeh Nov 29, 2024
87b90db
fix conflict
cyyeh Dec 3, 2024
63999d4
update api doc
cyyeh Nov 29, 2024
f0b6085
remove timer
cyyeh Dec 4, 2024
efe1a3b
fix conflicts
cyyeh Dec 9, 2024
438598a
fix missing code
cyyeh Dec 9, 2024
db43e2f
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 9, 2024
80292db
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 9, 2024
688a615
fix conflicts
cyyeh Dec 12, 2024
e5b8f7b
fix conflicts
cyyeh Dec 16, 2024
a0dd555
refactor
cyyeh Dec 16, 2024
7c48e9d
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 16, 2024
5a7a337
fix generator run
cyyeh Dec 16, 2024
af26063
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 16, 2024
bc2c209
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 18, 2024
0079454
merge
cyyeh Dec 18, 2024
485630a
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 19, 2024
9af56db
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 19, 2024
d556c07
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 19, 2024
7ab7ff4
reformat configs
cyyeh Dec 19, 2024
dc1a3ab
remove pipeline visualization
cyyeh Dec 19, 2024
28a8eea
remove redundant test file
cyyeh Dec 19, 2024
1a9160d
reformat test code
cyyeh Dec 19, 2024
ffc836d
refine routing
cyyeh Dec 19, 2024
ffcf94c
fix component return type
cyyeh Dec 19, 2024
23348ce
change delete_sql_pairs pipeline order
cyyeh Dec 19, 2024
78b8409
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 19, 2024
b3dba11
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 20, 2024
d309bd8
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 20, 2024
b28790c
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 21, 2024
08fbecc
fix conflict
cyyeh Dec 23, 2024
624b7a7
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 23, 2024
defc559
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 23, 2024
ea91017
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 25, 2024
1afc694
fix conflicts
cyyeh Dec 26, 2024
1e21fcb
fix conflict
cyyeh Dec 26, 2024
7d96661
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 26, 2024
b74d8cf
Merge branch 'main' into feat/ai-service/sql-pairs
cyyeh Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ data:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini
- name: sql_pairs_preparation
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about rename it to sql_pairs_indexing? i think it will be more suitable style as other

- name: sql_pairs_deletion
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
- name: sql_pairs_retrieval
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
- name: sql_executor
engine: wren_ui
- name: chart_generation
Expand All @@ -168,6 +178,5 @@ data:
query_cache_ttl: 3600
langfuse_host: https://cloud.langfuse.com
langfuse_enable: true
enable_timer: false
logging_level: DEBUG
development: false
12 changes: 11 additions & 1 deletion docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ pipes:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini
- document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini
name: sql_pairs_preparation
- document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
name: sql_pairs_deletion
- document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
name: sql_pairs_retrieval
llm: litellm_llm.gpt-4o-mini
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you sort the pipe name on top? it will be better for us

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

- name: preprocess_sql_data
llm: litellm_llm.gpt-4o-mini
- name: sql_executor
Expand All @@ -118,6 +129,5 @@ settings:
query_cache_ttl: 3600
langfuse_host: https://cloud.langfuse.com
langfuse_enable: true
enable_timer: false
logging_level: DEBUG
development: false
4 changes: 0 additions & 4 deletions wren-ai-service/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ For a comprehensive understanding of how to evaluate the pipelines, please refer

### Estimate the Speed of the Pipeline

- to evaluate the speed of the pipeline, you can enable the timer
- add environment variables `ENABLE_TIMER=1` in `.env.dev`
- restart wren ai service
- check the logs in the terminal
- to run the load test
- setup `DATASET_NAME` in `.env.dev`
- adjust test config if needed
Expand Down
1 change: 0 additions & 1 deletion wren-ai-service/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ The configuration file (`config.yaml`) is structured into several sections, each
query_cache_ttl: <cache_ttl_in_seconds>
langfuse_host: <langfuse_endpoint>
langfuse_enable: <true/false>
enable_timer: <true/false>
logging_level: <log_level>
development: <true/false>
```
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def _score_metrics(self, test_case: LLMTestCase, result: TestResult) -> None:
@observe(name="Summary Trace", capture_input=False, capture_output=False)
def _average_score(self, meta: dict) -> None:
langfuse_context.update_current_trace(
session_id=meta["session_id"],
user_id=meta["user_id"],
session_id=meta.get("session_id"),
user_id=meta.get("user_id"),
metadata=trace_metadata(meta, type="summary"),
)

Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ async def process(self, query: dict) -> dict:
}

langfuse_context.update_current_trace(
session_id=self._meta["session_id"],
user_id=self._meta["user_id"],
session_id=self._meta.get("session_id"),
user_id=self._meta.get("user_id"),
metadata=trace_metadata(self._meta, type=prediction["type"]),
)

Expand All @@ -159,8 +159,8 @@ async def flat(self, prediction: dict, **kwargs) -> dict:

langfuse_context.update_current_trace(
name=f"Prediction Process - Shallow Trace for {prediction['input']} ",
session_id=self._meta["session_id"],
user_id=self._meta["user_id"],
session_id=self._meta.get("session_id"),
user_id=self._meta.get("user_id"),
metadata={
**trace_metadata(self._meta, type=prediction["type"]),
"source_trace_id": prediction["source_trace_id"],
Expand Down
1 change: 0 additions & 1 deletion wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class Settings(BaseSettings):
langfuse_enable: bool = Field(default=True)

# debug config
enable_timer: bool = Field(default=False)
logging_level: str = Field(default="INFO")
development: bool = Field(default=False)

Expand Down
102 changes: 47 additions & 55 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,7 @@
from src.config import Settings
from src.core.pipeline import PipelineComponent
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines import indexing
from src.pipelines.generation import (
chart_adjustment,
chart_generation,
data_assistance,
followup_sql_generation,
intent_classification,
question_recommendation,
relationship_recommendation,
semantics_description,
sql_answer,
sql_breakdown,
sql_correction,
sql_expansion,
sql_explanation,
sql_generation,
sql_regeneration,
sql_summary,
)
from src.pipelines.retrieval import (
historical_question,
preprocess_sql_data,
retrieval,
sql_executor,
)
from src.pipelines import generation, indexing, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
from src.web.v1.services.chart import ChartService
Expand All @@ -41,8 +17,9 @@
from src.web.v1.services.semantics_preparation import SemanticsPreparationService
from src.web.v1.services.sql_answer import SqlAnswerService
from src.web.v1.services.sql_expansion import SqlExpansionService
from src.web.v1.services.sql_explanation import SQLExplanationService
from src.web.v1.services.sql_regeneration import SQLRegenerationService
from src.web.v1.services.sql_explanation import SqlExplanationService
from src.web.v1.services.sql_pairs_preparation import SqlPairsPreparationService
from src.web.v1.services.sql_regeneration import SqlRegenerationService

logger = logging.getLogger("wren-ai-service")

Expand All @@ -59,8 +36,9 @@ class ServiceContainer:
chart_adjustment_service: ChartAdjustmentService
sql_answer_service: SqlAnswerService
sql_expansion_service: SqlExpansionService
sql_explanation_service: SQLExplanationService
sql_regeneration_service: SQLRegenerationService
sql_explanation_service: SqlExplanationService
sql_regeneration_service: SqlRegenerationService
sql_pairs_preparation_service: SqlPairsPreparationService


@dataclass
Expand All @@ -80,7 +58,7 @@ def create_service_container(
return ServiceContainer(
semantics_description=SemanticsDescription(
pipelines={
"semantics_description": semantics_description.SemanticsDescription(
"semantics_description": generation.SemanticsDescription(
**pipe_components["semantics_description"],
)
},
Expand All @@ -103,10 +81,10 @@ def create_service_container(
),
ask_service=AskService(
pipelines={
"intent_classification": intent_classification.IntentClassification(
"intent_classification": generation.IntentClassification(
**pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
"data_assistance": generation.DataAssistance(
**pipe_components["data_assistance"]
),
"retrieval": retrieval.Retrieval(
Expand All @@ -115,63 +93,66 @@ def create_service_container(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
),
"historical_question": historical_question.HistoricalQuestion(
"historical_question_retrieval": retrieval.HistoricalQuestionRetrieval(
**pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
"sql_pairs_retrieval": retrieval.SqlPairsRetrieval(
**pipe_components["sql_pairs_retrieval"],
),
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration(
"followup_sql_generation": generation.FollowUpSQLGeneration(
**pipe_components["followup_sql_generation"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
**query_cache,
),
chart_service=ChartService(
pipelines={
"sql_executor": sql_executor.SQLExecutor(
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
),
"chart_generation": chart_generation.ChartGeneration(
"chart_generation": generation.ChartGeneration(
**pipe_components["chart_generation"],
),
},
**query_cache,
),
chart_adjustment_service=ChartAdjustmentService(
pipelines={
"sql_executor": sql_executor.SQLExecutor(
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
),
"chart_adjustment": chart_adjustment.ChartAdjustment(
"chart_adjustment": generation.ChartAdjustment(
**pipe_components["chart_adjustment"],
),
},
**query_cache,
),
sql_answer_service=SqlAnswerService(
pipelines={
"preprocess_sql_data": preprocess_sql_data.PreprocessSqlData(
"preprocess_sql_data": retrieval.PreprocessSqlData(
**pipe_components["preprocess_sql_data"],
),
"sql_answer": sql_answer.SQLAnswer(
"sql_answer": generation.SQLAnswer(
**pipe_components["sql_answer"],
),
},
**query_cache,
),
ask_details_service=AskDetailsService(
pipelines={
"sql_breakdown": sql_breakdown.SQLBreakdown(
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
Expand All @@ -184,45 +165,45 @@ def create_service_container(
table_retrieval_size=settings.table_retrieval_size,
table_column_retrieval_size=settings.table_column_retrieval_size,
),
"sql_expansion": sql_expansion.SQLExpansion(
"sql_expansion": generation.SQLExpansion(
**pipe_components["sql_expansion"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
**query_cache,
),
sql_explanation_service=SQLExplanationService(
sql_explanation_service=SqlExplanationService(
pipelines={
"sql_explanation": sql_explanation.SQLExplanation(
"sql_explanation": generation.SQLExplanation(
**pipe_components["sql_explanation"],
)
},
**query_cache,
),
sql_regeneration_service=SQLRegenerationService(
sql_regeneration_service=SqlRegenerationService(
pipelines={
"sql_regeneration": sql_regeneration.SQLRegeneration(
"sql_regeneration": generation.SQLRegeneration(
**pipe_components["sql_regeneration"],
)
},
**query_cache,
),
relationship_recommendation=RelationshipRecommendation(
pipelines={
"relationship_recommendation": relationship_recommendation.RelationshipRecommendation(
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
)
},
**query_cache,
),
question_recommendation=QuestionRecommendation(
pipelines={
"question_recommendation": question_recommendation.QuestionRecommendation(
"question_recommendation": generation.QuestionRecommendation(
**pipe_components["question_recommendation"],
),
"retrieval": retrieval.Retrieval(
Expand All @@ -231,12 +212,23 @@ def create_service_container(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
),
"sql_generation": sql_generation.SQLGeneration(
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
},
**query_cache,
),
sql_pairs_preparation_service=SqlPairsPreparationService(
pipelines={
"sql_pairs_preparation": indexing.SqlPairsPreparation(
**pipe_components["sql_pairs_preparation"],
),
"sql_pairs_deletion": indexing.SqlPairsDeletion(
**pipe_components["sql_pairs_deletion"],
),
},
**query_cache,
),
)


Expand Down
17 changes: 16 additions & 1 deletion wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import aiohttp
import orjson
import pytz
from haystack import component
from haystack import Document, component

from src.core.engine import (
Engine,
Expand All @@ -19,6 +19,21 @@
logger = logging.getLogger("wren-ai-service")


@component
class ScoreFilter:
@component.output_types(
documents=List[Document],
)
def run(self, documents: List[Document], score: float = 0.9):
return {
"documents": sorted(
filter(lambda document: document.score >= score, documents),
key=lambda document: document.score,
reverse=True,
)
}


@component
class SQLBreakdownGenPostProcessor:
def __init__(self, engine: Engine):
Expand Down
Loading
Loading