Skip to content

Commit

Permalink
Allow structured outputs via function calling (#828)
Browse files Browse the repository at this point in the history
  • Loading branch information
greysteil authored Aug 16, 2024
1 parent dd7f582 commit d86425a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
76 changes: 76 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
}
}
}

func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: "display_cases",
Strict: true,
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": {
Type: jsonschema.String,
},
"CamelCase": {
Type: jsonschema.String,
},
"KebabCase": {
Type: jsonschema.String,
},
"SnakeCase": {
Type: jsonschema.String,
},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
},
},
},
ToolChoice: openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "display_cases",
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
}
}
1 change: 1 addition & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ type ToolFunction struct {
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
// Parameters is an object describing the function.
// You can pass json.RawMessage to describe the schema,
// or you can pass in a struct which serializes to the proper JSON schema.
Expand Down
26 changes: 26 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) {
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("StructuredOutputs", func(t *testing.T) {
type testMessage struct {
Count int `json:"count"`
Words []string `json:"words"`
}
msg := testMessage{
Count: 2,
Words: []string{"hello", "world"},
}
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.GPT3Dot5Turbo0613,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []openai.FunctionDefinition{{
Name: "test",
Strict: true,
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
}

func TestAzureChatCompletions(t *testing.T) {
Expand Down

0 comments on commit d86425a

Please sign in to comment.