Skip to content

Commit

Permalink
chore(wren-ai-service): remove pipeline visualization (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Dec 18, 2024
1 parent 95a5a8f commit a133518
Show file tree
Hide file tree
Showing 26 changed files with 360 additions and 957 deletions.
738 changes: 359 additions & 379 deletions wren-ai-service/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tqdm = "^4.66.4"
numpy = "^1.26.4"
sqlparse = "^0.5.0"
orjson = "^3.10.3"
sf-hamilton = {version = "^1.69.0", extras = ["visualization"]}
sf-hamilton = {version = "^1.69.0"}
aiohttp = {extras = ["speedups"], version = "^3.10.2"}
ollama-haystack = "^0.0.6"
langfuse = "^2.43.3"
Expand Down
1 change: 0 additions & 1 deletion wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ def dry_run_pipeline(pipeline_cls: BasicPipeline, pipeline_name: str, **kwargs):
pipeline = pipeline_cls(**pipe_components[pipeline_name])
init_langfuse()

pipeline.visualize(**kwargs)
async_validate(lambda: pipeline.run(**kwargs))

langfuse_context.flush()
31 changes: 0 additions & 31 deletions wren-ai-service/src/pipelines/generation/chart_adjustment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from pathlib import Path
from typing import Any, Dict

import orjson
Expand Down Expand Up @@ -229,36 +228,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
query: str,
sql: str,
adjustment_option: ChartAdjustmentOption,
chart_schema: dict,
data: dict,
language: str,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["post_process"],
output_file_path=f"{destination}/chart_adjustment.dot",
inputs={
"query": query,
"sql": sql,
"adjustment_option": adjustment_option,
"chart_schema": chart_schema,
"data": data,
"language": language,
**self._components,
**self._configs,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="Chart Adjustment")
async def run(
Expand Down
27 changes: 0 additions & 27 deletions wren-ai-service/src/pipelines/generation/chart_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from pathlib import Path
from typing import Any, Dict

import orjson
Expand Down Expand Up @@ -205,32 +204,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
query: str,
sql: str,
data: dict,
language: str,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["post_process"],
output_file_path=f"{destination}/chart_generation.dot",
inputs={
"query": query,
"sql": sql,
"data": data,
"language": language,
**self._components,
**self._configs,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="Chart Generation")
async def run(
Expand Down
28 changes: 0 additions & 28 deletions wren-ai-service/src/pipelines/generation/data_assistance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
import sys
from pathlib import Path
from typing import Any, Optional

from hamilton import base
Expand Down Expand Up @@ -143,33 +142,6 @@ async def _get_streaming_results(query_id):
except TimeoutError:
break

def visualize(
self,
query: str,
db_schemas: list[str],
language: str,
query_id: Optional[str] = None,
history: Optional[AskHistory] = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["data_assistance"],
output_file_path=f"{destination}/data_assistance.dot",
inputs={
"query": query,
"db_schemas": db_schemas,
"language": language,
"query_id": query_id or "",
"history": history,
**self._components,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="Data Assistance")
async def run(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from pathlib import Path
from typing import Any, List

from hamilton import base
Expand Down Expand Up @@ -203,34 +202,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
query: str,
contexts: List[str],
history: AskHistory,
configuration: Configuration = Configuration(),
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["post_process"],
output_file_path=f"{destination}/followup_sql_generation.dot",
inputs={
"query": query,
"documents": contexts,
"history": history,
"project_id": project_id,
"configuration": configuration,
**self._components,
**self._configs,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="Follow-Up SQL Generation")
async def run(
Expand Down
24 changes: 0 additions & 24 deletions wren-ai-service/src/pipelines/generation/intent_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ast
import logging
import sys
from pathlib import Path
from typing import Any, Literal, Optional

import orjson
Expand Down Expand Up @@ -292,29 +291,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
query: str,
id: Optional[str] = None,
history: Optional[AskHistory] = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["post_process"],
output_file_path=f"{destination}/intent_classification.dot",
inputs={
"query": query,
"id": id or "",
"history": history,
**self._components,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="Intent Classification")
async def run(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -175,38 +174,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
mdl: dict,
previous_questions: list[str] = [],
categories: list[str] = [],
language: str = "en",
current_date: str = datetime.now(),
max_questions: int = 5,
max_categories: int = 3,
**_,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
[self._final],
output_file_path=f"{destination}/question_recommendation.dot",
inputs={
"mdl": mdl,
"previous_questions": previous_questions,
"categories": categories,
"language": language,
"current_date": current_date,
"max_questions": max_questions,
"max_categories": max_categories,
**self._components,
},
show_legend=True,
orient="LR",
)

@observe(name="Question Recommendation")
async def run(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -188,27 +187,6 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
mdl: dict,
language: str = "English",
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
[self._final],
output_file_path=f"{destination}/relationship_recommendation.dot",
inputs={
"mdl": mdl,
"language": language,
**self._components,
},
show_legend=True,
orient="LR",
)

@observe(name="Relationship Recommendation")
async def run(
self,
Expand Down
26 changes: 0 additions & 26 deletions wren-ai-service/src/pipelines/generation/semantics_description.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -206,31 +205,6 @@ def __init__(self, llm_provider: LLMProvider, **_):
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(
self,
user_prompt: str,
selected_models: list[str],
mdl: dict,
language: str = "en",
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
[self._final],
output_file_path=f"{destination}/semantics_description.dot",
inputs={
"user_prompt": user_prompt,
"selected_models": selected_models,
"mdl": mdl,
"language": language,
**self._components,
},
show_legend=True,
orient="LR",
)

@observe(name="Semantics Description Generation")
async def run(
self,
Expand Down
28 changes: 0 additions & 28 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
import sys
from pathlib import Path
from typing import Any, Optional

from hamilton import base
Expand Down Expand Up @@ -132,33 +131,6 @@ async def _get_streaming_results(query_id):
except TimeoutError:
break

def visualize(
self,
query: str,
sql: str,
sql_data: dict,
language: str,
query_id: Optional[str] = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)

self._pipe.visualize_execution(
["generate_answer"],
output_file_path=f"{destination}/sql_answer.dot",
inputs={
"query": query,
"sql": sql,
"sql_data": sql_data,
"language": language,
"query_id": query_id,
**self._components,
},
show_legend=True,
orient="LR",
)

@async_timer
@observe(name="SQL Answer Generation")
async def run(
Expand Down
Loading

0 comments on commit a133518

Please sign in to comment.