Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/type-error #11240

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any

from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
Expand Down Expand Up @@ -36,7 +39,7 @@ def convert(cls, config: dict) -> ModelConfigEntity:
)

@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
"""
Validate and set defaults for model config

Expand Down
46 changes: 12 additions & 34 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union

from flask import Flask, current_app
from pydantic import ValidationError
Expand Down Expand Up @@ -33,37 +33,15 @@


class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...

@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...

def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

Expand Down Expand Up @@ -134,7 +112,7 @@ def generate(
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
Expand All @@ -148,12 +126,12 @@ def generate(
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream,
stream=streaming,
)

def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

Expand Down Expand Up @@ -182,7 +160,7 @@ def single_iteration_generate(
query="",
files=[],
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
Expand All @@ -197,7 +175,7 @@ def single_iteration_generate(
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
stream=stream,
stream=streaming,
)

def _generate(
Expand All @@ -209,7 +187,7 @@ def _generate(
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

Expand Down
5 changes: 3 additions & 2 deletions api/core/app/apps/agent_chat/app_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional

from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_app_config(
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
"""
Validate for agent chat app model config

Expand Down
40 changes: 11 additions & 29 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from collections.abc import Generator, Mapping
from typing import Any, Union

from flask import Flask, current_app
from pydantic import ValidationError
Expand All @@ -28,34 +28,15 @@


class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict, None, None]: ...

@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...

def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

Expand All @@ -65,7 +46,7 @@ def generate(
:param invoke_from: invoke from source
:param stream: is stream
"""
if not stream:
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")

if not args.get("query"):
Expand Down Expand Up @@ -96,7 +77,8 @@ def generate(

# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id,
config=args["model_config"],
)

# always enable retriever resource in debugger mode
Expand Down Expand Up @@ -141,7 +123,7 @@ def generate(
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
Expand Down Expand Up @@ -182,7 +164,7 @@ def generate(
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)

return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
Expand Down
8 changes: 5 additions & 3 deletions api/core/app/apps/base_app_generate_response_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Union

from core.app.entities.app_invoke_entities import InvokeFrom
Expand All @@ -14,8 +14,10 @@ class AppGenerateResponseConverter(ABC):

@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
cls,
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
invoke_from: InvokeFrom,
) -> Mapping[str, Any] | Generator[str, None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
Expand Down
6 changes: 3 additions & 3 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate(
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
streaming: bool = True,
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
Expand Down Expand Up @@ -142,7 +142,7 @@ def generate(
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=stream,
stream=streaming,
)

# init generate records
Expand Down Expand Up @@ -179,7 +179,7 @@ def generate(
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)

return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
Expand Down
6 changes: 3 additions & 3 deletions api/core/app/apps/completion/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generate(
) -> dict: ...

def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
Expand Down Expand Up @@ -119,7 +119,7 @@ def generate(
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
Expand Down Expand Up @@ -158,7 +158,7 @@ def generate(
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)

return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
Expand Down
Loading