From 1569e231a0dc1a75cb072638e68d01f6b615a2a3 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 5 Nov 2024 18:24:08 +0700 Subject: [PATCH] Merge input & output schema --- .../backend/backend/blocks/basic.py | 58 ++++++++---- .../backend/backend/data/graph.py | 91 +++++++++++-------- .../backend/test/data/test_graph.py | 50 +++++++--- 3 files changed, 129 insertions(+), 70 deletions(-) diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index 4acbaf5a58f8..6e8a2906d7f8 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -148,9 +148,12 @@ class Input(BlockSchema): description="The value to be passed as input.", default=None, ) - description: str = SchemaField( + title: str | None = SchemaField( + description="The title of the input.", default=None, advanced=True + ) + description: str | None = SchemaField( description="The description of the input.", - default="", + default=None, advanced=True, ) placeholder_values: List[Any] = SchemaField( @@ -163,6 +166,16 @@ class Input(BlockSchema): default=False, advanced=True, ) + advanced: bool = SchemaField( + description="Whether to show the input in the advanced section, if the field is not required.", + default=False, + advanced=True, + ) + secret: bool = SchemaField( + description="Whether the input should be treated as a secret.", + default=False, + advanced=True, + ) class Output(BlockSchema): result: Any = SchemaField(description="The value passed as input.") @@ -195,6 +208,7 @@ def __init__(self): ], categories={BlockCategory.INPUT, BlockCategory.BASIC}, block_type=BlockType.INPUT, + static_output=True, ) def run(self, input_data: Input, **kwargs) -> BlockOutput: @@ -205,28 +219,25 @@ class AgentOutputBlock(Block): """ Records the output of the graph for users to see. - Attributes: - recorded_value: The value to be recorded as output. - name: The name of the output. - description: The description of the output. - fmt_string: The format string to be used to format the recorded_value. - - Outputs: - output: The formatted recorded_value if fmt_string is provided and the recorded_value - can be formatted, otherwise the raw recorded_value. - Behavior: - If fmt_string is provided and the recorded_value is of a type that can be formatted, - the block attempts to format the recorded_value using the fmt_string. - If formatting fails or no fmt_string is provided, the raw recorded_value is output. + If `format` is provided and the `value` is of a type that can be formatted, + the block attempts to format the recorded_value using the `format`. + If formatting fails or no `format` is provided, the raw `value` is output. """ class Input(BlockSchema): - value: Any = SchemaField(description="The value to be recorded as output.") + value: Any = SchemaField( + description="The value to be recorded as output.", + default=None, + advanced=False, + ) name: str = SchemaField(description="The name of the output.") - description: str = SchemaField( + title: str | None = SchemaField( + description="The title of the input.", default=None, advanced=True + ) + description: str | None = SchemaField( description="The description of the output.", - default="", + default=None, advanced=True, ) format: str = SchemaField( @@ -234,6 +245,16 @@ class Input(BlockSchema): default="", advanced=True, ) + advanced: bool = SchemaField( + description="Whether to treat the output as advanced.", + default=False, + advanced=True, + ) + secret: bool = SchemaField( + description="Whether the output should be treated as a secret.", + default=False, + advanced=True, + ) class Output(BlockSchema): output: Any = SchemaField(description="The value recorded as output.") @@ -271,6 +292,7 @@ def __init__(self): ], categories={BlockCategory.OUTPUT, BlockCategory.BASIC}, block_type=BlockType.OUTPUT, + static_output=True, ) def run(self, input_data: Input, **kwargs) -> BlockOutput: diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 3f7fc00a52f2..d490208262a3 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -3,13 +3,13 @@ import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any, Literal, Type from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink from prisma.types import AgentGraphWhereInput -from pydantic import BaseModel from pydantic.fields import computed_field +from backend.blocks.basic import AgentInputBlock, AgentOutputBlock from backend.data.block import BlockInput, BlockType, get_block, get_blocks from backend.data.db import BaseDbModel, transaction from backend.data.execution import ExecutionStatus @@ -19,19 +19,6 @@ logger = logging.getLogger(__name__) -class InputSchemaItem(BaseModel): - node_id: str - title: str | None = None - description: str | None = None - default: Any | None = None - - -class OutputSchemaItem(BaseModel): - node_id: str - title: str | None = None - description: str | None = None - - class Link(BaseDbModel): source_id: str sink_id: str @@ -118,36 +105,60 @@ class Graph(BaseDbModel): nodes: list[Node] = [] links: list[Link] = [] - @computed_field - @property - def input_schema(self) -> dict[str, InputSchemaItem]: + @staticmethod + def _generate_schema( + type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], + data: list[dict], + ) -> dict[str, Any]: + props = [] + for p in data: + try: + props.append(type_class(**p)) + except Exception as e: + logger.warning(f"Invalid {type_class}: {p}, {e}") + return { - node.input_default["name"]: InputSchemaItem( - node_id=node.id, - title=node.input_default.get("title"), - description=node.input_default.get("description"), - default=node.input_default.get("value"), - ) - for node in self.nodes - if (b := get_block(node.block_id)) - and b.block_type == BlockType.INPUT - and "name" in node.input_default + "type": "object", + "properties": { + p.name: { + "secret": p.secret, + "advanced": p.advanced, + "title": p.title or p.name, + **({"description": p.description} if p.description else {}), + **({"default": p.value} if p.value is not None else {}), + } + for p in props + }, + "required": [p.name for p in props if p.value is None], } @computed_field @property - def output_schema(self) -> dict[str, OutputSchemaItem]: - return { - node.input_default["name"]: OutputSchemaItem( - node_id=node.id, - title=node.input_default.get("title"), - description=node.input_default.get("description"), - ) - for node in self.nodes - if (b := get_block(node.block_id)) - and b.block_type == BlockType.OUTPUT - and "name" in node.input_default - } + def input_schema(self) -> dict[str, Any]: + return self._generate_schema( + AgentInputBlock.Input, + [ + node.input_default + for node in self.nodes + if (b := get_block(node.block_id)) + and b.block_type == BlockType.INPUT + and "name" in node.input_default + ], + ) + + @computed_field + @property + def output_schema(self) -> dict[str, Any]: + return self._generate_schema( + AgentOutputBlock.Input, + [ + node.input_default + for node in self.nodes + if (b := get_block(node.block_id)) + and b.block_type == BlockType.OUTPUT + and "name" in node.input_default + ], + ) @property def starting_nodes(self) -> list[Node]: diff --git a/autogpt_platform/backend/test/data/test_graph.py b/autogpt_platform/backend/test/data/test_graph.py index a311f1d2f8cf..050e20fdc04b 100644 --- a/autogpt_platform/backend/test/data/test_graph.py +++ b/autogpt_platform/backend/test/data/test_graph.py @@ -1,9 +1,12 @@ +from typing import Any from uuid import UUID import pytest from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock +from backend.data.block import BlockSchema from backend.data.graph import Graph, Link, Node +from backend.data.model import SchemaField from backend.data.user import DEFAULT_USER_ID from backend.server.model import CreateGraph from backend.util.test import SpinTestServer @@ -38,7 +41,7 @@ async def test_graph_creation(server: SpinTestServer): source_id="node_1", sink_id="node_2", source_name="output", - sink_name="input", + sink_name="name", ), ], ) @@ -85,11 +88,18 @@ async def test_get_input_schema(server: SpinTestServer): description="Test input schema", nodes=[ Node( - id="node_0", + id="node_0_a", block_id=input_block, - input_default={"name": "in_key", "title": "Input Key"}, + input_default={"name": "in_key_a", "title": "Key A", "value": "A"}, + metadata={"id": "node_0_a"}, ), - Node(id="node_1", block_id=value_block), + Node( + id="node_0_b", + block_id=input_block, + input_default={"name": "in_key_b", "advanced": True}, + metadata={"id": "node_0_b"}, + ), + Node(id="node_1", block_id=value_block, metadata={"id": "node_1"}), Node( id="node_2", block_id=output_block, @@ -97,13 +107,20 @@ async def test_get_input_schema(server: SpinTestServer): "name": "out_key", "description": "This is an output key", }, + metadata={"id": "node_2"}, ), ], links=[ Link( - source_id="node_0", + source_id="node_0_a", sink_id="node_1", - source_name="output", + source_name="result", + sink_name="input", + ), + Link( + source_id="node_0_b", + sink_id="node_1", + source_name="result", sink_name="input", ), Link( @@ -120,12 +137,21 @@ async def test_get_input_schema(server: SpinTestServer): create_graph, DEFAULT_USER_ID ) + class ExpectedInputSchema(BlockSchema): + in_key_a: Any = SchemaField(title="Key A", default="A", advanced=False) + in_key_b: Any = SchemaField(title="in_key_b", advanced=True) + + class ExpectedOutputSchema(BlockSchema): + out_key: Any = SchemaField( + description="This is an output key", + title="out_key", + advanced=False, + ) + input_schema = created_graph.input_schema - assert len(input_schema) == 1 - assert input_schema["in_key"].node_id == created_graph.nodes[0].id - assert input_schema["in_key"].title == "Input Key" + input_schema["title"] = "ExpectedInputSchema" + assert input_schema == ExpectedInputSchema.jsonschema() output_schema = created_graph.output_schema - assert len(output_schema) == 1 - assert output_schema["out_key"].node_id == created_graph.nodes[2].id - assert output_schema["out_key"].description == "This is an output key" + output_schema["title"] = "ExpectedOutputSchema" + assert output_schema == ExpectedOutputSchema.jsonschema()