diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 145fbcd64..8fec91f76 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -128,8 +128,11 @@ data: - name: sql_answer llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui +<<<<<<< HEAD +======= - name: preprocess_sql_data llm: litellm_llm.gpt-4o-mini-2024-07-18 +>>>>>>> main - name: sql_breakdown llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui @@ -154,6 +157,19 @@ data: document_store: qdrant - name: data_assistance llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_preparation + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_deletion + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + - name: sql_pairs_retrieval + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: preprocess_sql_data + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui - name: chart_generation @@ -170,6 +186,5 @@ data: query_cache_ttl: 3600 langfuse_host: https://cloud.langfuse.com langfuse_enable: true - enable_timer: false logging_level: DEBUG development: false diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 8f2f8a804..25cadacf1 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -104,6 +104,17 @@ pipes: document_store: qdrant - name: data_assistance llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_preparation + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_deletion + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + - name: sql_pairs_retrieval + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: preprocess_sql_data llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor @@ -122,6 +133,5 @@ settings: query_cache_ttl: 3600 langfuse_host: https://cloud.langfuse.com langfuse_enable: true - enable_timer: false logging_level: DEBUG development: false diff --git a/wren-ai-service/README.md b/wren-ai-service/README.md index d7abca89b..063dc12fe 100644 --- a/wren-ai-service/README.md +++ b/wren-ai-service/README.md @@ -98,10 +98,6 @@ For a comprehensive understanding of how to evaluate the pipelines, please refer ### Estimate the Speed of the Pipeline -- to evaluate the speed of the pipeline, you can enable the timer - - add environment variables `ENABLE_TIMER=1` in `.env.dev` - - restart wren ai service - - check the logs in the terminal - to run the load test - setup `DATASET_NAME` in `.env.dev` - adjust test config if needed diff --git a/wren-ai-service/docs/configuration.md b/wren-ai-service/docs/configuration.md index c886409b1..c88c39ce8 100644 --- a/wren-ai-service/docs/configuration.md +++ b/wren-ai-service/docs/configuration.md @@ -134,7 +134,6 @@ The configuration file (`config.yaml`) is structured into several sections, each query_cache_ttl: langfuse_host: langfuse_enable: - enable_timer: logging_level: development: ``` diff --git a/wren-ai-service/eval/evaluation.py b/wren-ai-service/eval/evaluation.py index 797759f10..acf3f76e9 100644 --- a/wren-ai-service/eval/evaluation.py +++ b/wren-ai-service/eval/evaluation.py @@ -118,8 +118,8 @@ def _score_metrics(self, test_case: LLMTestCase, result: TestResult) -> None: @observe(name="Summary Trace", capture_input=False, capture_output=False) def _average_score(self, meta: dict) -> None: langfuse_context.update_current_trace( - session_id=meta["session_id"], - user_id=meta["user_id"], + session_id=meta.get("session_id"), + user_id=meta.get("user_id"), metadata=trace_metadata(meta, type="summary"), ) diff --git a/wren-ai-service/eval/pipelines.py b/wren-ai-service/eval/pipelines.py index acf363557..d7a09f106 100644 --- a/wren-ai-service/eval/pipelines.py +++ b/wren-ai-service/eval/pipelines.py @@ -142,8 +142,8 @@ async def process(self, query: dict) -> dict: } langfuse_context.update_current_trace( - session_id=self._meta["session_id"], - user_id=self._meta["user_id"], + session_id=self._meta.get("session_id"), + user_id=self._meta.get("user_id"), metadata=trace_metadata(self._meta, type=prediction["type"]), ) @@ -159,8 +159,8 @@ async def flat(self, prediction: dict, **kwargs) -> dict: langfuse_context.update_current_trace( name=f"Prediction Process - Shallow Trace for {prediction['input']} ", - session_id=self._meta["session_id"], - user_id=self._meta["user_id"], + session_id=self._meta.get("session_id"), + user_id=self._meta.get("user_id"), metadata={ **trace_metadata(self._meta, type=prediction["type"]), "source_trace_id": prediction["source_trace_id"], diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 4b5e8bf94..0fa694c2d 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -46,7 +46,6 @@ class Settings(BaseSettings): langfuse_enable: bool = Field(default=True) # debug config - enable_timer: bool = Field(default=False) logging_level: str = Field(default="INFO") development: bool = Field(default=False) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 3c8f5c758..af6c75898 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -6,31 +6,7 @@ from src.config import Settings from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider -from src.pipelines import indexing -from src.pipelines.generation import ( - chart_adjustment, - chart_generation, - data_assistance, - followup_sql_generation, - intent_classification, - question_recommendation, - relationship_recommendation, - semantics_description, - sql_answer, - sql_breakdown, - sql_correction, - sql_expansion, - sql_explanation, - sql_generation, - sql_regeneration, - sql_summary, -) -from src.pipelines.retrieval import ( - historical_question, - preprocess_sql_data, - retrieval, - sql_executor, -) +from src.pipelines import generation, indexing, retrieval from src.web.v1.services.ask import AskService from src.web.v1.services.ask_details import AskDetailsService from src.web.v1.services.chart import ChartService @@ -41,8 +17,9 @@ from src.web.v1.services.semantics_preparation import SemanticsPreparationService from src.web.v1.services.sql_answer import SqlAnswerService from src.web.v1.services.sql_expansion import SqlExpansionService -from src.web.v1.services.sql_explanation import SQLExplanationService -from src.web.v1.services.sql_regeneration import SQLRegenerationService +from src.web.v1.services.sql_explanation import SqlExplanationService +from src.web.v1.services.sql_pairs_preparation import SqlPairsPreparationService +from src.web.v1.services.sql_regeneration import SqlRegenerationService logger = logging.getLogger("wren-ai-service") @@ -59,8 +36,9 @@ class ServiceContainer: chart_adjustment_service: ChartAdjustmentService sql_answer_service: SqlAnswerService sql_expansion_service: SqlExpansionService - sql_explanation_service: SQLExplanationService - sql_regeneration_service: SQLRegenerationService + sql_explanation_service: SqlExplanationService + sql_regeneration_service: SqlRegenerationService + sql_pairs_preparation_service: SqlPairsPreparationService @dataclass @@ -80,7 +58,7 @@ def create_service_container( return ServiceContainer( semantics_description=SemanticsDescription( pipelines={ - "semantics_description": semantics_description.SemanticsDescription( + "semantics_description": generation.SemanticsDescription( **pipe_components["semantics_description"], ) }, @@ -103,10 +81,10 @@ def create_service_container( ), ask_service=AskService( pipelines={ - "intent_classification": intent_classification.IntentClassification( + "intent_classification": generation.IntentClassification( **pipe_components["intent_classification"], ), - "data_assistance": data_assistance.DataAssistance( + "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"] ), "retrieval": retrieval.Retrieval( @@ -115,19 +93,22 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "historical_question": historical_question.HistoricalQuestion( + "historical_question_retrieval": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], ), - "sql_generation": sql_generation.SQLGeneration( + "sql_pairs_retrieval": retrieval.SqlPairsRetrieval( + **pipe_components["sql_pairs_retrieval"], + ), + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration( + "followup_sql_generation": generation.FollowUpSQLGeneration( **pipe_components["followup_sql_generation"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -135,10 +116,10 @@ def create_service_container( ), chart_service=ChartService( pipelines={ - "sql_executor": sql_executor.SQLExecutor( + "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], ), - "chart_generation": chart_generation.ChartGeneration( + "chart_generation": generation.ChartGeneration( **pipe_components["chart_generation"], ), }, @@ -146,10 +127,10 @@ def create_service_container( ), chart_adjustment_service=ChartAdjustmentService( pipelines={ - "sql_executor": sql_executor.SQLExecutor( + "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], ), - "chart_adjustment": chart_adjustment.ChartAdjustment( + "chart_adjustment": generation.ChartAdjustment( **pipe_components["chart_adjustment"], ), }, @@ -157,10 +138,10 @@ def create_service_container( ), sql_answer_service=SqlAnswerService( pipelines={ - "preprocess_sql_data": preprocess_sql_data.PreprocessSqlData( + "preprocess_sql_data": retrieval.PreprocessSqlData( **pipe_components["preprocess_sql_data"], ), - "sql_answer": sql_answer.SQLAnswer( + "sql_answer": generation.SQLAnswer( **pipe_components["sql_answer"], ), }, @@ -168,10 +149,10 @@ def create_service_container( ), ask_details_service=AskDetailsService( pipelines={ - "sql_breakdown": sql_breakdown.SQLBreakdown( + "sql_breakdown": generation.SQLBreakdown( **pipe_components["sql_breakdown"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -184,29 +165,29 @@ def create_service_container( table_retrieval_size=settings.table_retrieval_size, table_column_retrieval_size=settings.table_column_retrieval_size, ), - "sql_expansion": sql_expansion.SQLExpansion( + "sql_expansion": generation.SQLExpansion( **pipe_components["sql_expansion"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, **query_cache, ), - sql_explanation_service=SQLExplanationService( + sql_explanation_service=SqlExplanationService( pipelines={ - "sql_explanation": sql_explanation.SQLExplanation( + "sql_explanation": generation.SQLExplanation( **pipe_components["sql_explanation"], ) }, **query_cache, ), - sql_regeneration_service=SQLRegenerationService( + sql_regeneration_service=SqlRegenerationService( pipelines={ - "sql_regeneration": sql_regeneration.SQLRegeneration( + "sql_regeneration": generation.SQLRegeneration( **pipe_components["sql_regeneration"], ) }, @@ -214,7 +195,7 @@ def create_service_container( ), relationship_recommendation=RelationshipRecommendation( pipelines={ - "relationship_recommendation": relationship_recommendation.RelationshipRecommendation( + "relationship_recommendation": generation.RelationshipRecommendation( **pipe_components["relationship_recommendation"], ) }, @@ -222,7 +203,7 @@ def create_service_container( ), question_recommendation=QuestionRecommendation( pipelines={ - "question_recommendation": question_recommendation.QuestionRecommendation( + "question_recommendation": generation.QuestionRecommendation( **pipe_components["question_recommendation"], ), "retrieval": retrieval.Retrieval( @@ -231,12 +212,23 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), }, **query_cache, ), + sql_pairs_preparation_service=SqlPairsPreparationService( + pipelines={ + "sql_pairs_preparation": indexing.SqlPairsPreparation( + **pipe_components["sql_pairs_preparation"], + ), + "sql_pairs_deletion": indexing.SqlPairsDeletion( + **pipe_components["sql_pairs_deletion"], + ), + }, + **query_cache, + ), ) diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index e69de29bb..559cec5ac 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -0,0 +1,35 @@ +from .chart_adjustment import ChartAdjustment +from .chart_generation import ChartGeneration +from .data_assistance import DataAssistance +from .followup_sql_generation import FollowUpSQLGeneration +from .intent_classification import IntentClassification +from .question_recommendation import QuestionRecommendation +from .relationship_recommendation import RelationshipRecommendation +from .semantics_description import SemanticsDescription +from .sql_answer import SQLAnswer +from .sql_breakdown import SQLBreakdown +from .sql_correction import SQLCorrection +from .sql_expansion import SQLExpansion +from .sql_explanation import SQLExplanation +from .sql_generation import SQLGeneration +from .sql_regeneration import SQLRegeneration +from .sql_summary import SQLSummary + +__all__ = [ + "SQLRegeneration", + "ChartGeneration", + "ChartAdjustment", + "DataAssistance", + "FollowUpSQLGeneration", + "IntentClassification", + "QuestionRecommendation", + "RelationshipRecommendation", + "SemanticsDescription", + "SQLAnswer", + "SQLBreakdown", + "SQLCorrection", + "SQLExpansion", + "SQLExplanation", + "SQLGeneration", + "SQLSummary", +] diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index a7ebd947c..d99a4b452 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -24,7 +24,6 @@ StackedBarChartSchema, chart_generation_instructions, ) -from src.utils import async_timer, timer from src.web.v1.services.chart_adjustment import ChartAdjustmentOption logger = logging.getLogger("wren-ai-service") @@ -133,7 +132,6 @@ def run( ## Start of Pipeline -@timer @observe(capture_input=False) def preprocess_data( data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor @@ -141,7 +139,6 @@ def preprocess_data( return chart_data_preprocessor.run(data) -@timer @observe(capture_input=False) def prompt( query: str, @@ -164,20 +161,18 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_chart_adjustment(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@timer @observe(capture_input=False) def post_process( generate_chart_adjustment: dict, vega_schema: Dict[str, Any], post_processor: ChartAdjustmentPostProcessor, ) -> dict: - return post_processor.run(generate_chart_adjustment.get("replies"), vega_schema) + return post_processor(generate_chart_adjustment.get("replies"), vega_schema) ## End of Pipeline @@ -232,7 +227,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Chart Adjustment") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index 5cb906f1f..80caba1df 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -24,7 +24,6 @@ StackedBarChartSchema, chart_generation_instructions, ) -from src.utils import async_timer, timer logger = logging.getLogger("wren-ai-service") @@ -112,7 +111,6 @@ def run( ## Start of Pipeline -@timer @observe(capture_input=False) def preprocess_data( data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor @@ -120,7 +118,6 @@ def preprocess_data( return chart_data_preprocessor.run(data) -@timer @observe(capture_input=False) def prompt( query: str, @@ -139,13 +136,11 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_chart(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@timer @observe(capture_input=False) def post_process( generate_chart: dict, @@ -208,7 +203,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Chart Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index d9b99afaa..4f0acc852 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -10,7 +10,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.utils import async_timer, timer from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -48,7 +47,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -73,7 +71,6 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def data_assistance(prompt: dict, generator: Any, query_id: str) -> dict: return await generator(prompt=prompt.get("prompt"), query_id=query_id) @@ -142,7 +139,6 @@ async def _get_streaming_results(query_id): except TimeoutError: break - @async_timer @observe(name="Data Assistance") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 7371c768a..efb6eb7ef 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any, List +from typing import Any, Dict, List from hamilton import base from hamilton.async_driver import AsyncDriver @@ -18,7 +18,6 @@ construct_instructions, sql_generation_system_prompt, ) -from src.utils import async_timer, timer from src.web.v1.services import Configuration from src.web.v1.services.ask import AskHistory @@ -105,6 +104,16 @@ Instructions: {{ instructions }} {% endif %} +{% if sql_samples %} +### SQL SAMPLES ### +{% for sample in sql_samples %} +Summary: +{{sample.summary}} +SQL: +{{sample.sql}} +{% endfor %} +{% endif %} + ### INPUT ### User's Follow-up Question: {{ query }} @@ -113,7 +122,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -122,6 +130,7 @@ def prompt( alert: str, configuration: Configuration, prompt_builder: PromptBuilder, + sql_samples: List[Dict] | None = None, ) -> dict: previous_query_summaries = [step.summary for step in history.steps if step.summary] @@ -133,16 +142,15 @@ def prompt( alert=alert, instructions=construct_instructions(configuration), current_time=show_current_time(configuration.timezone), + sql_samples=sql_samples, ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@async_timer @observe(capture_input=False) async def post_process( generate_sql_in_followup: dict, @@ -202,7 +210,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Follow-Up SQL Generation") async def run( self, @@ -210,6 +217,7 @@ async def run( contexts: List[str], history: AskHistory, configuration: Configuration = Configuration(), + sql_samples: List[Dict] | None = None, project_id: str | None = None, ): logger.info("Follow-Up SQL Generation pipeline is running...") @@ -221,6 +229,7 @@ async def run( "history": history, "project_id": project_id, "configuration": configuration, + "sql_samples": sql_samples, **self._components, **self._configs, }, diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 62022acd0..bddd06801 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -14,7 +14,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider from src.pipelines.common import build_table_ddl -from src.utils import async_timer, timer from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -109,7 +108,6 @@ ## Start of Pipeline -@async_timer @observe(capture_input=False, capture_output=False) async def embedding( query: str, embedder: Any, history: Optional[AskHistory] = None @@ -123,7 +121,6 @@ async def embedding( return await embedder.run(query) -@async_timer @observe(capture_input=False) async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dict: filters = { @@ -144,7 +141,6 @@ async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dic ) -@async_timer @observe(capture_input=False) async def dbschema_retrieval( table_retrieval: dict, embedding: dict, id: str, dbschema_retriever: Any @@ -181,7 +177,6 @@ async def dbschema_retrieval( return results["documents"] -@timer @observe() def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[str]: db_schemas = {} @@ -219,7 +214,6 @@ def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[str]: return db_schemas_in_ddl -@timer @observe(capture_input=False) def prompt( query: str, @@ -238,13 +232,11 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def classify_intent(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@timer @observe(capture_input=False) def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict: try: @@ -317,7 +309,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Intent Classification") async def run( self, query: str, id: Optional[str] = None, history: Optional[AskHistory] = None diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index d27751241..c1337bf55 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -10,7 +10,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.utils import async_timer, timer logger = logging.getLogger("wren-ai-service") @@ -46,7 +45,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -63,7 +61,6 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_answer(prompt: dict, generator: Any, query_id: str) -> dict: return await generator(prompt=prompt.get("prompt"), query_id=query_id) @@ -132,7 +129,6 @@ async def _get_streaming_results(query_id): except TimeoutError: break - @async_timer @observe(name="SQL Answer Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_breakdown.py b/wren-ai-service/src/pipelines/generation/sql_breakdown.py index a84ab8962..5f50f4036 100644 --- a/wren-ai-service/src/pipelines/generation/sql_breakdown.py +++ b/wren-ai-service/src/pipelines/generation/sql_breakdown.py @@ -15,10 +15,6 @@ TEXT_TO_SQL_RULES, SQLBreakdownGenPostProcessor, ) -from src.utils import ( - async_timer, - timer, -) logger = logging.getLogger("wren-ai-service") @@ -117,7 +113,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -131,13 +126,11 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_details(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@async_timer @observe(capture_input=False) async def post_process( generate_sql_details: dict, @@ -198,7 +191,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Breakdown Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 7b09c59e9..43d2ef78c 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -18,7 +18,6 @@ SQLGenPostProcessor, sql_generation_system_prompt, ) -from src.utils import async_timer, timer logger = logging.getLogger("wren-ai-service") @@ -55,7 +54,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompts( documents: List[Document], @@ -73,7 +71,6 @@ def prompts( ] -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[dict]: tasks = [] @@ -84,7 +81,6 @@ async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[ return await asyncio.gather(*tasks) -@async_timer @observe(capture_input=False) async def post_process( generate_sql_corrections: list[dict], @@ -142,7 +138,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Correction") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_expansion.py b/wren-ai-service/src/pipelines/generation/sql_expansion.py index 7373edc6e..78f930b16 100644 --- a/wren-ai-service/src/pipelines/generation/sql_expansion.py +++ b/wren-ai-service/src/pipelines/generation/sql_expansion.py @@ -13,7 +13,6 @@ from src.core.provider import LLMProvider from src.pipelines.common import show_current_time from src.pipelines.generation.utils.sql import SQLGenPostProcessor -from src.utils import async_timer, timer from src.web.v1.services import Configuration from src.web.v1.services.ask import AskHistory @@ -52,7 +51,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -69,13 +67,11 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_expansion(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@async_timer @observe(capture_input=False) async def post_process( generate_sql_expansion: dict, @@ -131,7 +127,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Sql Expansion Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_explanation.py b/wren-ai-service/src/pipelines/generation/sql_explanation.py index 5105faf00..99c9a8228 100644 --- a/wren-ai-service/src/pipelines/generation/sql_explanation.py +++ b/wren-ai-service/src/pipelines/generation/sql_explanation.py @@ -13,7 +13,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.utils import async_timer, timer from src.web.v1.services.sql_explanation import StepWithAnalysisResult logger = logging.getLogger("wren-ai-service") @@ -467,7 +466,6 @@ def run( ## Start of Pipeline -@timer @observe(capture_input=False) def preprocess( sql_analysis_results: List[dict], pre_processor: SQLAnalysisPreprocessor @@ -475,7 +473,6 @@ def preprocess( return pre_processor.run(sql_analysis_results) -@timer @observe(capture_input=False) def prompts( question: str, @@ -530,7 +527,6 @@ def prompts( ] -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_explanation(prompts: List[dict], generator: Any) -> List[dict]: async def _task(prompt: str, generator: Any): @@ -540,7 +536,6 @@ async def _task(prompt: str, generator: Any): return await asyncio.gather(*tasks) -@timer @observe(capture_input=False) def post_process( generate_sql_explanation: List[dict], @@ -610,7 +605,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Explanation Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index f33ca79e5..972f82441 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -18,7 +18,6 @@ construct_instructions, sql_generation_system_prompt, ) -from src.utils import async_timer, timer from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -56,11 +55,11 @@ ] } -{% if samples %} +{% if sql_samples %} ### SAMPLES ### -{% for sample in samples %} -Question: -{{sample.question}} +{% for sample in sql_samples %} +Summary: +{{sample.summary}} SQL: {{sample.sql}} {% endfor %} @@ -75,7 +74,6 @@ ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -84,7 +82,7 @@ def prompt( text_to_sql_rules: str, prompt_builder: PromptBuilder, configuration: Configuration | None = None, - samples: List[Dict] | None = None, + sql_samples: List[Dict] | None = None, ) -> dict: return prompt_builder.run( query=query, @@ -92,12 +90,11 @@ def prompt( exclude=exclude, text_to_sql_rules=text_to_sql_rules, instructions=construct_instructions(configuration), - samples=samples, + sql_samples=sql_samples, current_time=show_current_time(configuration.timezone), ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql( prompt: dict, @@ -106,7 +103,6 @@ async def generate_sql( return await generator(prompt=prompt.get("prompt")) -@async_timer @observe(capture_input=False) async def post_process( generate_sql: dict, @@ -162,7 +158,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Generation") async def run( self, @@ -170,7 +165,7 @@ async def run( contexts: List[str], exclude: List[Dict], configuration: Configuration = Configuration(), - samples: List[Dict] | None = None, + sql_samples: List[Dict] | None = None, project_id: str | None = None, ): logger.info("SQL Generation pipeline is running...") @@ -180,7 +175,7 @@ async def run( "query": query, "documents": contexts, "exclude": exclude, - "samples": samples, + "sql_samples": sql_samples, "project_id": project_id, "configuration": configuration, **self._components, diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index a317b88f1..1ddb19f53 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -13,7 +13,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider from src.pipelines.generation.utils.sql import SQLBreakdownGenPostProcessor -from src.utils import async_timer, timer from src.web.v1.services.sql_regeneration import ( SQLExplanationWithUserCorrections, ) @@ -95,7 +94,6 @@ def run( ## Start of Pipeline -@timer @observe(capture_input=False) def preprocess( description: str, @@ -108,7 +106,6 @@ def preprocess( ) -@timer @observe(capture_input=False) def sql_regeneration_prompt( preprocess: Dict[str, Any], @@ -117,7 +114,6 @@ def sql_regeneration_prompt( return prompt_builder.run(results=preprocess["results"]) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_regeneration( sql_regeneration_prompt: dict, @@ -126,7 +122,6 @@ async def generate_sql_regeneration( return await generator(prompt=sql_regeneration_prompt.get("prompt")) -@async_timer @observe(capture_input=False) async def sql_regeneration_post_process( generate_sql_regeneration: dict, @@ -187,7 +182,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL-Regeneration Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_summary.py b/wren-ai-service/src/pipelines/generation/sql_summary.py index 717bd5002..4823f902e 100644 --- a/wren-ai-service/src/pipelines/generation/sql_summary.py +++ b/wren-ai-service/src/pipelines/generation/sql_summary.py @@ -12,7 +12,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.utils import async_timer, timer logger = logging.getLogger("wren-ai-service") @@ -73,7 +72,6 @@ def run(self, sqls: List[str], replies: List[str]): ## Start of Pipeline -@timer @observe(capture_input=False) def prompt( query: str, @@ -88,13 +86,11 @@ def prompt( ) -@async_timer @observe(as_type="generation", capture_input=False) async def generate_sql_summary(prompt: dict, generator: Any) -> dict: return await generator(prompt=prompt.get("prompt")) -@timer def post_process( generate_sql_summary: dict, sqls: List[str], @@ -142,7 +138,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Summary") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/__init__.py b/wren-ai-service/src/pipelines/indexing/__init__.py index be638b142..4b8685755 100644 --- a/wren-ai-service/src/pipelines/indexing/__init__.py +++ b/wren-ai-service/src/pipelines/indexing/__init__.py @@ -87,9 +87,39 @@ async def run( return {"documents_written": documents_written} +@component +class SqlPairsCleaner: + def __init__(self, sql_pairs_store: DocumentStore) -> None: + self._sql_pairs_store = sql_pairs_store + + @component.output_types() + async def run(self, sql_pair_ids: List[str], id: Optional[str] = None) -> None: + filters = { + "operator": "AND", + "conditions": [ + {"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids}, + ], + } + + if id: + filters["conditions"].append( + {"field": "project_id", "operator": "==", "value": id} + ) + + return await self._sql_pairs_store.delete_documents(filters) + + # Put the pipelines imports here to avoid circular imports and make them accessible directly to the rest of packages from .db_schema import DBSchema # noqa: E402 from .historical_question import HistoricalQuestion # noqa: E402 +from .sql_pairs_deletion import SqlPairsDeletion # noqa: E402 +from .sql_pairs_preparation import SqlPairsPreparation # noqa: E402 from .table_description import TableDescription # noqa: E402 -__all__ = ["DBSchema", "TableDescription", "HistoricalQuestion"] +__all__ = [ + "DBSchema", + "TableDescription", + "HistoricalQuestion", + "SqlPairsDeletion", + "SqlPairsPreparation", +] diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs_deletion.py b/wren-ai-service/src/pipelines/indexing/sql_pairs_deletion.py new file mode 100644 index 000000000..944258f82 --- /dev/null +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs_deletion.py @@ -0,0 +1,67 @@ +import logging +import sys +from typing import Any, Dict, List, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import DocumentStoreProvider +from src.pipelines.indexing import SqlPairsCleaner + +logger = logging.getLogger("wren-ai-service") + + +## Start of Pipeline +@observe(capture_input=False, capture_output=False) +async def delete_sql_pairs( + sql_pairs_cleaner: SqlPairsCleaner, + sql_pair_ids: List[str], + id: Optional[str] = None, +) -> None: + return await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, id=id) + + +## End of Pipeline + + +class SqlPairsDeletion(BasicPipeline): + def __init__( + self, + document_store_provider: DocumentStoreProvider, + **kwargs, + ) -> None: + sql_pairs_store = document_store_provider.get_store(dataset_name="sql_pairs") + + self._components = { + "sql_pairs_cleaner": SqlPairsCleaner(sql_pairs_store), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + @observe(name="SQL Pairs Deletion") + async def run( + self, sql_pair_ids: List[str], id: Optional[str] = None + ) -> Dict[str, Any]: + logger.info("SQL Pairs Deletion pipeline is running...") + return await self._pipe.execute( + ["delete_sql_pairs"], + inputs={ + "sql_pair_ids": sql_pair_ids, + "id": id or "", + **self._components, + }, + ) + + +if __name__ == "__main__": + from src.pipelines.common import dry_run_pipeline + + dry_run_pipeline( + SqlPairsDeletion, + "sql_pairs_deletion", + sql_pair_ids=[], + ) diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs_preparation.py b/wren-ai-service/src/pipelines/indexing/sql_pairs_preparation.py new file mode 100644 index 000000000..55ab102bc --- /dev/null +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs_preparation.py @@ -0,0 +1,226 @@ +import asyncio +import logging +import sys +import uuid +from typing import Any, Dict, List, Optional + +import orjson +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack import Document, component +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.document_stores.types import DuplicatePolicy +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider +from src.pipelines.indexing import AsyncDocumentWriter, SqlPairsCleaner +from src.web.v1.services.sql_pairs_preparation import SqlPair + +logger = logging.getLogger("wren-ai-service") + + +sql_intention_generation_system_prompt = """ +### TASK ### + +You are a data analyst great at generating the concise and readable summary of a SQL query. + +### INSTRUCTIONS ### + +- Summary should be concise and readable. +- Summary should be no longer than 20 words. +- Don't rephrase keywords in the SQL query, just use them as they are. + +### OUTPUT ### + +You need to output a JSON object as following: +{ + "intention": "" +} +""" + +sql_intention_generation_user_prompt_template = """ +### INPUT ### +SQL: {{sql}} + +Please think step by step +""" + + +@component +class SqlPairsDescriptionConverter: + @component.output_types(documents=List[Document]) + def run(self, sql_pairs: List[Dict[str, Any]], id: Optional[str] = None): + logger.info("Converting SQL pairs to documents...") + + return { + "documents": [ + Document( + id=str(uuid.uuid4()), + meta=( + { + "sql_pair_id": sql_pair.get("id"), + "project_id": id, + "sql": sql_pair.get("sql"), + } + if id + else { + "sql_pair_id": sql_pair.get("id"), + "sql": sql_pair.get("sql"), + } + ), + content=sql_pair.get("intention"), + ) + for sql_pair in sql_pairs + ] + } + + +## Start of Pipeline +@observe(capture_input=False) +def prompts( + sql_pairs: List[SqlPair], + prompt_builder: PromptBuilder, +) -> List[dict]: + return [prompt_builder.run(sql=sql_pair.sql) for sql_pair in sql_pairs] + + +@observe(as_type="generation", capture_input=False) +async def generate_sql_intention( + prompts: List[dict], + sql_intention_generator: Any, +) -> List[dict]: + async def _task(prompt: str, generator: Any): + return await generator(prompt=prompt.get("prompt")) + + tasks = [_task(prompt, sql_intention_generator) for prompt in prompts] + return await asyncio.gather(*tasks) + + +@observe() +def post_process( + generate_sql_intention: List[dict], + sql_pairs: List[SqlPair], +) -> List[Dict[str, Any]]: + intentions = [ + orjson.loads(result["replies"][0])["intention"] + for result in generate_sql_intention + ] + + return [ + {"id": sql_pair.id, "sql": sql_pair.sql, "intention": intention} + for sql_pair, intention in zip(sql_pairs, intentions) + ] + + +@observe(capture_input=False) +def convert_sql_pairs_to_documents( + post_process: List[Dict[str, Any]], + sql_pairs_description_converter: SqlPairsDescriptionConverter, + id: Optional[str] = None, +) -> Dict[str, Any]: + return sql_pairs_description_converter.run(sql_pairs=post_process, id=id) + + +@observe(capture_input=False, capture_output=False) +async def embed_sql_pairs( + convert_sql_pairs_to_documents: Dict[str, Any], + document_embedder: Any, +) -> Dict[str, Any]: + return await document_embedder.run( + documents=convert_sql_pairs_to_documents["documents"] + ) + + +@observe(capture_input=False, capture_output=False) +async def delete_sql_pairs( + sql_pairs_cleaner: SqlPairsCleaner, + sql_pairs: List[SqlPair], + embed_sql_pairs: Dict[str, Any], + id: Optional[str] = None, +) -> List[SqlPair]: + sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs] + await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, id=id) + + return embed_sql_pairs + + +@observe(capture_input=False) +async def write_sql_pairs( + embed_sql_pairs: Dict[str, Any], + sql_pairs_writer: AsyncDocumentWriter, +) -> None: + return await sql_pairs_writer.run(documents=embed_sql_pairs["documents"]) + + +## End of Pipeline +class SqlIntentionGenerationResult(BaseModel): + intention: str + + +SQL_INTENTION_GENERATION_MODEL_KWARGS = { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "matched_schema", + "schema": SqlIntentionGenerationResult.model_json_schema(), + }, + } +} + + +class SqlPairsPreparation(BasicPipeline): + def __init__( + self, + embedder_provider: EmbedderProvider, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **kwargs, + ) -> None: + sql_pairs_store = document_store_provider.get_store(dataset_name="sql_pairs") + + self._components = { + "sql_pairs_cleaner": SqlPairsCleaner(sql_pairs_store), + "prompt_builder": PromptBuilder( + template=sql_intention_generation_user_prompt_template + ), + "sql_intention_generator": llm_provider.get_generator( + system_prompt=sql_intention_generation_system_prompt, + generation_kwargs=SQL_INTENTION_GENERATION_MODEL_KWARGS, + ), + "document_embedder": embedder_provider.get_document_embedder(), + "sql_pairs_description_converter": SqlPairsDescriptionConverter(), + "sql_pairs_writer": AsyncDocumentWriter( + document_store=sql_pairs_store, + policy=DuplicatePolicy.OVERWRITE, + ), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + @observe(name="SQL Pairs Preparation") + async def run( + self, sql_pairs: List[SqlPair], id: Optional[str] = None + ) -> Dict[str, Any]: + logger.info("SQL Pairs Preparation pipeline is running...") + return await self._pipe.execute( + ["write_sql_pairs"], + inputs={ + "sql_pairs": sql_pairs, + "id": id or "", + **self._components, + }, + ) + + +if __name__ == "__main__": + from src.pipelines.common import dry_run_pipeline + + dry_run_pipeline( + SqlPairsPreparation, + "sql_pairs_preparation", + sql_pairs=[], + ) diff --git a/wren-ai-service/src/pipelines/retrieval/__init__.py b/wren-ai-service/src/pipelines/retrieval/__init__.py index e69de29bb..82e319b4c 100644 --- a/wren-ai-service/src/pipelines/retrieval/__init__.py +++ b/wren-ai-service/src/pipelines/retrieval/__init__.py @@ -0,0 +1,13 @@ +from .historical_question_retrieval import HistoricalQuestionRetrieval +from .preprocess_sql_data import PreprocessSqlData +from .retrieval import Retrieval +from .sql_executor import SQLExecutor +from .sql_pairs_retrieval import SqlPairsRetrieval + +__all__ = [ + "HistoricalQuestionRetrieval", + "PreprocessSqlData", + "Retrieval", + "SQLExecutor", + "SqlPairsRetrieval", +] diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py similarity index 85% rename from wren-ai-service/src/pipelines/retrieval/historical_question.py rename to wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 0d2a1ceab..9be740cae 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -10,29 +10,11 @@ from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider, EmbedderProvider -from src.utils import ( - async_timer, - timer, -) +from src.pipelines.common import ScoreFilter logger = logging.getLogger("wren-ai-service") -@component -class ScoreFilter: - @component.output_types( - documents=List[Document], - ) - def run(self, documents: List[Document], score: float = 0.9): - return { - "documents": sorted( - filter(lambda document: document.score >= score, documents), - key=lambda document: document.score, - reverse=True, - ) - } - - @component class OutputFormatter: @component.output_types( @@ -54,7 +36,6 @@ def run(self, documents: List[Document]): ## Start of Pipeline -@async_timer @observe(capture_input=False) async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) -> int: filters = ( @@ -71,7 +52,6 @@ async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) return document_count -@async_timer @observe(capture_input=False, capture_output=False) async def embedding(count_documents: int, query: str, embedder: Any) -> dict: if count_documents: @@ -80,7 +60,6 @@ async def embedding(count_documents: int, query: str, embedder: Any) -> dict: return {} -@async_timer @observe(capture_input=False) async def retrieval(embedding: dict, id: str, retriever: Any) -> dict: if embedding: @@ -104,16 +83,14 @@ async def retrieval(embedding: dict, id: str, retriever: Any) -> dict: return {} -@timer @observe(capture_input=False) def filtered_documents(retrieval: dict, score_filter: ScoreFilter) -> dict: if retrieval: - return score_filter.run(documents=retrieval.get("documents")) + return score_filter.run(documents=retrieval.get("documents"), score=0.9) return {} -@timer @observe(capture_input=False) def formatted_output( filtered_documents: dict, output_formatter: OutputFormatter @@ -127,7 +104,7 @@ def formatted_output( ## End of Pipeline -class HistoricalQuestion(BasicPipeline): +class HistoricalQuestionRetrieval(BasicPipeline): def __init__( self, embedder_provider: EmbedderProvider, @@ -150,10 +127,9 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Historical Question") async def run(self, query: str, id: Optional[str] = None): - logger.info("HistoricalQuestion pipeline is running...") + logger.info("HistoricalQuestion Retrieval pipeline is running...") return await self._pipe.execute( ["formatted_output"], inputs={ @@ -168,7 +144,7 @@ async def run(self, query: str, id: Optional[str] = None): from src.pipelines.common import dry_run_pipeline dry_run_pipeline( - HistoricalQuestion, + HistoricalQuestionRetrieval, "historical_question_retrieval", query="this is a test query", ) diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index b5e5ffb29..9cf094e22 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -9,13 +9,11 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.utils import timer logger = logging.getLogger("wren-ai-service") ## Start of Pipeline -@timer @observe(capture_input=False, capture_output=False) def preprocess( sql_data: Dict, @@ -67,7 +65,6 @@ def __init__( super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) - @timer @observe(name="Preprocess SQL Data") def run( self, diff --git a/wren-ai-service/src/pipelines/retrieval/retrieval.py b/wren-ai-service/src/pipelines/retrieval/retrieval.py index 5fdafeb1a..584fffc3f 100644 --- a/wren-ai-service/src/pipelines/retrieval/retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/retrieval.py @@ -15,7 +15,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider from src.pipelines.common import build_table_ddl -from src.utils import async_timer, timer from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -115,7 +114,6 @@ def _build_view_ddl(content: dict) -> str: ## Start of Pipeline -@async_timer @observe(capture_input=False, capture_output=False) async def embedding( query: str, embedder: Any, history: Optional[AskHistory] = None @@ -132,7 +130,6 @@ async def embedding( return await embedder.run(query) -@async_timer @observe(capture_input=False) async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dict: filters = { @@ -153,7 +150,6 @@ async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dic ) -@async_timer @observe(capture_input=False) async def dbschema_retrieval( table_retrieval: dict, embedding: dict, id: str, dbschema_retriever: Any @@ -188,7 +184,6 @@ async def dbschema_retrieval( return results["documents"] -@timer @observe() def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[dict]: db_schemas = {} @@ -217,7 +212,6 @@ def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[dict]: return list(db_schemas.values()) -@timer @observe(capture_input=False) def check_using_db_schemas_without_pruning( construct_db_schemas: list[dict], @@ -256,7 +250,6 @@ def check_using_db_schemas_without_pruning( } -@timer @observe(capture_input=False) def prompt( query: str, @@ -287,7 +280,6 @@ def prompt( return {} -@async_timer @observe(as_type="generation", capture_input=False) async def filter_columns_in_tables( prompt: dict, table_columns_selection_generator: Any @@ -298,7 +290,6 @@ async def filter_columns_in_tables( return {} -@timer @observe() def construct_retrieval_results( check_using_db_schemas_without_pruning: dict, @@ -420,7 +411,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="Ask Retrieval") async def run( self, diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index 50e72791d..e0a721f84 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -10,7 +10,6 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline -from src.utils import async_timer logger = logging.getLogger("wren-ai-service") @@ -42,7 +41,6 @@ async def run( ## Start of Pipeline -@async_timer @observe(capture_input=False) async def execute_sql( sql: str, @@ -70,7 +68,6 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - @async_timer @observe(name="SQL Execution") async def run( self, sql: str, project_id: str | None = None, limit: int = 500 diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py new file mode 100644 index 000000000..abba3b016 --- /dev/null +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -0,0 +1,148 @@ +import logging +import sys +from typing import Any, Dict, List, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack import Document, component +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import DocumentStoreProvider, EmbedderProvider +from src.pipelines.common import ScoreFilter + +logger = logging.getLogger("wren-ai-service") + + +@component +class OutputFormatter: + @component.output_types( + documents=List[Optional[Dict]], + ) + def run(self, documents: List[Document]): + list = [] + + for doc in documents: + formatted = { + "summary": doc.content, + "sql": doc.meta.get("sql"), + } + list.append(formatted) + + return {"documents": list} + + +## Start of Pipeline +@observe(capture_input=False) +async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) -> int: + filters = ( + { + "operator": "AND", + "conditions": [ + {"field": "project_id", "operator": "==", "value": id}, + ], + } + if id + else None + ) + document_count = await store.count_documents(filters=filters) + return document_count + + +@observe(capture_input=False, capture_output=False) +async def embedding(count_documents: int, query: str, embedder: Any) -> dict: + if count_documents: + return await embedder.run(query) + + return {} + + +@observe(capture_input=False) +async def retrieval(embedding: dict, id: str, retriever: Any) -> dict: + if embedding: + filters = ( + { + "operator": "AND", + "conditions": [ + {"field": "project_id", "operator": "==", "value": id}, + ], + } + if id + else None + ) + + res = await retriever.run( + query_embedding=embedding.get("embedding"), + filters=filters, + ) + return dict(documents=res.get("documents")) + + return {} + + +@observe(capture_input=False) +def filtered_documents(retrieval: dict, score_filter: ScoreFilter) -> dict: + if retrieval: + return score_filter.run(documents=retrieval.get("documents"), score=0.7) + + return {} + + +@observe(capture_input=False) +def formatted_output( + filtered_documents: dict, output_formatter: OutputFormatter +) -> dict: + if filtered_documents: + return output_formatter.run(documents=filtered_documents.get("documents")) + + return {"documents": []} + + +## End of Pipeline + + +class SqlPairsRetrieval(BasicPipeline): + def __init__( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **kwargs, + ) -> None: + store = document_store_provider.get_store(dataset_name="sql_pairs") + self._components = { + "store": store, + "embedder": embedder_provider.get_text_embedder(), + "retriever": document_store_provider.get_retriever( + document_store=store, + ), + "score_filter": ScoreFilter(), + # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough + "output_formatter": OutputFormatter(), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + @observe(name="SqlPairs Retrieval") + async def run(self, query: str, id: Optional[str] = None): + logger.info("SqlPairs Retrieval pipeline is running...") + return await self._pipe.execute( + ["formatted_output"], + inputs={ + "query": query, + "id": id or "", + **self._components, + }, + ) + + +if __name__ == "__main__": + from src.pipelines.common import dry_run_pipeline + + dry_run_pipeline( + SqlPairsRetrieval, + "sql_pairs_retrieval", + query="this is a test query", + ) diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 16984a9f7..0ca20aee0 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -1,8 +1,6 @@ -import asyncio import functools import logging import os -import time from pathlib import Path from dotenv import load_dotenv @@ -61,54 +59,6 @@ def load_env_vars() -> str: return "prod" -def timer(func): - @functools.wraps(func) - def wrapper_timer(*args, **kwargs): - from src.config import settings - - if settings.enable_timer: - startTime = time.perf_counter() - result = func(*args, **kwargs) - endTime = time.perf_counter() - elapsed_time = endTime - startTime - - logger.info( - f"{func.__qualname__} Elapsed time: {elapsed_time:0.4f} seconds" - ) - - return result - - return func(*args, **kwargs) - - return wrapper_timer - - -def async_timer(func): - async def process(func, *args, **kwargs): - assert asyncio.iscoroutinefunction(func) - return await func(*args, **kwargs) - - @functools.wraps(func) - async def wrapper_timer(*args, **kwargs): - from src.config import settings - - if settings.enable_timer: - startTime = time.perf_counter() - result = await process(func, *args, **kwargs) - endTime = time.perf_counter() - elapsed_time = endTime - startTime - - logger.info( - f"{func.__qualname__} Elapsed time: {elapsed_time:0.4f} seconds" - ) - - return result - - return await process(func, *args, **kwargs) - - return wrapper_timer - - def remove_trailing_slash(endpoint: str) -> str: return endpoint.rstrip("/") if endpoint.endswith("/") else endpoint diff --git a/wren-ai-service/src/web/development.py b/wren-ai-service/src/web/development.py index 442ae6e1c..d517b0872 100644 --- a/wren-ai-service/src/web/development.py +++ b/wren-ai-service/src/web/development.py @@ -6,7 +6,6 @@ from fastapi import APIRouter, BackgroundTasks -from src.utils import async_timer from src.web.v1.services.ask import ( AskError, AskRequest, @@ -20,7 +19,6 @@ test_ask_results = {} -@async_timer async def dummy_ask_task(ask_request: AskRequest): """ settings: diff --git a/wren-ai-service/src/web/v1/routers/__init__.py b/wren-ai-service/src/web/v1/routers/__init__.py index d3ebd402e..3c4c83bd9 100644 --- a/wren-ai-service/src/web/v1/routers/__init__.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -8,10 +8,11 @@ question_recommendation, relationship_recommendation, semantics_description, - semantics_preparations, + semantics_preparation, sql_answers, sql_expansions, sql_explanations, + sql_pairs_preparation, sql_regenerations, ) @@ -21,11 +22,12 @@ router.include_router(question_recommendation.router) router.include_router(relationship_recommendation.router) router.include_router(semantics_description.router) -router.include_router(semantics_preparations.router) +router.include_router(semantics_preparation.router) router.include_router(sql_answers.router) router.include_router(sql_expansions.router) router.include_router(sql_explanations.router) router.include_router(chart.router) router.include_router(chart_adjustment.router) router.include_router(sql_regenerations.router) +router.include_router(sql_pairs_preparation.router) # connected subrouter diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 84ec54bad..3f0ac4950 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -34,7 +34,6 @@ - `project_id`: (Optional) Identifier for the project to fetch relevant data. - `mdl_hash`: (Optional) Hash or ID related to the model to be used for the query. - `thread_id`: (Optional) Thread identifier for the query. - - `user_id`: (Optional) User identifier. - `history`: (Optional) Query history (SQL steps). - `configurations`: (Optional) Configuration such as fiscal year. - **Response**: diff --git a/wren-ai-service/src/web/v1/routers/ask_details.py b/wren-ai-service/src/web/v1/routers/ask_details.py index abd1a2fc2..3665b5ebf 100644 --- a/wren-ai-service/src/web/v1/routers/ask_details.py +++ b/wren-ai-service/src/web/v1/routers/ask_details.py @@ -33,7 +33,6 @@ "mdl_hash": "optional-hash", # Optional model hash for reference "thread_id": "optional-thread-id", # Optional thread identifier "project_id": "optional-project-id", # Optional project identifier - "user_id": "optional-user-id" # Optional user identifier } - Response: AskDetailsResponse { diff --git a/wren-ai-service/src/web/v1/routers/semantics_preparations.py b/wren-ai-service/src/web/v1/routers/semantics_preparation.py similarity index 98% rename from wren-ai-service/src/web/v1/routers/semantics_preparations.py rename to wren-ai-service/src/web/v1/routers/semantics_preparation.py index a09241789..f2b128bb4 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_preparations.py +++ b/wren-ai-service/src/web/v1/routers/semantics_preparation.py @@ -32,7 +32,6 @@ "mdl": "model_data_string", # String representing the model data to be indexed "mdl_hash": "unique_hash", # Unique identifier for the model (hash or ID) "project_id": "optional_project_id", # Optional project identifier - "user_id": "optional_user_id" # Optional user identifier } - Response: SemanticsPreparationResponse { diff --git a/wren-ai-service/src/web/v1/routers/sql_answers.py b/wren-ai-service/src/web/v1/routers/sql_answers.py index 1e83ec294..405af7234 100644 --- a/wren-ai-service/src/web/v1/routers/sql_answers.py +++ b/wren-ai-service/src/web/v1/routers/sql_answers.py @@ -33,7 +33,6 @@ "sql": "SELECT * FROM table_name WHERE condition", # Actual SQL statement "sql_data": , # Preprocessed SQL data "thread_id": "unique-thread-id", # Optional thread identifier for tracking - "user_id": "user-id" # Optional user identifier for tracking } - Response: SqlAnswerResponse { diff --git a/wren-ai-service/src/web/v1/routers/sql_expansions.py b/wren-ai-service/src/web/v1/routers/sql_expansions.py index 15fb9d954..170ab4006 100644 --- a/wren-ai-service/src/web/v1/routers/sql_expansions.py +++ b/wren-ai-service/src/web/v1/routers/sql_expansions.py @@ -36,7 +36,6 @@ "project_id": "project-identifier", # Identifier for the project "mdl_hash": "hash-of-model", # Hash of the model (if applicable) "thread_id": "thread-identifier", # Identifier for the thread (if applicable) - "user_id": "user-identifier" # Identifier for the user making the request } - Response: SqlExpansionResponse { diff --git a/wren-ai-service/src/web/v1/routers/sql_explanations.py b/wren-ai-service/src/web/v1/routers/sql_explanations.py index 66918ad9f..83325f28d 100644 --- a/wren-ai-service/src/web/v1/routers/sql_explanations.py +++ b/wren-ai-service/src/web/v1/routers/sql_explanations.py @@ -39,7 +39,6 @@ "mdl_hash": "hash_value", # Optional hash for the model used "thread_id": "thread-123", # Optional identifier for the thread "project_id": "project-456", # Optional identifier for the project - "user_id": "user-789" # Optional identifier for the user } - Response: SQLExplanationResponse { diff --git a/wren-ai-service/src/web/v1/routers/sql_pairs_preparation.py b/wren-ai-service/src/web/v1/routers/sql_pairs_preparation.py new file mode 100644 index 000000000..7202ed1df --- /dev/null +++ b/wren-ai-service/src/web/v1/routers/sql_pairs_preparation.py @@ -0,0 +1,123 @@ +import uuid +from dataclasses import asdict + +from fastapi import APIRouter, BackgroundTasks, Depends + +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.web.v1.services.sql_pairs_preparation import ( + DeleteSqlPairsRequest, + DeleteSqlPairsResponse, + SqlPairsPreparationRequest, + SqlPairsPreparationResponse, + SqlPairsPreparationStatusRequest, + SqlPairsPreparationStatusResponse, +) + +router = APIRouter() + + +""" +Sql Pairs Preparation Router + +This router manages the endpoints related to users uploading SQL pairs and retrieving their status. + +Endpoints: +1. **POST /sql-pairs** + - Initiates the preparation of SQL pairs for processing. + - **Request Body**: SqlPairsPreparationRequest + - `sql_pairs`: List of SQL pairs, each containing: + - `sql`: The SQL statement + - `id`: Unique identifier for the SQL pair + - `project_id`: (Optional) Identifier for the project context + - **Response**: SqlPairsPreparationResponse + - `sql_pairs_preparation_id`: A unique identifier (UUID) for tracking the preparation process + +2. **DELETE /sql-pairs** + - Deletes specified SQL pairs. + - **Request Body**: DeleteSqlPairsRequest + - `ids`: List of SQL pair IDs to delete + - `project_id`: (Optional) Project identifier + - **Response**: DeleteSqlPairsResponse + - `sql_pairs_preparation_id`: A unique identifier (UUID) for tracking the deletion process + +3. **GET /sql-pairs/{sql_pairs_preparation_id}** + - Retrieves the current status of a SQL pairs preparation or deletion process. + - **Path Parameter**: + - `sql_pairs_preparation_id`: The unique identifier of the process + - **Response**: SqlPairsPreparationStatusResponse + - `status`: Current status ("indexing", "deleting", "finished", or "failed") + - `error`: (Optional) Error information if the process failed, including: + - `code`: Error code ("OTHERS") + - `message`: Detailed error message + +Process: +1. Submit SQL pairs using the POST endpoint to initiate preparation. This returns a preparation ID. +2. Use the DELETE endpoint to remove specific SQL pairs from the system. +3. Track the status of any operation using the GET endpoint with the preparation ID. + +Note: All operations are processed asynchronously using background tasks. The status can be polled +via the GET endpoint. Results are cached with a TTL of 120 seconds. +""" + + +@router.post("/sql-pairs") +async def prepare_sql_pairs( + prepare_sql_pairs_request: SqlPairsPreparationRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> SqlPairsPreparationResponse: + id = str(uuid.uuid4()) + prepare_sql_pairs_request.query_id = id + service_container.sql_pairs_preparation_service._prepare_sql_pairs_statuses[ + id + ] = SqlPairsPreparationStatusResponse( + status="indexing", + ) + + background_tasks.add_task( + service_container.sql_pairs_preparation_service.prepare_sql_pairs, + prepare_sql_pairs_request, + service_metadata=asdict(service_metadata), + ) + return SqlPairsPreparationResponse(sql_pairs_preparation_id=id) + + +@router.delete("/sql-pairs") +async def delete_sql_pairs( + delete_sql_pairs_request: DeleteSqlPairsRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> DeleteSqlPairsResponse: + id = str(uuid.uuid4()) + delete_sql_pairs_request.query_id = id + service_container.sql_pairs_preparation_service._prepare_sql_pairs_statuses[ + id + ] = SqlPairsPreparationStatusResponse( + status="deleting", + ) + + background_tasks.add_task( + service_container.sql_pairs_preparation_service.delete_sql_pairs, + delete_sql_pairs_request, + service_metadata=asdict(service_metadata), + ) + return DeleteSqlPairsResponse(sql_pairs_preparation_id=id) + + +@router.get("/sql-pairs/{sql_pairs_preparation_id}") +async def get_sql_pairs_preparation_status( + sql_pairs_preparation_id: str, + service_container: ServiceContainer = Depends(get_service_container), +) -> SqlPairsPreparationStatusResponse: + return service_container.sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest( + sql_pairs_preparation_id=sql_pairs_preparation_id + ) + ) diff --git a/wren-ai-service/src/web/v1/routers/sql_regenerations.py b/wren-ai-service/src/web/v1/routers/sql_regenerations.py index 4fc26b9cc..657e5ecf0 100644 --- a/wren-ai-service/src/web/v1/routers/sql_regenerations.py +++ b/wren-ai-service/src/web/v1/routers/sql_regenerations.py @@ -41,7 +41,6 @@ "mdl_hash": null, # Optional hash for model identification "thread_id": null, # Optional identifier for the processing thread "project_id": null, # Optional project identifier - "user_id": null # Optional user identifier } - Response: SQLRegenerationResponse { diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index c8e22018b..bfa92698a 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -7,7 +7,7 @@ from pydantic import AliasChoices, BaseModel, Field from src.core.pipeline import BasicPipeline -from src.utils import async_timer, trace_metadata +from src.utils import trace_metadata from src.web.v1.services import Configuration, SSEEvent from src.web.v1.services.ask_details import SQLBreakdown @@ -29,7 +29,6 @@ class AskRequest(BaseModel): # so we need to support as a choice, and will remove it in the future mdl_hash: Optional[str] = Field(validation_alias=AliasChoices("mdl_hash", "id")) thread_id: Optional[str] = None - user_id: Optional[str] = None history: Optional[AskHistory] = None configurations: Optional[Configuration] = Configuration() @@ -122,7 +121,6 @@ def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]): filter(lambda x: x["type"] == "DRY_RUN", invalid_generation_results) ) - @async_timer @observe(name="Ask Question") @trace_metadata async def ask( @@ -244,7 +242,9 @@ async def ask( intent_reasoning=intent_reasoning, ) - historical_question = await self._pipelines["historical_question"].run( + historical_question = await self._pipelines[ + "historical_question_retrieval" + ].run( query=ask_request.query, id=ask_request.project_id, ) @@ -267,6 +267,13 @@ async def ask( for result in historical_question_result ] else: + sql_samples = ( + await self._pipelines["sql_pairs_retrieval"].run( + query=ask_request.query, + id=ask_request.project_id, + ) + )["formatted_output"].get("documents", []) + if ask_request.history: text_to_sql_generation_results = await self._pipelines[ "followup_sql_generation" @@ -276,6 +283,7 @@ async def ask( history=ask_request.history, project_id=ask_request.project_id, configuration=ask_request.configurations, + sql_samples=sql_samples, ) else: text_to_sql_generation_results = await self._pipelines[ @@ -286,6 +294,7 @@ async def ask( exclude=historical_question_result, project_id=ask_request.project_id, configuration=ask_request.configurations, + sql_samples=sql_samples, ) if sql_valid_results := text_to_sql_generation_results[ diff --git a/wren-ai-service/src/web/v1/services/ask_details.py b/wren-ai-service/src/web/v1/services/ask_details.py index 4405172ad..4c90cf0ac 100644 --- a/wren-ai-service/src/web/v1/services/ask_details.py +++ b/wren-ai-service/src/web/v1/services/ask_details.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from src.core.engine import add_quotes -from src.utils import async_timer, trace_metadata +from src.utils import trace_metadata from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -27,7 +27,6 @@ class AskDetailsRequest(BaseModel): mdl_hash: Optional[str] = None thread_id: Optional[str] = None project_id: Optional[str] = None - user_id: Optional[str] = None configurations: Configuration = Configuration() @property @@ -82,7 +81,6 @@ async def _add_summary_to_sql(self, sql: str, query: str, language: str): ) return sql_summary_results["post_process"]["sql_summary_results"] - @async_timer @observe(name="Ask Details(Breakdown SQL)") @trace_metadata async def ask_details( diff --git a/wren-ai-service/src/web/v1/services/chart.py b/wren-ai-service/src/web/v1/services/chart.py index 6b28c728e..b7b21f184 100644 --- a/wren-ai-service/src/web/v1/services/chart.py +++ b/wren-ai-service/src/web/v1/services/chart.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.core.pipeline import BasicPipeline -from src.utils import async_timer, trace_metadata +from src.utils import trace_metadata from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -93,7 +93,6 @@ def _is_stopped(self, query_id: str): return False - @async_timer @observe(name="Generate Chart") @trace_metadata async def chart( diff --git a/wren-ai-service/src/web/v1/services/chart_adjustment.py b/wren-ai-service/src/web/v1/services/chart_adjustment.py index 0e2a050d6..ae574389d 100644 --- a/wren-ai-service/src/web/v1/services/chart_adjustment.py +++ b/wren-ai-service/src/web/v1/services/chart_adjustment.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.core.pipeline import BasicPipeline -from src.utils import async_timer, trace_metadata +from src.utils import trace_metadata from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -106,7 +106,6 @@ def _is_stopped(self, query_id: str): return False - @async_timer @observe(name="Adjust Chart") @trace_metadata async def chart_adjustment( diff --git a/wren-ai-service/src/web/v1/services/semantics_preparation.py b/wren-ai-service/src/web/v1/services/semantics_preparation.py index 81cf351e4..c4e23a286 100644 --- a/wren-ai-service/src/web/v1/services/semantics_preparation.py +++ b/wren-ai-service/src/web/v1/services/semantics_preparation.py @@ -19,7 +19,6 @@ class SemanticsPreparationRequest(BaseModel): # so we need to support as a choice, and will remove it in the future mdl_hash: str = Field(validation_alias=AliasChoices("mdl_hash", "id")) project_id: Optional[str] = None - user_id: Optional[str] = None class SemanticsPreparationResponse(BaseModel): diff --git a/wren-ai-service/src/web/v1/services/sql_answer.py b/wren-ai-service/src/web/v1/services/sql_answer.py index 5c5ba05e5..0b74c90af 100644 --- a/wren-ai-service/src/web/v1/services/sql_answer.py +++ b/wren-ai-service/src/web/v1/services/sql_answer.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from src.core.pipeline import BasicPipeline -from src.utils import async_timer, trace_metadata +from src.utils import trace_metadata from src.web.v1.services import Configuration, SSEEvent logger = logging.getLogger("wren-ai-service") @@ -20,7 +20,6 @@ class SqlAnswerRequest(BaseModel): sql: str sql_data: Dict thread_id: Optional[str] = None - user_id: Optional[str] = None configurations: Optional[Configuration] = Configuration() @property @@ -63,7 +62,6 @@ def __init__( maxsize=maxsize, ttl=ttl ) - @async_timer @observe(name="SQL Answer") @trace_metadata async def sql_answer( diff --git a/wren-ai-service/src/web/v1/services/sql_expansion.py b/wren-ai-service/src/web/v1/services/sql_expansion.py index 07af68158..a91d17905 100644 --- a/wren-ai-service/src/web/v1/services/sql_expansion.py +++ b/wren-ai-service/src/web/v1/services/sql_expansion.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.core.pipeline import BasicPipeline -from src.utils import async_timer, remove_sql_summary_duplicates, trace_metadata +from src.utils import remove_sql_summary_duplicates, trace_metadata from src.web.v1.services import Configuration from src.web.v1.services.ask import AskError, AskHistory from src.web.v1.services.ask_details import SQLBreakdown @@ -23,7 +23,6 @@ class SqlExpansionRequest(BaseModel): project_id: Optional[str] = None mdl_hash: Optional[str] = None thread_id: Optional[str] = None - user_id: Optional[str] = None configurations: Optional[Configuration] = Configuration() @property @@ -99,7 +98,6 @@ def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]): filter(lambda x: x["type"] == "DRY_RUN", invalid_generation_results) ) - @async_timer @observe(name="SQL Expansion") @trace_metadata async def sql_expansion( diff --git a/wren-ai-service/src/web/v1/services/sql_explanation.py b/wren-ai-service/src/web/v1/services/sql_explanation.py index be780bfc5..6e4fbb470 100644 --- a/wren-ai-service/src/web/v1/services/sql_explanation.py +++ b/wren-ai-service/src/web/v1/services/sql_explanation.py @@ -6,8 +6,6 @@ from haystack import Pipeline from pydantic import BaseModel -from src.utils import async_timer - logger = logging.getLogger("wren-ai-service") @@ -25,7 +23,6 @@ class SQLExplanationRequest(BaseModel): mdl_hash: Optional[str] = None thread_id: Optional[str] = None project_id: Optional[str] = None - user_id: Optional[str] = None @property def query_id(self) -> str: @@ -55,7 +52,7 @@ class SQLExplanationResultError(BaseModel): error: Optional[SQLExplanationResultError] = None -class SQLExplanationService: +class SqlExplanationService: def __init__( self, pipelines: Dict[str, Pipeline], @@ -67,7 +64,6 @@ def __init__( str, SQLExplanationResultResponse ] = TTLCache(maxsize=maxsize, ttl=ttl) - @async_timer async def sql_explanation( self, sql_explanation_request: SQLExplanationRequest, diff --git a/wren-ai-service/src/web/v1/services/sql_pairs_preparation.py b/wren-ai-service/src/web/v1/services/sql_pairs_preparation.py new file mode 100644 index 000000000..39a9f63cd --- /dev/null +++ b/wren-ai-service/src/web/v1/services/sql_pairs_preparation.py @@ -0,0 +1,188 @@ +import logging +from typing import Dict, List, Literal, Optional + +from cachetools import TTLCache +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.utils import trace_metadata + +logger = logging.getLogger("wren-ai-service") + + +# POST /v1/sql-pairs +class SqlPair(BaseModel): + sql: str + id: str + + +class SqlPairsPreparationRequest(BaseModel): + _query_id: str | None = None + sql_pairs: List[SqlPair] + project_id: Optional[str] = None + + @property + def query_id(self) -> str: + return self._query_id + + @query_id.setter + def query_id(self, query_id: str): + self._query_id = query_id + + +class SqlPairsPreparationResponse(BaseModel): + sql_pairs_preparation_id: str + + +# DELETE /v1/sql-pairs +class DeleteSqlPairsRequest(BaseModel): + _query_id: str | None = None + ids: List[str] + project_id: Optional[str] = None + + @property + def query_id(self) -> str: + return self._query_id + + @query_id.setter + def query_id(self, query_id: str): + self._query_id = query_id + + +class DeleteSqlPairsResponse(BaseModel): + sql_pairs_preparation_id: str + + +# GET /v1/sql-pairs/{sql_pairs_preparation_id} +class SqlPairsPreparationStatusRequest(BaseModel): + sql_pairs_preparation_id: str + + +class SqlPairsPreparationStatusResponse(BaseModel): + class SqlPairsPreparationError(BaseModel): + code: Literal["OTHERS"] + message: str + + status: Literal["indexing", "deleting", "finished", "failed"] + error: Optional[SqlPairsPreparationError] = None + + +class SqlPairsPreparationService: + def __init__( + self, + pipelines: Dict[str, BasicPipeline], + maxsize: int = 1_000_000, + ttl: int = 120, + ): + self._pipelines = pipelines + self._prepare_sql_pairs_statuses: Dict[ + str, SqlPairsPreparationStatusResponse + ] = TTLCache(maxsize=maxsize, ttl=ttl) + + @observe(name="Prepare SQL Pairs") + @trace_metadata + async def prepare_sql_pairs( + self, + prepare_sql_pairs_request: SqlPairsPreparationRequest, + **kwargs, + ): + results = { + "metadata": { + "error_type": "", + "error_message": "", + }, + } + + try: + await self._pipelines["sql_pairs_preparation"].run( + sql_pairs=prepare_sql_pairs_request.sql_pairs, + id=prepare_sql_pairs_request.project_id, + ) + + self._prepare_sql_pairs_statuses[ + prepare_sql_pairs_request.query_id + ] = SqlPairsPreparationStatusResponse( + status="finished", + ) + except Exception as e: + logger.exception(f"Failed to prepare SQL pairs: {e}") + + self._prepare_sql_pairs_statuses[ + prepare_sql_pairs_request.query_id + ] = SqlPairsPreparationStatusResponse( + status="failed", + error=SqlPairsPreparationStatusResponse.SqlPairsPreparationError( + code="OTHERS", + message=f"Failed to prepare SQL pairs: {e}", + ), + ) + + results["metadata"]["error_type"] = "INDEXING_FAILED" + results["metadata"]["error_message"] = str(e) + + return results + + @observe(name="Delete SQL Pairs") + @trace_metadata + async def delete_sql_pairs( + self, + delete_sql_pairs_request: DeleteSqlPairsRequest, + **kwargs, + ): + results = { + "metadata": { + "error_type": "", + "error_message": "", + }, + } + + try: + await self._pipelines["sql_pairs_deletion"].run( + sql_pair_ids=delete_sql_pairs_request.ids, + id=delete_sql_pairs_request.project_id, + ) + + self._prepare_sql_pairs_statuses[ + delete_sql_pairs_request.query_id + ] = SqlPairsPreparationStatusResponse( + status="finished", + ) + except Exception as e: + logger.exception(f"Failed to delete SQL pairs: {e}") + + self._prepare_sql_pairs_statuses[ + delete_sql_pairs_request.query_id + ] = SqlPairsPreparationStatusResponse( + status="failed", + error=SqlPairsPreparationStatusResponse.SqlPairsPreparationError( + code="OTHERS", + message=f"Failed to delete SQL pairs: {e}", + ), + ) + + results["metadata"]["error_type"] = "DELETION_FAILED" + results["metadata"]["error_message"] = str(e) + + return results + + def get_prepare_sql_pairs_status( + self, prepare_sql_pairs_status_request: SqlPairsPreparationStatusRequest + ) -> SqlPairsPreparationStatusResponse: + if ( + result := self._prepare_sql_pairs_statuses.get( + prepare_sql_pairs_status_request.sql_pairs_preparation_id + ) + ) is None: + logger.exception( + f"id is not found for SqlPairsPreparation: {prepare_sql_pairs_status_request.sql_pairs_preparation_id}" + ) + return SqlPairsPreparationStatusResponse( + status="failed", + error=SqlPairsPreparationStatusResponse.SqlPairsPreparationError( + code="OTHERS", + message=f"{prepare_sql_pairs_status_request.sql_pairs_preparation_id} is not found", + ), + ) + + return result diff --git a/wren-ai-service/src/web/v1/services/sql_regeneration.py b/wren-ai-service/src/web/v1/services/sql_regeneration.py index 6713144e0..618ca790f 100644 --- a/wren-ai-service/src/web/v1/services/sql_regeneration.py +++ b/wren-ai-service/src/web/v1/services/sql_regeneration.py @@ -5,7 +5,6 @@ from haystack import Pipeline from pydantic import BaseModel -from src.utils import async_timer from src.web.v1.services.ask_details import SQLBreakdown logger = logging.getLogger("wren-ai-service") @@ -43,7 +42,6 @@ class SQLRegenerationRequest(BaseModel): mdl_hash: Optional[str] = None thread_id: Optional[str] = None project_id: Optional[str] = None - user_id: Optional[str] = None @property def query_id(self) -> str: @@ -77,7 +75,7 @@ class SQLRegenerationError(BaseModel): error: Optional[SQLRegenerationError] = None -class SQLRegenerationService: +class SqlRegenerationService: def __init__( self, pipelines: Dict[str, Pipeline], @@ -89,7 +87,6 @@ def __init__( str, SQLRegenerationResultResponse ] = TTLCache(maxsize=maxsize, ttl=ttl) - @async_timer async def sql_regeneration( self, sql_regeneration_request: SQLRegenerationRequest, diff --git a/wren-ai-service/tests/data/config.test.yaml b/wren-ai-service/tests/data/config.test.yaml index 006f4d15c..ae366d17f 100644 --- a/wren-ai-service/tests/data/config.test.yaml +++ b/wren-ai-service/tests/data/config.test.yaml @@ -42,7 +42,7 @@ pipes: llm: openai_llm.gpt-4o-mini embedder: openai_embedder.text-embedding-3-large document_store: qdrant - - name: historical_question + - name: historical_question_retrieval embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: sql_generation @@ -87,6 +87,5 @@ settings: query_cache_ttl: 3600 langfuse_host: https://cloud.langfuse.com langfuse_enable: false - enable_timer: false logging_level: INFO development: false diff --git a/wren-ai-service/tests/pytest/pipelines/generation/__init__.py b/wren-ai-service/tests/pytest/pipelines/generation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wren-ai-service/tests/pytest/pipelines/test_ask.py b/wren-ai-service/tests/pytest/pipelines/generation/test_ask.py similarity index 100% rename from wren-ai-service/tests/pytest/pipelines/test_ask.py rename to wren-ai-service/tests/pytest/pipelines/generation/test_ask.py diff --git a/wren-ai-service/tests/pytest/pipelines/test_ask_details.py b/wren-ai-service/tests/pytest/pipelines/generation/test_ask_details.py similarity index 100% rename from wren-ai-service/tests/pytest/pipelines/test_ask_details.py rename to wren-ai-service/tests/pytest/pipelines/generation/test_ask_details.py diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/__init__.py b/wren-ai-service/tests/pytest/pipelines/indexing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_deletion.py b/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_deletion.py new file mode 100644 index 000000000..096f97c96 --- /dev/null +++ b/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_deletion.py @@ -0,0 +1,42 @@ +import pytest + +from src.config import settings +from src.core.provider import DocumentStoreProvider +from src.pipelines.indexing.sql_pairs_deletion import SqlPairsDeletion +from src.pipelines.indexing.sql_pairs_preparation import SqlPair, SqlPairsPreparation +from src.providers import generate_components + + +@pytest.mark.asyncio +async def test_sql_pairs_deletion(): + pipe_components = generate_components(settings.components) + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + store = document_store_provider.get_store( + dataset_name="sql_pairs", + recreate_index=True, + ) + + sql_pairs = [ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ] + sql_pairs_preparation = SqlPairsPreparation( + **pipe_components["sql_pairs_preparation"] + ) + await sql_pairs_preparation.run( + sql_pairs=sql_pairs, + id="fake-id", + ) + + sql_pairs_deletion = SqlPairsDeletion(**pipe_components["sql_pairs_deletion"]) + await sql_pairs_deletion.run( + id="fake-id-2", sql_pair_ids=[sql_pair.id for sql_pair in sql_pairs] + ) + assert await store.count_documents() == 2 + + await sql_pairs_deletion.run( + id="fake-id", sql_pair_ids=[sql_pair.id for sql_pair in sql_pairs] + ) + assert await store.count_documents() == 0 diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_preparation.py b/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_preparation.py new file mode 100644 index 000000000..f2eeefbee --- /dev/null +++ b/wren-ai-service/tests/pytest/pipelines/indexing/test_sql_pairs_preparation.py @@ -0,0 +1,79 @@ +import pytest + +from src.config import settings +from src.core.provider import DocumentStoreProvider +from src.pipelines.indexing.sql_pairs_preparation import SqlPair, SqlPairsPreparation +from src.providers import generate_components + + +@pytest.mark.asyncio +async def test_sql_pairs_preparation_saving_to_document_store(): + pipe_components = generate_components(settings.components) + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + store = document_store_provider.get_store( + dataset_name="sql_pairs", + recreate_index=True, + ) + + sql_pairs_preparation = SqlPairsPreparation( + **pipe_components["sql_pairs_preparation"] + ) + await sql_pairs_preparation.run( + sql_pairs=[ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ], + id="fake-id", + ) + + assert await store.count_documents() == 2 + documents = store.filter_documents() + for document in documents: + assert document.content, "content should not be empty" + assert document.meta, "meta should not be empty" + assert document.meta.get("sql_pair_id"), "sql_pair_id should be in meta" + assert document.meta.get("sql"), "sql should be in meta" + + +@pytest.mark.asyncio +async def test_sql_pairs_preparation_saving_to_document_store_with_multiple_project_ids(): + pipe_components = generate_components(settings.components) + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + store = document_store_provider.get_store( + dataset_name="sql_pairs", + recreate_index=True, + ) + + sql_pairs_preparation = SqlPairsPreparation( + **pipe_components["sql_pairs_preparation"] + ) + await sql_pairs_preparation.run( + sql_pairs=[ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ], + id="fake-id", + ) + + await sql_pairs_preparation.run( + sql_pairs=[ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ], + id="fake-id-2", + ) + + assert await store.count_documents() == 4 + documents = store.filter_documents( + filters={ + "operator": "AND", + "conditions": [ + {"field": "project_id", "operator": "==", "value": "fake-id"}, + ], + } + ) + assert len(documents) == 2 diff --git a/wren-ai-service/tests/pytest/pipelines/retrieval/__init__.py b/wren-ai-service/tests/pytest/pipelines/retrieval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wren-ai-service/tests/pytest/services/mocks.py b/wren-ai-service/tests/pytest/services/mocks.py index c4e145262..dadee6994 100644 --- a/wren-ai-service/tests/pytest/services/mocks.py +++ b/wren-ai-service/tests/pytest/services/mocks.py @@ -1,7 +1,6 @@ from typing import Optional -from src.pipelines.generation import intent_classification, sql_generation, sql_summary -from src.pipelines.retrieval import historical_question, retrieval +from src.pipelines import generation, retrieval from src.web.v1.services import Configuration from src.web.v1.services.ask import AskHistory @@ -14,7 +13,7 @@ async def run(self, query: str, id: Optional[str] = None): return {"construct_retrieval_results": self._documents} -class HistoricalQuestionMock(historical_question.HistoricalQuestion): +class HistoricalQuestionMock(retrieval.HistoricalQuestionRetrieval): def __init__(self, documents: list = []): self._documents = documents @@ -22,7 +21,7 @@ async def run(self, query: str, id: Optional[str] = None): return {"formatted_output": {"documents": self._documents}} -class IntentClassificationMock(intent_classification.IntentClassification): +class IntentClassificationMock(generation.IntentClassification): def __init__(self, intent: str = "MISLEADING_QUERY"): self._intent = intent @@ -32,7 +31,7 @@ async def run( return {"post_process": {"intent": self._intent, "db_schemas": []}} -class GenerationMock(sql_generation.SQLGeneration): +class GenerationMock(generation.SQLGeneration): def __init__(self, valid: list = [], invalid: list = []): self._valid = valid self._invalid = invalid @@ -53,7 +52,7 @@ async def run( } -class SQLSummaryMock(sql_summary.SQLSummary): +class SQLSummaryMock(generation.SQLSummary): """ Example for the results: [ diff --git a/wren-ai-service/tests/pytest/services/test_ask.py b/wren-ai-service/tests/pytest/services/test_ask.py index a5493ac12..4600953d4 100644 --- a/wren-ai-service/tests/pytest/services/test_ask.py +++ b/wren-ai-service/tests/pytest/services/test_ask.py @@ -7,14 +7,7 @@ import pytest from src.config import settings -from src.pipelines import indexing -from src.pipelines.generation import ( - data_assistance, - intent_classification, - sql_correction, - sql_generation, -) -from src.pipelines.retrieval import historical_question, retrieval +from src.pipelines import generation, indexing, retrieval from src.providers import generate_components from src.web.v1.services.ask import ( AskRequest, @@ -41,22 +34,22 @@ def ask_service(): return AskService( { - "intent_classification": intent_classification.IntentClassification( + "intent_classification": generation.IntentClassification( **pipe_components["intent_classification"], ), - "data_assistance": data_assistance.DataAssistance( + "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"], ), "retrieval": retrieval.Retrieval( **pipe_components["db_schema_retrieval"], ), - "historical_question": historical_question.HistoricalQuestion( + "historical_question_retrieval": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), } @@ -167,7 +160,7 @@ def _ask_service_ttl_mock(query: str): f"mock document 2 for {query}", ] ), - "historical_question": HistoricalQuestionMock(), + "historical_question_retrieval": HistoricalQuestionMock(), "sql_generation": GenerationMock( valid=[{"sql": "select count(*) from books"}], ), diff --git a/wren-ai-service/tests/pytest/services/test_ask_details.py b/wren-ai-service/tests/pytest/services/test_ask_details.py index c5b251ebc..ca97fc27a 100644 --- a/wren-ai-service/tests/pytest/services/test_ask_details.py +++ b/wren-ai-service/tests/pytest/services/test_ask_details.py @@ -3,7 +3,7 @@ import pytest from src.config import settings -from src.pipelines.generation import sql_breakdown, sql_summary +from src.pipelines import generation from src.providers import generate_components from src.web.v1.services.ask_details import ( AskDetailsRequest, @@ -17,10 +17,10 @@ def ask_details_service(): pipe_components = generate_components(settings.components) return AskDetailsService( { - "sql_breakdown": sql_breakdown.SQLBreakdown( + "sql_breakdown": generation.SQLBreakdown( **pipe_components["sql_breakdown"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), } diff --git a/wren-ai-service/tests/pytest/services/test_sql_pairs_preparation.py b/wren-ai-service/tests/pytest/services/test_sql_pairs_preparation.py new file mode 100644 index 000000000..40ef4387c --- /dev/null +++ b/wren-ai-service/tests/pytest/services/test_sql_pairs_preparation.py @@ -0,0 +1,164 @@ +import uuid + +import pytest + +from src.config import settings +from src.core.provider import DocumentStoreProvider +from src.globals import create_service_container +from src.providers import generate_components +from src.web.v1.services.sql_pairs_preparation import ( + DeleteSqlPairsRequest, + SqlPair, + SqlPairsPreparationRequest, + SqlPairsPreparationService, + SqlPairsPreparationStatusRequest, +) + + +@pytest.fixture +def sql_pairs_preparation_service(): + pipe_components = generate_components(settings.components) + service_container = create_service_container(pipe_components, settings) + + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + document_store_provider.get_store( + dataset_name="sql_pairs", + recreate_index=True, + ) + + return service_container.sql_pairs_preparation_service + + +@pytest.fixture +def service_metadata(): + return { + "pipes_metadata": { + "mock": { + "generation_model": "mock-llm-model", + "generation_model_kwargs": {}, + "embedding_model": "mock-embedding-model", + "embedding_model_dim": 768, + }, + }, + "service_version": "0.8.0-mock", + } + + +@pytest.mark.asyncio +async def test_sql_pairs_preparation( + sql_pairs_preparation_service: SqlPairsPreparationService, + service_metadata: dict, +): + request = SqlPairsPreparationRequest( + sql_pairs=[ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ], + project_id="fake-id", + ) + request.query_id = str(uuid.uuid4()) + await sql_pairs_preparation_service.prepare_sql_pairs( + request, + service_metadata=service_metadata, + ) + + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest(sql_pairs_preparation_id=request.query_id) + ) + ) + while ( + sql_pairs_preparation_response.status != "finished" + and sql_pairs_preparation_response.status != "failed" + ): + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest( + sql_pairs_preparation_id=request.query_id + ) + ) + ) + + assert sql_pairs_preparation_response.status == "finished" + pipe_components = generate_components(settings.components) + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + store = document_store_provider.get_store( + dataset_name="sql_pairs", + ) + assert await store.count_documents() == 2 + + +@pytest.mark.asyncio +async def test_sql_pairs_deletion( + sql_pairs_preparation_service: SqlPairsPreparationService, + service_metadata: dict, +): + request = SqlPairsPreparationRequest( + sql_pairs=[ + SqlPair(sql="SELECT * FROM book", id="1"), + SqlPair(sql="SELECT * FROM author", id="2"), + ], + project_id="fake-id", + ) + request.query_id = str(uuid.uuid4()) + await sql_pairs_preparation_service.prepare_sql_pairs( + request, + service_metadata=service_metadata, + ) + + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest(sql_pairs_preparation_id=request.query_id) + ) + ) + while ( + sql_pairs_preparation_response.status != "finished" + and sql_pairs_preparation_response.status != "failed" + ): + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest( + sql_pairs_preparation_id=request.query_id + ) + ) + ) + + assert sql_pairs_preparation_response.status == "finished" + + deletion_request = DeleteSqlPairsRequest( + ids=["1", "2"], + project_id="fake-id", + ) + deletion_request.query_id = request.query_id + await sql_pairs_preparation_service.delete_sql_pairs(deletion_request) + + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest(sql_pairs_preparation_id=request.query_id) + ) + ) + while ( + sql_pairs_preparation_response.status != "finished" + and sql_pairs_preparation_response.status != "failed" + ): + sql_pairs_preparation_response = ( + sql_pairs_preparation_service.get_prepare_sql_pairs_status( + SqlPairsPreparationStatusRequest( + sql_pairs_preparation_id=request.query_id + ) + ) + ) + + assert sql_pairs_preparation_response.status == "finished" + pipe_components = generate_components(settings.components) + document_store_provider: DocumentStoreProvider = pipe_components[ + "sql_pairs_preparation" + ]["document_store_provider"] + store = document_store_provider.get_store( + dataset_name="sql_pairs", + ) + assert await store.count_documents() == 0 diff --git a/wren-ai-service/tests/pytest/test_config.py b/wren-ai-service/tests/pytest/test_config.py index 7b12acbed..7b522e0dd 100644 --- a/wren-ai-service/tests/pytest/test_config.py +++ b/wren-ai-service/tests/pytest/test_config.py @@ -22,7 +22,6 @@ def test_settings_default_values(): assert settings.langfuse_enable is True assert settings.logging_level == "INFO" - assert settings.enable_timer is False assert settings.development is False assert settings.config_path == "config.yaml" diff --git a/wren-ai-service/tests/pytest/test_main.py b/wren-ai-service/tests/pytest/test_main.py index 4bcead5a9..cb82fa9bf 100644 --- a/wren-ai-service/tests/pytest/test_main.py +++ b/wren-ai-service/tests/pytest/test_main.py @@ -23,7 +23,7 @@ def app(): } -def test_semantics_preparations(app): +def test_semantics_preparation(app): with TestClient(app) as client: semantics_preperation_id = GLOBAL_DATA["semantics_preperation_id"] diff --git a/wren-ai-service/tests/pytest/test_utils.py b/wren-ai-service/tests/pytest/test_utils.py index 422cfef09..32842f8fc 100644 --- a/wren-ai-service/tests/pytest/test_utils.py +++ b/wren-ai-service/tests/pytest/test_utils.py @@ -65,7 +65,6 @@ class Request: project_id = "mock-project-id" thread_id = "mock-thread-id" mdl_hash = "mock-mdl-hash" - user_id = "mock-user-id" query = "mock-user-query" @utils.trace_metadata @@ -75,7 +74,7 @@ async def my_function(_: str, b: Request, **kwargs): asyncio.run(my_function("", Request(), service_metadata=asdict(service_metadata))) function.assert_called_once_with( - user_id="mock-user-id", + user_id=None, session_id="mock-thread-id", release=service_metadata.service_version, metadata={ diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 2ba4c8ecc..1d6dbb184 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -122,10 +122,25 @@ pipes: document_store: qdrant - name: data_assistance llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_preparation + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_deletion + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + - name: sql_pairs_retrieval + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: preprocess_sql_data llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui + - name: chart_generation + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: chart_adjustment + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: host: 127.0.0.1 @@ -138,6 +153,5 @@ settings: query_cache_ttl: 3600 langfuse_host: https://cloud.langfuse.com langfuse_enable: true - enable_timer: false logging_level: DEBUG development: true diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 742cafaaf..61bbc3965 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -141,9 +141,24 @@ pipes: document_store: qdrant - name: data_assistance llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_preparation + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: sql_pairs_deletion + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + - name: sql_pairs_retrieval + document_store: qdrant + embedder: openai_embedder.text-embedding-3-large + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: preprocess_sql_data + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui - - name: preprocess_sql_data + - name: chart_generation + llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: chart_adjustment llm: litellm_llm.gpt-4o-mini-2024-07-18 --- @@ -158,6 +173,5 @@ settings: query_cache_ttl: 3600 langfuse_host: https://cloud.langfuse.com langfuse_enable: true - enable_timer: false logging_level: INFO development: false diff --git a/wren-ui/README.md b/wren-ui/README.md index 51317147c..8b6d45069 100644 --- a/wren-ui/README.md +++ b/wren-ui/README.md @@ -138,7 +138,6 @@ wren-ai-service: EMBEDDER_OPENAI_API_KEY: ${EMBEDDER_OPENAI_API_KEY} LLM_AZURE_OPENAI_API_KEY: ${LLM_AZURE_OPENAI_API_KEY} EMBEDDER_AZURE_OPENAI_API_KEY: ${EMBEDDER_AZURE_OPENAI_API_KEY} - ENABLE_TIMER: ${AI_SERVICE_ENABLE_TIMER} LOGGING_LEVEL: ${AI_SERVICE_LOGGING_LEVEL} networks: - wren