Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update builtin tool provider methods to use session management #11938

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
Expand Down Expand Up @@ -91,26 +93,28 @@ def post(self, provider):

args = parser.parse_args()

return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
args["credentials"],
)
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
return result


class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id

return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
tenant_id=tenant_id,
provider_name=provider,
)


Expand Down
33 changes: 16 additions & 17 deletions api/services/tools/builtin_tools_manage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import logging
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.orm import Session

from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
Expand Down Expand Up @@ -32,7 +35,7 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str
tenant_id=tenant_id, provider_controller=provider_controller
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = (
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
Expand Down Expand Up @@ -71,19 +74,18 @@ def list_builtin_provider_credentials_schema(provider_name):
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])

@staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
provider = session.scalar(stmt)

try:
# get provider
Expand Down Expand Up @@ -115,29 +117,26 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st
encrypted_credentials=json.dumps(credentials),
)

db.session.add(provider)
db.session.commit()
session.add(provider)

else:
provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()

# delete cache
tool_configuration.delete_tool_credentials_cache()

return {"result": "success"}

@staticmethod
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
Expand All @@ -156,7 +155,7 @@ def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st
"""
delete tool provider
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
Expand Down
Loading