From 892ae055df2a902674bb4496fe6bd6d37d95d64b Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 7 Dec 2023 15:47:32 -0800 Subject: [PATCH 1/4] Fixing issue where the response format wasn't settable using the swagger as we had it. --- sdk/ai/azopenai/autorest.md | 16 +++- .../azopenai/client_chat_completions_test.go | 38 +++++++++ sdk/ai/azopenai/constants.go | 20 ----- sdk/ai/azopenai/interfaces.go | 9 ++ sdk/ai/azopenai/models.go | 42 +++++++++- sdk/ai/azopenai/models_serde.go | 83 ++++++++++++++++++- sdk/ai/azopenai/polymorphic_helpers.go | 23 +++++ 7 files changed, 208 insertions(+), 23 deletions(-) diff --git a/sdk/ai/azopenai/autorest.md b/sdk/ai/azopenai/autorest.md index b7bbb9c684f5..b43ee85b4b73 100644 --- a/sdk/ai/azopenai/autorest.md +++ b/sdk/ai/azopenai/autorest.md @@ -4,7 +4,7 @@ These settings apply only when `--go` is specified on the command line. ``` yaml input-file: -- https://github.com/Azure/azure-rest-api-specs/blob/d402f685809d6d08be9c0b45065cadd7d78ab870/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json +- https://github.com/Azure/azure-rest-api-specs/blob/3e0e2a93ddb3c9c44ff1baf4952baa24ca98e9db/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json output-folder: ../azopenai clear-output-folder: false @@ -98,6 +98,20 @@ directive: transform: return $.replace(/InternalOYDAuthTypeRename/g, "configType") ``` +`ChatCompletionsResponseFormat.Type` + +```yaml +directive: + - from: swagger-document + where: $.definitions.ChatCompletionsResponseFormat + transform: $.properties.type["x-ms-client-name"] = "InternalChatCompletionsResponseFormat" + - from: + - models.go + - models_serde.go + where: $ + transform: return $.replace(/InternalChatCompletionsResponseFormat/g, "respType") +``` + ## Model -> DeploymentName ```yaml diff --git a/sdk/ai/azopenai/client_chat_completions_test.go b/sdk/ai/azopenai/client_chat_completions_test.go index a86cd11ffb4e..a0cda81a2468 100644 --- a/sdk/ai/azopenai/client_chat_completions_test.go +++ b/sdk/ai/azopenai/client_chat_completions_test.go @@ -8,6 +8,7 @@ package azopenai_test import ( "context" + "encoding/json" "errors" "io" "net/http" @@ -262,3 +263,40 @@ func TestClient_OpenAI_GetChatCompletions_Vision(t *testing.T) { t.Logf(*resp.Choices[0].Message.Content) } + +func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) { + testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string) { + body := azopenai.ChatCompletionsOptions{ + DeploymentName: &deploymentName, + Messages: []azopenai.ChatRequestMessageClassification{ + &azopenai.ChatRequestSystemMessage{Content: to.Ptr("You are a helpful assistant designed to output JSON.")}, + &azopenai.ChatRequestUserMessage{ + Content: azopenai.NewChatRequestUserMessageContent("List capital cities and their states"), + }, + }, + // Without this format directive you end up getting JSON, but with a non-JSON preamble, like this: + // "I'm happy to help! Here are some examples of capital cities and their corresponding states:\n\n```json\n{\n" (etc) + ResponseFormat: &azopenai.ChatCompletionsJSONResponseFormat{}, + Temperature: to.Ptr[float32](0.0), + } + + resp, err := chatClient.GetChatCompletions(context.Background(), body, nil) + require.NoError(t, err) + + // validate that it came back as JSON data + var v any + err = json.Unmarshal([]byte(*resp.Choices[0].Message.Content), &v) + require.NoError(t, err) + require.NotEmpty(t, v) + } + + t.Run("OpenAI", func(t *testing.T) { + chatClient := newOpenAIClientForTest(t) + testFn(t, chatClient, "gpt-3.5-turbo-1106") + }) + + t.Run("AzureOpenAI", func(t *testing.T) { + chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint) + testFn(t, chatClient, "gpt-4-1106-preview") + }) +} diff --git a/sdk/ai/azopenai/constants.go b/sdk/ai/azopenai/constants.go index d582f6e468b2..0599b7d67761 100644 --- a/sdk/ai/azopenai/constants.go +++ b/sdk/ai/azopenai/constants.go @@ -181,26 +181,6 @@ func PossibleChatCompletionRequestMessageContentPartTypeValues() []ChatCompletio } } -// ChatCompletionsResponseFormat - The valid response formats Chat Completions can provide. Used to enable JSON mode. -type ChatCompletionsResponseFormat string - -const ( - // ChatCompletionsResponseFormatJSONObject - Use a response format that guarantees emission of a valid JSON object. Only structure - // is guaranteed and contents must - // still be validated. - ChatCompletionsResponseFormatJSONObject ChatCompletionsResponseFormat = "json_object" - // ChatCompletionsResponseFormatText - Use the default, plain text response format. - ChatCompletionsResponseFormatText ChatCompletionsResponseFormat = "text" -) - -// PossibleChatCompletionsResponseFormatValues returns the possible values for the ChatCompletionsResponseFormat const type. -func PossibleChatCompletionsResponseFormatValues() []ChatCompletionsResponseFormat { - return []ChatCompletionsResponseFormat{ - ChatCompletionsResponseFormatJSONObject, - ChatCompletionsResponseFormatText, - } -} - // ChatRole - A description of the intended purpose of a message within a chat completions interaction. type ChatRole string diff --git a/sdk/ai/azopenai/interfaces.go b/sdk/ai/azopenai/interfaces.go index 50d7f51cf05d..0a01f3046248 100644 --- a/sdk/ai/azopenai/interfaces.go +++ b/sdk/ai/azopenai/interfaces.go @@ -27,6 +27,15 @@ type ChatCompletionRequestMessageContentPartClassification interface { GetChatCompletionRequestMessageContentPart() *ChatCompletionRequestMessageContentPart } +// ChatCompletionsResponseFormatClassification provides polymorphic access to related types. +// Call the interface's GetChatCompletionsResponseFormat() method to access the common type. +// Use a type switch to determine the concrete type. The possible types are: +// - *ChatCompletionsJSONResponseFormat, *ChatCompletionsResponseFormat, *ChatCompletionsTextResponseFormat +type ChatCompletionsResponseFormatClassification interface { + // GetChatCompletionsResponseFormat returns the ChatCompletionsResponseFormat content of the underlying type. + GetChatCompletionsResponseFormat() *ChatCompletionsResponseFormat +} + // ChatCompletionsToolCallClassification provides polymorphic access to related types. // Call the interface's GetChatCompletionsToolCall() method to access the common type. // Use a type switch to determine the concrete type. The possible types are: diff --git a/sdk/ai/azopenai/models.go b/sdk/ai/azopenai/models.go index 35e9bff6ee05..f47ec664b313 100644 --- a/sdk/ai/azopenai/models.go +++ b/sdk/ai/azopenai/models.go @@ -634,6 +634,20 @@ func (c *ChatCompletionsFunctionToolDefinition) GetChatCompletionsToolDefinition } } +// ChatCompletionsJSONResponseFormat - A response format for Chat Completions that restricts responses to emitting valid JSON +// objects. +type ChatCompletionsJSONResponseFormat struct { + // REQUIRED; The discriminated type for the response format. + respType *string +} + +// GetChatCompletionsResponseFormat implements the ChatCompletionsResponseFormatClassification interface for type ChatCompletionsJSONResponseFormat. +func (c *ChatCompletionsJSONResponseFormat) GetChatCompletionsResponseFormat() *ChatCompletionsResponseFormat { + return &ChatCompletionsResponseFormat{ + respType: c.respType, + } +} + // ChatCompletionsOptions - The configuration information for a chat completions request. Completions support a wide variety // of tasks and generate text that continues from or "completes" provided prompt data. type ChatCompletionsOptions struct { @@ -689,7 +703,7 @@ type ChatCompletionsOptions struct { PresencePenalty *float32 // An object specifying the format that the model must output. Used to enable JSON mode. - ResponseFormat *ChatCompletionsResponseFormat + ResponseFormat ChatCompletionsResponseFormatClassification // If specified, the system will make a best effort to sample deterministically such that repeated requests with the same // seed and parameters should return the same result. Determinism is not guaranteed, @@ -722,6 +736,32 @@ type ChatCompletionsOptions struct { User *string } +// ChatCompletionsResponseFormat - An abstract representation of a response format configuration usable by Chat Completions. +// Can be used to enable JSON mode. +type ChatCompletionsResponseFormat struct { + // REQUIRED; The discriminated type for the response format. + respType *string +} + +// GetChatCompletionsResponseFormat implements the ChatCompletionsResponseFormatClassification interface for type ChatCompletionsResponseFormat. +func (c *ChatCompletionsResponseFormat) GetChatCompletionsResponseFormat() *ChatCompletionsResponseFormat { + return c +} + +// ChatCompletionsTextResponseFormat - The standard Chat Completions response format that can freely generate text and is +// not guaranteed to produce response content that adheres to a specific schema. +type ChatCompletionsTextResponseFormat struct { + // REQUIRED; The discriminated type for the response format. + respType *string +} + +// GetChatCompletionsResponseFormat implements the ChatCompletionsResponseFormatClassification interface for type ChatCompletionsTextResponseFormat. +func (c *ChatCompletionsTextResponseFormat) GetChatCompletionsResponseFormat() *ChatCompletionsResponseFormat { + return &ChatCompletionsResponseFormat{ + respType: c.respType, + } +} + // ChatCompletionsToolCall - An abstract representation of a tool call that must be resolved in a subsequent request to perform // the requested chat completion. type ChatCompletionsToolCall struct { diff --git a/sdk/ai/azopenai/models_serde.go b/sdk/ai/azopenai/models_serde.go index bd6dbe637ea3..55fbdb549fe9 100644 --- a/sdk/ai/azopenai/models_serde.go +++ b/sdk/ai/azopenai/models_serde.go @@ -1271,6 +1271,33 @@ func (c *ChatCompletionsFunctionToolDefinition) UnmarshalJSON(data []byte) error return nil } +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsJSONResponseFormat. +func (c ChatCompletionsJSONResponseFormat) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + objectMap["type"] = "json_object" + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsJSONResponseFormat. +func (c *ChatCompletionsJSONResponseFormat) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "type": + err = unpopulate(val, "respType", &c.respType) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type ChatCompletionsOptions. func (c ChatCompletionsOptions) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) @@ -1333,7 +1360,7 @@ func (c *ChatCompletionsOptions) UnmarshalJSON(data []byte) error { err = unpopulate(val, "PresencePenalty", &c.PresencePenalty) delete(rawMsg, key) case "response_format": - err = unpopulate(val, "ResponseFormat", &c.ResponseFormat) + c.ResponseFormat, err = unmarshalChatCompletionsResponseFormatClassification(val) delete(rawMsg, key) case "seed": err = unpopulate(val, "Seed", &c.Seed) @@ -1364,6 +1391,60 @@ func (c *ChatCompletionsOptions) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsResponseFormat. +func (c ChatCompletionsResponseFormat) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + objectMap["type"] = c.respType + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsResponseFormat. +func (c *ChatCompletionsResponseFormat) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "type": + err = unpopulate(val, "respType", &c.respType) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsTextResponseFormat. +func (c ChatCompletionsTextResponseFormat) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + objectMap["type"] = "text" + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsTextResponseFormat. +func (c *ChatCompletionsTextResponseFormat) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "type": + err = unpopulate(val, "respType", &c.respType) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolCall. func (c ChatCompletionsToolCall) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) diff --git a/sdk/ai/azopenai/polymorphic_helpers.go b/sdk/ai/azopenai/polymorphic_helpers.go index 39a44c5b11e6..ce97f0faef93 100644 --- a/sdk/ai/azopenai/polymorphic_helpers.go +++ b/sdk/ai/azopenai/polymorphic_helpers.go @@ -58,6 +58,29 @@ func unmarshalAzureChatExtensionConfigurationClassificationArray(rawMsg json.Raw return fArray, nil } +func unmarshalChatCompletionsResponseFormatClassification(rawMsg json.RawMessage) (ChatCompletionsResponseFormatClassification, error) { + if rawMsg == nil { + return nil, nil + } + var m map[string]any + if err := json.Unmarshal(rawMsg, &m); err != nil { + return nil, err + } + var b ChatCompletionsResponseFormatClassification + switch m["type"] { + case "json_object": + b = &ChatCompletionsJSONResponseFormat{} + case "text": + b = &ChatCompletionsTextResponseFormat{} + default: + b = &ChatCompletionsResponseFormat{} + } + if err := json.Unmarshal(rawMsg, b); err != nil { + return nil, err + } + return b, nil +} + func unmarshalChatCompletionsToolCallClassification(rawMsg json.RawMessage) (ChatCompletionsToolCallClassification, error) { if rawMsg == nil { return nil, nil From 02e9072ef4043dca9159570ec69350227a03b4b4 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 7 Dec 2023 15:54:57 -0800 Subject: [PATCH 2/4] Updated recordings. --- sdk/ai/azopenai/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/ai/azopenai/assets.json b/sdk/ai/azopenai/assets.json index 047ce368b0df..3699c2461b69 100644 --- a/sdk/ai/azopenai/assets.json +++ b/sdk/ai/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/ai/azopenai", - "Tag": "go/ai/azopenai_9ed7d01267" + "Tag": "go/ai/azopenai_4b8f565947" } From 8f1b0215854153952263ebd401531cf0eb96f7a2 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 7 Dec 2023 17:44:42 -0800 Subject: [PATCH 3/4] - Fixing bug where ToolChoice was unmodeled. - Pushing release date to Monday. --- sdk/ai/azopenai/CHANGELOG.md | 2 +- sdk/ai/azopenai/assets.json | 2 +- sdk/ai/azopenai/autorest.md | 19 ++++++++- sdk/ai/azopenai/client_functions_test.go | 40 ++++++++++++++++-- sdk/ai/azopenai/custom_models.go | 53 ++++++++++++++++++++++++ sdk/ai/azopenai/models.go | 2 +- sdk/ai/azopenai/models_serde.go | 2 +- 7 files changed, 111 insertions(+), 9 deletions(-) diff --git a/sdk/ai/azopenai/CHANGELOG.md b/sdk/ai/azopenai/CHANGELOG.md index ebb1aeb9d43f..f8373965323e 100644 --- a/sdk/ai/azopenai/CHANGELOG.md +++ b/sdk/ai/azopenai/CHANGELOG.md @@ -1,6 +1,6 @@ # Release History -## 0.4.0 (2023-12-07) +## 0.4.0 (2023-12-11) Support for many of the features mentioned in OpenAI's November Dev Day and Microsoft's 2023 Ignite conference diff --git a/sdk/ai/azopenai/assets.json b/sdk/ai/azopenai/assets.json index 3699c2461b69..3cf47ad12d18 100644 --- a/sdk/ai/azopenai/assets.json +++ b/sdk/ai/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/ai/azopenai", - "Tag": "go/ai/azopenai_4b8f565947" + "Tag": "go/ai/azopenai_d4fd4783ec" } diff --git a/sdk/ai/azopenai/autorest.md b/sdk/ai/azopenai/autorest.md index b43ee85b4b73..a56983b793aa 100644 --- a/sdk/ai/azopenai/autorest.md +++ b/sdk/ai/azopenai/autorest.md @@ -5,7 +5,6 @@ These settings apply only when `--go` is specified on the command line. ``` yaml input-file: - https://github.com/Azure/azure-rest-api-specs/blob/3e0e2a93ddb3c9c44ff1baf4952baa24ca98e9db/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json - output-folder: ../azopenai clear-output-folder: false module: github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai @@ -585,3 +584,21 @@ directive: return $.replace(/(func \(c ChatCompletionsOptions\) MarshalJSON\(\).+?populate\(objectMap, "frequency_penalty", c.FrequencyPenalty\))/s, "$1\n" + populateLines) ``` + +Fix ToolChoice discriminated union + +```yaml +directive: + - from: swagger-document + where: $.definitions.ChatCompletionsOptions.properties + transform: $["tool_choice"]["x-ms-client-name"] = "ToolChoiceRenameMe" + - from: + - models.go + - models_serde.go + where: $ + transform: | + return $ + .replace(/^\s+ToolChoiceRenameMe.+$/m, "ToolChoice *ChatCompletionsToolChoice") // update the name _and_ type for the field + .replace(/ToolChoiceRenameMe/g, "ToolChoice") // rename all other references + .replace(/populateAny\(objectMap, "tool_choice", c\.ToolChoice\)/, 'populate(objectMap, "tool_choice", c.ToolChoice)'); // treat field as typed so nil means omit. +``` diff --git a/sdk/ai/azopenai/client_functions_test.go b/sdk/ai/azopenai/client_functions_test.go index f18e07e9b422..5c28f205a0e5 100644 --- a/sdk/ai/azopenai/client_functions_test.go +++ b/sdk/ai/azopenai/client_functions_test.go @@ -28,15 +28,46 @@ type ParamProperty struct { func TestGetChatCompletions_usingFunctions(t *testing.T) { // https://platform.openai.com/docs/guides/gpt/function-calling + useSpecificTool := azopenai.NewChatCompletionsToolChoice( + azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"}, + ) + t.Run("OpenAI", func(t *testing.T) { chatClient := newOpenAIClientForTest(t) - testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletions) - testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletionsLegacyFunctions) + + testData := []struct { + Model string + ToolChoice *azopenai.ChatCompletionsToolChoice + }{ + // all of these variants use the tool provided - auto just also works since we did provide + // a tool reference and ask a question to use it. + {Model: openAI.ChatCompletions, ToolChoice: nil}, + {Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto}, + {Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool}, + } + + for _, td := range testData { + testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice) + } }) t.Run("AzureOpenAI", func(t *testing.T) { chatClient := newAzureOpenAIClientForTest(t, azureOpenAI) - testChatCompletionsFunctions(t, chatClient, azureOpenAI.ChatCompletions) + + testData := []struct { + Model string + ToolChoice *azopenai.ChatCompletionsToolChoice + }{ + // all of these variants use the tool provided - auto just also works since we did provide + // a tool reference and ask a question to use it. + {Model: azureOpenAI.ChatCompletions, ToolChoice: nil}, + {Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto}, + {Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool}, + } + + for _, td := range testData { + testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice) + } }) } @@ -120,7 +151,7 @@ func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Clien require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams) } -func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string) { +func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) { body := azopenai.ChatCompletionsOptions{ DeploymentName: &deploymentName, Messages: []azopenai.ChatRequestMessageClassification{ @@ -150,6 +181,7 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep }, }, }, + ToolChoice: toolChoice, Temperature: to.Ptr[float32](0.0), } diff --git a/sdk/ai/azopenai/custom_models.go b/sdk/ai/azopenai/custom_models.go index 48d0348f8460..33429765379e 100644 --- a/sdk/ai/azopenai/custom_models.go +++ b/sdk/ai/azopenai/custom_models.go @@ -132,3 +132,56 @@ func (e *Error) Error() string { return *e.message } + +// ChatCompletionsToolChoice controls which tool is used for this ChatCompletions call. +// You can choose between: +// - [ChatCompletionsToolChoiceAuto] means the model can pick between generating a message or calling a function. +// - [ChatCompletionsToolChoiceNone] means the model will not call a function and instead generates a message +// - Use the [NewChatCompletionsToolChoice] function to specify a specific tool. +type ChatCompletionsToolChoice struct { + value any +} + +var ( + // ChatCompletionsToolChoiceAuto means the model can pick between generating a message or calling a function. + ChatCompletionsToolChoiceAuto *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "auto"} + + // ChatCompletionsToolChoiceNone means the model will not call a function and instead generates a message. + ChatCompletionsToolChoiceNone *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "none"} +) + +// NewChatCompletionsToolChoice creates a ChatCompletionsToolChoice for a specific tool. +func NewChatCompletionsToolChoice[T ChatCompletionsToolChoiceFunction](v T) *ChatCompletionsToolChoice { + return &ChatCompletionsToolChoice{value: v} +} + +// ChatCompletionsToolChoiceFunction can be used to force the model to call a particular function. +type ChatCompletionsToolChoiceFunction struct { + // Name is the name of the function to call. + Name string +} + +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoiceFunction. +func (tf ChatCompletionsToolChoiceFunction) MarshalJSON() ([]byte, error) { + type jsonInnerFunc struct { + Name string `json:"name"` + } + + type jsonFormat struct { + Type string `json:"type"` + Function jsonInnerFunc `json:"function"` + } + + return json.Marshal(jsonFormat{ + Type: "function", + //nolint:gosimple,can't use the ChatCompletionsToolChoiceFunction here or marshalling will be circular! + Function: jsonInnerFunc{ + Name: tf.Name, + }, + }) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsToolChoice. +func (tc ChatCompletionsToolChoice) MarshalJSON() ([]byte, error) { + return json.Marshal(tc.value) +} diff --git a/sdk/ai/azopenai/models.go b/sdk/ai/azopenai/models.go index f47ec664b313..ad78bb97283a 100644 --- a/sdk/ai/azopenai/models.go +++ b/sdk/ai/azopenai/models.go @@ -720,7 +720,7 @@ type ChatCompletionsOptions struct { Temperature *float32 // If specified, the model will configure which of the provided tools it can use for the chat completions response. - ToolChoice any + ToolChoice *ChatCompletionsToolChoice // The available tool definitions that the chat completions request can use, including caller-defined functions. Tools []ChatCompletionsToolDefinitionClassification diff --git a/sdk/ai/azopenai/models_serde.go b/sdk/ai/azopenai/models_serde.go index 55fbdb549fe9..4cbcbed31e64 100644 --- a/sdk/ai/azopenai/models_serde.go +++ b/sdk/ai/azopenai/models_serde.go @@ -1316,7 +1316,7 @@ func (c ChatCompletionsOptions) MarshalJSON() ([]byte, error) { populate(objectMap, "seed", c.Seed) populate(objectMap, "stop", c.Stop) populate(objectMap, "temperature", c.Temperature) - populateAny(objectMap, "tool_choice", c.ToolChoice) + populate(objectMap, "tool_choice", c.ToolChoice) populate(objectMap, "tools", c.Tools) populate(objectMap, "top_p", c.TopP) populate(objectMap, "user", c.User) From c30325cd4c1057c893ec3ff36206d8b7cbacc6e2 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Fri, 8 Dec 2023 10:52:53 -0800 Subject: [PATCH 4/4] Fixing doc comment --- sdk/ai/azopenai/custom_models.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/ai/azopenai/custom_models.go b/sdk/ai/azopenai/custom_models.go index 33429765379e..4c9bb7038fd5 100644 --- a/sdk/ai/azopenai/custom_models.go +++ b/sdk/ai/azopenai/custom_models.go @@ -181,7 +181,7 @@ func (tf ChatCompletionsToolChoiceFunction) MarshalJSON() ([]byte, error) { }) } -// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsToolChoice. +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoice. func (tc ChatCompletionsToolChoice) MarshalJSON() ([]byte, error) { return json.Marshal(tc.value) }