Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add Cascade Delete Function for Transactions and Builds Associated with Flows #3848

Merged
merged 7 commits into from
Sep 18, 2024
12 changes: 12 additions & 0 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import TYPE_CHECKING, Any

from fastapi import HTTPException
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from sqlalchemy import delete
from sqlmodel import Session

from langflow.graph.graph.base import Graph
Expand Down Expand Up @@ -241,3 +244,12 @@ def parse_value(value: Any, input_type: str) -> Any:
return float(value) if value is not None else None
else:
return value


async def cascade_delete_flow(session: Session, flow: Flow):
try:
session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow.id)) # type: ignore
session.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow.id)) # type: ignore
session.exec(delete(Flow).where(Flow.id == flow.id)) # type: ignore
except Exception as e:
raise RuntimeError(f"Unable to cascade delete flow: ${flow.id}", e)
8 changes: 4 additions & 4 deletions src/backend/base/langflow/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from loguru import logger
from sqlmodel import Session, and_, col, select

from langflow.api.utils import remove_api_keys, validate_is_component
from langflow.api.utils import cascade_delete_flow, remove_api_keys, validate_is_component
from langflow.api.v1.schemas import FlowListCreate
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.flow import Flow, FlowCreate, FlowRead, FlowUpdate
from langflow.services.database.models.flow.utils import delete_flow_by_id, get_webhook_component_in_flow
from langflow.services.database.models.flow.utils import get_webhook_component_in_flow
from langflow.services.database.models.folder.constants import DEFAULT_FOLDER_NAME
from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.transactions.crud import get_transactions_by_flow_id
Expand Down Expand Up @@ -251,7 +251,7 @@ def update_flow(


@router.delete("/{flow_id}", status_code=200)
def delete_flow(
async def delete_flow(
*,
session: Session = Depends(get_session),
flow_id: UUID,
Expand All @@ -267,7 +267,7 @@ def delete_flow(
)
if not flow:
raise HTTPException(status_code=404, detail="Flow not found")
delete_flow_by_id(str(flow_id), session)
await cascade_delete_flow(session, flow)
session.commit()
return {"message": "Flow deleted successfully"}

Expand Down
13 changes: 8 additions & 5 deletions src/backend/base/langflow/api/v1/folders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langflow.api.utils import cascade_delete_flow
import orjson
from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status
from sqlalchemy import or_, update
Expand Down Expand Up @@ -171,22 +172,24 @@ def update_folder(


@router.delete("/{folder_id}", status_code=204)
def delete_folder(
async def delete_folder(
*,
session: Session = Depends(get_session),
folder_id: str,
current_user: User = Depends(get_current_active_user),
):
try:
flows = session.exec(select(Flow).where(Flow.folder_id == folder_id, Folder.user_id == current_user.id)).all()
if len(flows) > 0:
for flow in flows:
await cascade_delete_flow(session, flow)

folder = session.exec(select(Folder).where(Folder.id == folder_id, Folder.user_id == current_user.id)).first()
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
session.delete(folder)
session.commit()
flows = session.exec(select(Flow).where(Flow.folder_id == folder_id, Folder.user_id == current_user.id)).all()
for flow in flows:
session.delete(flow)
session.commit()

return Response(status_code=status.HTTP_204_NO_CONTENT)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down
12 changes: 0 additions & 12 deletions src/backend/base/langflow/services/database/models/flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
from fastapi import Depends
from langflow.utils.version import get_version_info
from sqlmodel import Session
from sqlalchemy import delete

from langflow.services.deps import get_session

from .model import Flow
from .. import TransactionTable, MessageTable
from loguru import logger


def get_flow_by_id(session: Session = Depends(get_session), flow_id: Optional[str] = None) -> Flow | None:
Expand All @@ -21,15 +18,6 @@ def get_flow_by_id(session: Session = Depends(get_session), flow_id: Optional[st
return session.get(Flow, flow_id)


def delete_flow_by_id(flow_id: str, session: Session) -> None:
"""Delete flow by id."""
# Manually delete flow, transactions and messages because foreign key constraints might be disabled
session.exec(delete(Flow).where(Flow.id == flow_id)) # type: ignore
session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow_id)) # type: ignore
session.exec(delete(MessageTable).where(MessageTable.flow_id == flow_id)) # type: ignore
logger.info(f"Deleted flow {flow_id}")


def get_webhook_component_in_flow(flow_data: dict):
"""Get webhook component in flow data."""

Expand Down
70 changes: 70 additions & 0 deletions src/backend/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import namedtuple
from uuid import UUID, uuid4

from langflow.services.database.models.folder.model import FolderCreate
import orjson
import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -196,6 +197,75 @@ async def test_delete_flows_with_transaction_and_build(
assert response.json() == {"vertex_builds": {}}


@pytest.mark.asyncio
async def test_delete_folder_with_flows_with_transaction_and_build(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
# Create a new folder
folder_name = f"Test Folder {uuid4()}"
folder = FolderCreate(name=folder_name, description="Test folder description", components_list=[], flows_list=[])

response = client.post("/api/v1/folders/", json=folder.model_dump(), headers=logged_in_headers)
assert response.status_code == 201, f"Expected status code 201, but got {response.status_code}"

created_folder = response.json()
folder_id = created_folder["id"]

# Create ten flows
number_of_flows = 10
flows = [FlowCreate(name=f"Flow {i}", description="description", data={}) for i in range(number_of_flows)]
flow_ids = []
for flow in flows:
flow.folder_id = folder_id
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
assert response.status_code == 201
flow_ids.append(response.json()["id"])

# Create a transaction for each flow
for flow_id in flow_ids:
VertexTuple = namedtuple("VertexTuple", ["id"])

await log_transaction(
str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success"
)

# Create a build for each flow
for flow_id in flow_ids:
build = {
"valid": True,
"params": {},
"data": ResultDataResponse(),
"artifacts": {},
"vertex_id": "vid",
"flow_id": flow_id,
}
log_vertex_build(
flow_id=build["flow_id"],
vertex_id=build["vertex_id"],
valid=build["valid"],
params=build["params"],
data=build["data"],
artifacts=build.get("artifacts"),
)

response = client.request("DELETE", f"api/v1/folders/{folder_id}", headers=logged_in_headers)
assert response.status_code == 204

for flow_id in flow_ids:
response = client.request(
"GET", "api/v1/monitor/transactions", params={"flow_id": flow_id}, headers=logged_in_headers
)
assert response.status_code == 200
assert response.json() == []

for flow_id in flow_ids:
response = client.request(
"GET", "api/v1/monitor/builds", params={"flow_id": flow_id}, headers=logged_in_headers
)
assert response.status_code == 200
assert response.json() == {"vertex_builds": {}}


def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
Expand Down
Loading