Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(blocks): Add support for mutually exclusive input fields #8856

Merged
72 changes: 72 additions & 0 deletions autogpt_platform/backend/backend/blocks/test_mutual_exclusive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import List, Union

from pydantic import BaseModel
from typing_extensions import Literal

from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField


class PollOption(BaseModel):
text: str


class Poll(BaseModel):
discriminator: Literal["poll"]
some_input_4: List[PollOption]


class MediaUpload(BaseModel):
discriminator: Literal["media"]
some_input: str
some_input_2: str


class PollDuration(BaseModel):
discriminator: Literal["duration"]
some_input: int


class TweetBlock(Block):
class Input(BlockSchema):
tweet_text: str = SchemaField(
title="Tweet Text",
description="The main text content of the tweet",
)

attachment: Union[Poll, MediaUpload, PollDuration] = SchemaField(
discriminator="discriminator",
title="Tweet Attachment",
description="Optional tweet attachment (poll, media, or duration)",
)

class Output(BlockSchema):
result: str = SchemaField(
description="Shows the tweet content and any attachments"
)

def __init__(self):
super().__init__(
id="b7faa910-b074-11ef-bee7-477f51db4711",
description="Create a tweet with optional attachments",
categories={BlockCategory.BASIC},
input_schema=TweetBlock.Input,
output_schema=TweetBlock.Output,
)

def run(self, input_data: Input, **kwargs) -> BlockOutput:
tweet_content = [f"Tweet Text: {input_data.tweet_text}"]
if isinstance(input_data.attachment, Poll):
options = [opt.text for opt in input_data.attachment.some_input_4]
tweet_content.append(f"Poll Options: {', '.join(options)}")

if isinstance(input_data.attachment, MediaUpload):
tweet_content.append(f"Media URL: {input_data.attachment.some_input}")
tweet_content.append(f"Media URL 2: {input_data.attachment.some_input_2}")

if isinstance(input_data.attachment, PollDuration):
tweet_content.append(
f"Poll Duration: {input_data.attachment.some_input} hours"
)

yield "result", "\n".join(tweet_content)
1 change: 1 addition & 0 deletions autogpt_platform/backend/backend/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def ref_to_dict(obj):
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]

return obj

cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
Expand Down
155 changes: 147 additions & 8 deletions autogpt_platform/frontend/src/components/node-input-components.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
BlockIONumberSubSchema,
BlockIOBooleanSubSchema,
} from "@/lib/autogpt-server-api/types";
import React, { FC, useCallback, useEffect, useState } from "react";
import React, { FC, useCallback, useEffect, useMemo, useState } from "react";
import { Button } from "./ui/button";
import { Switch } from "./ui/switch";
import {
Expand Down Expand Up @@ -326,13 +326,26 @@ export const NodeGenericInputField: FC<{
}
}

if ("oneOf" in propSchema) {
// At the time of writing, this isn't used in the backend -> no impl. needed
console.error(
`Unsupported 'oneOf' in schema for '${propKey}'!`,
propSchema,
if (
"oneOf" in propSchema &&
propSchema.oneOf &&
"discriminator" in propSchema &&
propSchema.discriminator
) {
return (
<NodeOneOfDiscriminatorField
nodeId={nodeId}
propKey={propKey}
propSchema={propSchema}
currentValue={currentValue}
errors={errors}
connections={connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
className={className}
displayName={displayName}
/>
);
return null;
}

if (!("type" in propSchema)) {
Expand Down Expand Up @@ -451,6 +464,132 @@ export const NodeGenericInputField: FC<{
}
};

const NodeOneOfDiscriminatorField: FC<{
nodeId: string;
propKey: string;
propSchema: any;
currentValue?: any;
errors: { [key: string]: string | undefined };
connections: any;
handleInputChange: (key: string, value: any) => void;
handleInputClick: (key: string) => void;
className?: string;
displayName?: string;
}> = ({
nodeId,
propKey,
propSchema,
currentValue,
errors,
connections,
handleInputChange,
handleInputClick,
className,
displayName,
}) => {
const discriminator = propSchema.discriminator;

const discriminatorProperty = discriminator.propertyName;

const variantOptions = useMemo(() => {
const oneOfVariants = propSchema.oneOf || [];

return oneOfVariants
.map((variant: any) => {
const variantDiscValue =
variant.properties?.[discriminatorProperty]?.const;

return {
value: variantDiscValue,
schema: variant,
};
})
.filter((v: any) => v.value != null);
}, [discriminatorProperty, propSchema.oneOf]);

const currentVariant = variantOptions.find(
(opt: any) => currentValue?.[discriminatorProperty] === opt.value,
);

const [chosenType, setChosenType] = useState<string>(
currentVariant?.value || "",
);

const handleVariantChange = (newType: string) => {
setChosenType(newType);
const chosenVariant = variantOptions.find(
(opt: any) => opt.value === newType,
);
if (chosenVariant) {
const initialValue = {
[discriminatorProperty]: newType,
};
handleInputChange(propKey, initialValue);
}
};

const chosenVariantSchema = variantOptions.find(
(opt: any) => opt.value === chosenType,
)?.schema;

return (
<div className={cn("flex flex-col space-y-2", className)}>
<Select value={chosenType || ""} onValueChange={handleVariantChange}>
<SelectTrigger className="w-full">
<SelectValue placeholder="Select a type..." />
</SelectTrigger>
<SelectContent>
{variantOptions.map((opt: any) => (
<SelectItem key={opt.value} value={opt.value}>
{beautifyString(opt.value)}
</SelectItem>
))}
</SelectContent>
</Select>

{chosenVariantSchema && (
<div className={cn(className, "w-full flex-col")}>
{Object.entries(chosenVariantSchema.properties).map(
([someKey, childSchema]) => {
if (someKey === "discriminator") {
return null;
}
const childKey = propKey ? `${propKey}.${someKey}` : someKey;
return (
<div
key={childKey}
className="flex w-full flex-row justify-between space-y-2"
>
<span className="mr-2 mt-3 dark:text-gray-300">
{(childSchema as BlockIOSubSchema).title ||
beautifyString(someKey)}
</span>
<NodeGenericInputField
nodeId={nodeId}
key={propKey}
propKey={childKey}
propSchema={childSchema as BlockIOSubSchema}
currentValue={
currentValue ? currentValue[someKey] : undefined
}
errors={errors}
connections={connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
displayName={
chosenVariantSchema.title || beautifyString(someKey)
}
/>
</div>
);
},
)}
</div>
)}
</div>
);
};

const NodeCredentialsInput: FC<{
selfKey: string;
value: any;
Expand Down Expand Up @@ -849,7 +988,7 @@ const NodeStringInput: FC<{
placeholder={
schema?.placeholder || `Enter ${beautifyString(displayName)}`
}
className="pr-8 read-only:cursor-pointer read-only:text-gray-500 dark:text-white"
className="pr-8 read-only:cursor-pointer read-only:text-gray-500"
/>
<Button
variant="ghost"
Expand Down
Loading