Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: centralize global variable management #3284

Merged
merged 16 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 38 additions & 51 deletions src/backend/base/langflow/api/v1/variable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from datetime import datetime, timezone
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session

from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.services.database.models.variable import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service, get_variable_service
from langflow.services.variable.service import GENERIC_TYPE, DatabaseVariableService

router = APIRouter(prefix="/variables", tags=["Variables"])

Expand All @@ -20,36 +20,30 @@ def create_variable(
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Create a new variable."""
try:
# check if variable name already exists
variable_exists = session.exec(
select(Variable).where(
Variable.name == variable.name,
Variable.user_id == current_user.id,
)
).first()
if variable_exists:
raise HTTPException(status_code=400, detail="Variable name already exists")

variable_dict = variable.model_dump()
variable_dict["user_id"] = current_user.id

db_variable = Variable.model_validate(variable_dict)
if not db_variable.name and not db_variable.value:
if not variable.name and not variable.value:
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
elif not db_variable.name:

if not variable.name:
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
elif not db_variable.value:

if not variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=settings_service)
db_variable.value = encrypted
db_variable.user_id = current_user.id
session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable

if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
raise HTTPException(status_code=400, detail="Variable name already exists")

return variable_service.create_variable(
user_id=current_user.id,
name=variable.name,
value=variable.value,
default_fields=variable.default_fields or [],
_type=variable.type or GENERIC_TYPE,
session=session,
)
except Exception as e:
if isinstance(e, HTTPException):
raise e
Expand All @@ -61,11 +55,12 @@ def read_variables(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Read all variables."""
try:
variables = session.exec(select(Variable).where(Variable.user_id == current_user.id)).all()
return variables
return variable_service.get_all(user_id=current_user.id, session=session)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -77,22 +72,19 @@ def update_variable(
variable_id: UUID,
variable: VariableUpdate,
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Update a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")

variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
session.commit()
session.refresh(db_variable)
return db_variable
return variable_service.update_variable_fields(
user_id=current_user.id,
variable_id=variable_id,
variable=variable,
session=session,
)
except NoResultFound:
raise HTTPException(status_code=404, detail="Variable not found")

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -103,15 +95,10 @@ def delete_variable(
session: Session = Depends(get_session),
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Delete a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
session.delete(db_variable)
session.commit()
variable_service.delete_variable(user_id=current_user.id, variable_id=variable_id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
47 changes: 40 additions & 7 deletions src/backend/base/langflow/services/variable/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Union
from uuid import UUID

Expand All @@ -8,7 +9,7 @@

from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.deps import get_session
from langflow.services.variable.base import VariableService

Expand Down Expand Up @@ -76,21 +77,25 @@ def get_variable(
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()

if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")

if variable.type == CREDENTIAL_TYPE and field == "session_id": # type: ignore
raise TypeError(
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
"because its purpose is to prevent the exposure of values."
)

# we decrypt the value
if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")
decrypted = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
return decrypted

def get_all(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[Variable]]:
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())

def list_variables(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables]
variables = self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]

def update_variable(
self,
Expand All @@ -109,13 +114,41 @@ def update_variable(
session.refresh(variable)
return variable

def update_variable_fields(
self,
user_id: Union[UUID, str],
variable_id: Union[UUID, str],
variable: VariableUpdate,
session: Session = Depends(get_session),
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = session.exec(query).one()

variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=self.settings_service)
variable.value = encrypted

session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable

def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
name: str | None = None,
variable_id: UUID | None = None,
session: Session = Depends(get_session),
):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
stmt = select(Variable).where(Variable.user_id == user_id)
if name:
stmt = stmt.where(Variable.name == name)
if variable_id:
stmt = stmt.where(Variable.id == variable_id)
variable = session.exec(stmt).first()
if not variable:
raise ValueError(f"{name} variable not found.")
session.delete(variable)
Expand Down
Loading
Loading