Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Functions Support #4683

Merged
merged 15 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ OPENAI_API_KEY=your-openai-api-key
## PROMPT_SETTINGS_FILE - Specifies which Prompt Settings file to use (defaults to prompt_settings.yaml)
# PROMPT_SETTINGS_FILE=prompt_settings.yaml

## OPENAI_API_BASE_URL - Custom url for the OpenAI API, useful for connecting to custom backends. No effect if USE_AZURE is true, leave blank to keep the default url
## OPENAI_API_BASE_URL - Custom url for the OpenAI API, useful for connecting to custom backends. No effect if USE_AZURE is true, leave blank to keep the default url
# the following is an example:
# OPENAI_API_BASE_URL=http://localhost:443/v1

## OPENAI_FUNCTIONS - Enables OpenAI functions: https://platform.openai.com/docs/guides/gpt/function-calling
# the following is an example:
# OPENAI_FUNCTIONS=False

## AUTHORISE COMMAND KEY - Key to authorise commands
# AUTHORISE_COMMAND_KEY=y

Expand Down
8 changes: 6 additions & 2 deletions autogpt/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def signal_handler(signum, frame):
)

try:
assistant_reply_json = extract_json_from_response(assistant_reply)
assistant_reply_json = extract_json_from_response(
assistant_reply.content
)
validate_json(assistant_reply_json, self.config)
except json.JSONDecodeError as e:
logger.error(f"Exception while validating assistant reply JSON: {e}")
Expand All @@ -160,7 +162,9 @@ def signal_handler(signum, frame):
print_assistant_thoughts(
self.ai_name, assistant_reply_json, self.config
)
command_name, arguments = get_command(assistant_reply_json)
command_name, arguments = get_command(
assistant_reply_json, assistant_reply, self.config
)
if self.config.speak_mode:
say_text(f"I want to execute {command_name}")

Expand Down
8 changes: 6 additions & 2 deletions autogpt/agent/agent_manager.py
Pwuts marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def create_agent(
if plugin_messages := plugin.pre_instruction(messages.raw()):
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])
# Start GPT instance
agent_reply = create_chat_completion(prompt=messages, config=self.config)
agent_reply = create_chat_completion(
prompt=messages, config=self.config
).content

messages.add("assistant", agent_reply)

Expand Down Expand Up @@ -92,7 +94,9 @@ def message_agent(self, key: str | int, message: str) -> str:
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])

# Start GPT instance
agent_reply = create_chat_completion(prompt=messages, config=self.config)
agent_reply = create_chat_completion(
prompt=messages, config=self.config
).content

messages.add("assistant", agent_reply)

Expand Down
26 changes: 20 additions & 6 deletions autogpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict

from autogpt.agent.agent import Agent
from autogpt.config import Config
from autogpt.llm import ChatModelResponse


def is_valid_int(value: str) -> bool:
Expand All @@ -21,11 +23,15 @@ def is_valid_int(value: str) -> bool:
return False


def get_command(response_json: Dict):
def get_command(
assistant_reply_json: Dict, assistant_reply: ChatModelResponse, config: Config
):
"""Parse the response and return the command name and arguments

Args:
response_json (json): The response from the AI
assistant_reply_json (dict): The response object from the AI
assistant_reply (ChatModelResponse): The model response from the AI
config (Config): The config object

Returns:
tuple: The command name and arguments
Expand All @@ -35,14 +41,22 @@ def get_command(response_json: Dict):

Exception: If any other error occurs
"""
if config.openai_functions:
assistant_reply_json["command"] = {
"name": assistant_reply.function_call.name,
"args": json.loads(assistant_reply.function_call.arguments),
}
try:
if "command" not in response_json:
if "command" not in assistant_reply_json:
return "Error:", "Missing 'command' object in JSON"

if not isinstance(response_json, dict):
return "Error:", f"'response_json' object is not dictionary {response_json}"
if not isinstance(assistant_reply_json, dict):
return (
"Error:",
f"The previous message sent was not a dictionary {assistant_reply_json}",
)

command = response_json["command"]
command = assistant_reply_json["command"]
if not isinstance(command, dict):
return "Error:", "'command' object is not a dictionary"

Expand Down
12 changes: 11 additions & 1 deletion autogpt/command_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from autogpt.config import Config
from autogpt.models.command import Command
from autogpt.models.command_argument import CommandArgument

# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
Expand All @@ -18,11 +19,20 @@ def command(
"""The command decorator is used to create Command objects from ordinary functions."""

def decorator(func: Callable[..., Any]) -> Command:
typed_arguments = [
CommandArgument(
name=arg_name,
description=argument.get("description"),
type=argument.get("type", "string"),
required=argument.get("required", False),
)
for arg_name, argument in arguments.items()
]
cmd = Command(
name=name,
description=description,
method=func,
signature=arguments,
arguments=typed_arguments,
enabled=enabled,
disabled_reason=disabled_reason,
)
Expand Down
2 changes: 1 addition & 1 deletion autogpt/config/ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,5 @@ def construct_full_prompt(
if self.api_budget > 0.0:
full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
full_prompt += f"\n\n{prompt_generator.generate_prompt_string(config)}"
return full_prompt
2 changes: 2 additions & 0 deletions autogpt/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self) -> None:
if self.openai_organization is not None:
openai.organization = self.openai_organization

self.openai_functions = os.getenv("OPENAI_FUNCTIONS", "False") == "True"

self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
# ELEVENLABS_VOICE_1_ID is deprecated and included for backwards-compatibility
self.elevenlabs_voice_id = os.getenv(
Expand Down
10 changes: 7 additions & 3 deletions autogpt/json_utils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def extract_json_from_response(response_content: str) -> dict:


def llm_response_schema(
schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT,
config: Config, schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT
) -> dict[str, Any]:
filename = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
with open(filename, "r") as f:
return json.load(f)
json_schema = json.load(f)
if config.openai_functions:
del json_schema["properties"]["command"]
json_schema["required"].remove("command")
return json_schema


def validate_json(
Expand All @@ -47,7 +51,7 @@ def validate_json(
Returns:
bool: Whether the json_object is valid or not
"""
schema = llm_response_schema(schema_name)
schema = llm_response_schema(config, schema_name)
validator = Draft7Validator(schema)

if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
Expand Down
31 changes: 31 additions & 0 deletions autogpt/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from math import ceil, floor
from typing import List, Literal, TypedDict

from autogpt.models.command_argument import CommandArgument

MessageRole = Literal["system", "user", "assistant"]
MessageType = Literal["ai_response", "action_result"]

Expand Down Expand Up @@ -157,3 +159,32 @@ class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""

content: str = None
function_call: OpenAIFunctionCall = None


@dataclass
class OpenAIFunctionSpec:
"""Represents a "function" in OpenAI, which is mapped to a Command in Auto-GPT"""

name: str
description: str
parameters: OpenAIFunctionParameter


@dataclass
class OpenAIFunctionCall:
name: str
arguments: List[CommandArgument]


@dataclass
class OpenAIFunctionParameter:
type: str
properties: OpenAIFunctionProperties
required: bool


@dataclass
class OpenAIFunctionProperties:
type: str
description: str
6 changes: 5 additions & 1 deletion autogpt/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
from typing import TYPE_CHECKING

from autogpt.llm.providers.openai import get_openai_command_specs

if TYPE_CHECKING:
from autogpt.agent.agent import Agent

Expand Down Expand Up @@ -94,6 +96,7 @@ def chat_with_ai(
current_tokens_used += count_message_tokens([user_input_msg], model)

current_tokens_used += 500 # Reserve space for new_summary_message
current_tokens_used += 500 # Reserve space for the openai functions TODO improve

# Add Messages until the token limit is reached or there are no more messages to add.
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
Expand Down Expand Up @@ -193,11 +196,12 @@ def chat_with_ai(
assistant_reply = create_chat_completion(
prompt=message_sequence,
config=agent.config,
functions=get_openai_command_specs(agent),
max_tokens=tokens_remaining,
)

# Update full message history
agent.history.append(user_input_msg)
agent.history.add("assistant", assistant_reply, "ai_response")
agent.history.add("assistant", assistant_reply.content, "ai_response")

return assistant_reply
38 changes: 38 additions & 0 deletions autogpt/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
ChatModelInfo,
EmbeddingModelInfo,
MessageDict,
OpenAIFunctionParameter,
OpenAIFunctionProperties,
OpenAIFunctionSpec,
TextModelInfo,
TText,
)
Expand Down Expand Up @@ -267,3 +270,38 @@ def create_embedding(
input=input,
**kwargs,
)


def get_openai_command_specs(agent) -> list[OpenAIFunctionSpec]:
"""Get OpenAI-consumable function specs for the agent's available commands.
see https://platform.openai.com/docs/guides/gpt/function-calling
"""
functions = []
if not agent.config.openai_functions:
return functions
for command in agent.command_registry.commands.values():
properties = {}
required = []

for argument in command.arguments:
properties[argument.name] = OpenAIFunctionProperties(
type=argument.type,
description=argument.description,
)
if argument.required:
required.append(argument.name)
parameters = OpenAIFunctionParameter(
type="object",
properties=properties,
required=required,
)

functions.append(
OpenAIFunctionSpec(
name=command.name,
description=command.description,
parameters=parameters,
)
)

return functions
31 changes: 22 additions & 9 deletions autogpt/llm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import asdict
from typing import List, Literal, Optional

from colorama import Fore
Expand All @@ -8,8 +9,9 @@
from autogpt.logs import logger

from ..api_manager import ApiManager
from ..base import ChatSequence, Message
from ..base import ChatModelResponse, ChatSequence, Message, OpenAIFunctionSpec
from ..providers import openai as iopenai
from ..providers.openai import OPEN_AI_CHAT_MODELS
from .token_counter import *


Expand Down Expand Up @@ -52,7 +54,7 @@ def call_ai_function(
Message("user", arg_str),
],
)
return create_chat_completion(prompt=prompt, temperature=0)
return create_chat_completion(prompt=prompt, temperature=0).content


def create_text_completion(
Expand Down Expand Up @@ -88,10 +90,11 @@ def create_text_completion(
def create_chat_completion(
prompt: ChatSequence,
config: Config,
functions: Optional[List[OpenAIFunctionSpec]] = [],
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> str:
) -> ChatModelResponse:
"""Create a chat completion using the OpenAI API

Args:
Expand All @@ -103,6 +106,7 @@ def create_chat_completion(
Returns:
str: The response from the chat completion
"""

if model is None:
model = prompt.model.name
if temperature is None:
Expand Down Expand Up @@ -134,26 +138,35 @@ def create_chat_completion(
chat_completion_kwargs[
"deployment_id"
] = config.get_azure_deployment_id_for_model(model)
if functions:
chat_completion_kwargs["functions"] = [
asdict(function) for function in functions
]

response = iopenai.create_chat_completion(
messages=prompt.raw(),
**chat_completion_kwargs,
)
logger.debug(f"Response: {response}")

resp = ""
if not hasattr(response, "error"):
resp = response.choices[0].message["content"]
else:
if hasattr(response, "error"):
logger.error(response.error)
raise RuntimeError(response.error)

first_message = response.choices[0].message
content = first_message["content"]
function_call = first_message.get("function_call", {})

for plugin in config.plugins:
if not plugin.can_handle_on_response():
continue
resp = plugin.on_response(resp)
content = plugin.on_response(content)

return resp
return ChatModelResponse(
model_info=OPEN_AI_CHAT_MODELS[model],
content=content,
function_call=function_call,
)


def check_model(
Expand Down
Loading