diff --git a/README.md b/README.md index 5188ed3..179228e 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,9 @@ func main() { option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("OPENAI_API_KEY") ) chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F([]openai.ChatCompletionContentPartUnionParam{openai.ChatCompletionContentPartTextParam{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }) if err != nil { @@ -236,10 +235,9 @@ defer cancel() client.Chat.Completions.New( ctx, openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F([]openai.ChatCompletionContentPartUnionParam{openai.ChatCompletionContentPartTextParam{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }, // This sets the per-retry timeout @@ -299,10 +297,9 @@ client := openai.NewClient( client.Chat.Completions.New( context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F([]openai.ChatCompletionContentPartUnionParam{openai.ChatCompletionContentPartTextParam{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }, option.WithMaxRetries(5), diff --git a/azure/azure_test.go b/azure/azure_test.go index ac694cc..f5b9043 100644 --- a/azure/azure_test.go +++ b/azure/azure_test.go @@ -8,21 +8,14 @@ import ( "github.com/openai/openai-go" "github.com/openai/openai-go/internal/apijson" - "github.com/openai/openai-go/shared" ) func TestJSONRoute(t *testing.T) { chatCompletionParams := openai.ChatCompletionNewParams{ Model: openai.F(openai.ChatModel("arbitraryDeployment")), Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.ChatCompletionAssistantMessageParam{ - Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), - Content: openai.F[openai.ChatCompletionAssistantMessageParamContentUnion](shared.UnionString("You are a helpful assistant")), - }, - openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), - }, + openai.AssistantMessage("You are a helpful assistant"), + openai.UserMessage("Can you tell me another word for the universe?"), }), } @@ -95,14 +88,8 @@ func TestNoRouteChangeNeeded(t *testing.T) { chatCompletionParams := openai.ChatCompletionNewParams{ Model: openai.F(openai.ChatModel("arbitraryDeployment")), Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.ChatCompletionAssistantMessageParam{ - Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), - Content: openai.F[openai.ChatCompletionAssistantMessageParamContentUnion](shared.UnionString("You are a helpful assistant")), - }, - openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), - }, + openai.AssistantMessage("You are a helpful assistant"), + openai.UserMessage("Can you tell me another word for the universe?"), }), } diff --git a/chatcompletion.go b/chatcompletion.go index 864910b..688f44a 100644 --- a/chatcompletion.go +++ b/chatcompletion.go @@ -12,8 +12,79 @@ import ( "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/ssestream" "github.com/openai/openai-go/shared" + "github.com/tidwall/sjson" ) +func UserMessage(content string) ChatCompletionMessageParamUnion { + return UserMessageParts(TextPart(content)) +} + +func UserMessageParts(parts ...ChatCompletionContentPartUnionParam) ChatCompletionUserMessageParam { + return ChatCompletionUserMessageParam{ + Role: F(ChatCompletionUserMessageParamRoleUser), + Content: F(parts), + } +} + +func TextPart(content string) ChatCompletionContentPartTextParam { + return ChatCompletionContentPartTextParam{ + Type: F(ChatCompletionContentPartTextTypeText), + Text: F(content), + } +} + +func RefusalPart(refusal string) ChatCompletionContentPartRefusalParam { + return ChatCompletionContentPartRefusalParam{ + Type: F(ChatCompletionContentPartRefusalTypeRefusal), + Refusal: F(refusal), + } +} + +func ImagePart(url string) ChatCompletionContentPartImageParam { + return ChatCompletionContentPartImageParam{ + Type: F(ChatCompletionContentPartImageTypeImageURL), + ImageURL: F(ChatCompletionContentPartImageImageURLParam{ + URL: F(url), + }), + } +} + +func AssistantMessage(content string) ChatCompletionAssistantMessageParam { + return ChatCompletionAssistantMessageParam{ + Role: F(ChatCompletionAssistantMessageParamRoleAssistant), + Content: F([]ChatCompletionAssistantMessageParamContentUnion{ + TextPart(content), + }), + } +} + +func ToolMessage(toolCallID, content string) ChatCompletionToolMessageParam { + return ChatCompletionToolMessageParam{ + Role: F(ChatCompletionToolMessageParamRoleTool), + ToolCallID: F(toolCallID), + Content: F([]ChatCompletionContentPartTextParam{ + TextPart(content), + }), + } +} + +func SystemMessage(content string) ChatCompletionMessageParamUnion { + return ChatCompletionSystemMessageParam{ + Role: F(ChatCompletionSystemMessageParamRoleSystem), + Content: F([]ChatCompletionContentPartTextParam{ + TextPart(content), + }), + } +} + +func FunctionMessage(name, content string) ChatCompletionMessageParamUnion { + return ChatCompletionFunctionMessageParam{ + Role: F(ChatCompletionFunctionMessageParamRoleFunction), + Name: F(name), + Content: F(content), + } +} + // ChatCompletionService contains methods and other services that help with // interacting with the openai API. // @@ -870,10 +941,35 @@ func (r *ChatCompletionMessage) UnmarshalJSON(data []byte) (err error) { return apijson.UnmarshalRoot(data, r) } +func (r ChatCompletionMessage) MarshalJSON() (data []byte, err error) { + s := "" + s, _ = sjson.Set(s, "role", r.Role) + + if r.FunctionCall.Name != "" { + b, err := apijson.Marshal(r.FunctionCall) + if err != nil { + return nil, err + } + s, _ = sjson.SetRaw(s, "function_call", string(b)) + } else if len(r.ToolCalls) > 0 { + b, err := apijson.Marshal(r.ToolCalls) + if err != nil { + return nil, err + } + s, _ = sjson.SetRaw(s, "tool_calls", string(b)) + } else { + s, _ = sjson.Set(s, "content", r.Content) + } + + return []byte(s), nil +} + func (r chatCompletionMessageJSON) RawJSON() string { return r.raw } +func (r ChatCompletionMessage) implementsChatCompletionMessageParamUnion() {} + // The role of the author of this message. type ChatCompletionMessageRole string @@ -944,6 +1040,8 @@ func (r ChatCompletionMessageParam) implementsChatCompletionMessageParamUnion() // [ChatCompletionUserMessageParam], [ChatCompletionAssistantMessageParam], // [ChatCompletionToolMessageParam], [ChatCompletionFunctionMessageParam], // [ChatCompletionMessageParam]. +// +// This union is additionally satisfied by the return types [ChatCompletionMessage] type ChatCompletionMessageParamUnion interface { implementsChatCompletionMessageParamUnion() } diff --git a/examples/audio-text-to-speech/main.go b/examples/audio-text-to-speech/main.go index 7a268b3..c4f558b 100644 --- a/examples/audio-text-to-speech/main.go +++ b/examples/audio-text-to-speech/main.go @@ -13,7 +13,7 @@ func main() { ctx := context.Background() res, err := client.Audio.Speech.New(ctx, openai.AudioSpeechNewParams{ - Model: openai.F(openai.AudioSpeechNewParamsModelTTS1), + Model: openai.F(openai.AudioModelWhisper1), Input: openai.String(`Why did the chicken cross the road? To get to the other side.`), ResponseFormat: openai.F(openai.AudioSpeechNewParamsResponseFormatPCM), Voice: openai.F(openai.AudioSpeechNewParamsVoiceAlloy), diff --git a/examples/audio-transcriptions/main.go b/examples/audio-transcriptions/main.go index f2f7f2e..354a1d3 100644 --- a/examples/audio-transcriptions/main.go +++ b/examples/audio-transcriptions/main.go @@ -18,7 +18,7 @@ func main() { } transcription, err := client.Audio.Transcriptions.New(ctx, openai.AudioTranscriptionNewParams{ - Model: openai.F(openai.AudioTranscriptionNewParamsModelWhisper1), + Model: openai.F(openai.AudioModelWhisper1), File: openai.F[io.Reader](file), }) if err != nil { diff --git a/examples/chat-completion-streaming/main.go b/examples/chat-completion-streaming/main.go index 9591973..19c8f61 100644 --- a/examples/chat-completion-streaming/main.go +++ b/examples/chat-completion-streaming/main.go @@ -21,13 +21,15 @@ func main() { Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ openai.UserMessage(question), }), - Seed: openai.Int(0), + Seed: openai.Int(1), Model: openai.F(openai.ChatModelGPT4o), }) for stream.Next() { evt := stream.Current() - print(evt.Choices[0].Delta.Content) + if len(evt.Choices) > 0 { + print(evt.Choices[0].Delta.Content) + } } println() diff --git a/examples/chat-completion/main.go b/examples/chat-completion/main.go index a845b5c..399c294 100644 --- a/examples/chat-completion/main.go +++ b/examples/chat-completion/main.go @@ -21,7 +21,7 @@ func main() { Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ openai.UserMessage(question), }), - Seed: openai.Int(0), + Seed: openai.Int(1), Model: openai.F(openai.ChatModelGPT4o), }) if err != nil { diff --git a/examples/image-generation/main.go b/examples/image-generation/main.go index 533e6f4..fe3939e 100644 --- a/examples/image-generation/main.go +++ b/examples/image-generation/main.go @@ -23,7 +23,7 @@ func main() { image, err := client.Images.Generate(ctx, openai.ImageGenerateParams{ Prompt: openai.String(prompt), - Model: openai.F(openai.ImageGenerateParamsModelDallE3), + Model: openai.F(openai.ImageModelDallE3), ResponseFormat: openai.F(openai.ImageGenerateParamsResponseFormatURL), N: openai.Int(1), }) @@ -38,7 +38,7 @@ func main() { image, err = client.Images.Generate(ctx, openai.ImageGenerateParams{ Prompt: openai.String(prompt), - Model: openai.F(openai.ImageGenerateParamsModelDallE3), + Model: openai.F(openai.ImageModelDallE3), ResponseFormat: openai.F(openai.ImageGenerateParamsResponseFormatB64JSON), N: openai.Int(1), }) diff --git a/go.work b/go.work new file mode 100644 index 0000000..a864d2f --- /dev/null +++ b/go.work @@ -0,0 +1,3 @@ +go 1.22.4 + +use ./examples diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..4efc870 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,24 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=