forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
completion_test.go
125 lines (115 loc) · 3.45 KB
/
completion_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)
func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
},
)
if !errors.Is(err, ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) {
config := DefaultConfig("whatever")
client := NewClientWithConfig(config)
ctx := context.Background()
req := CompletionRequest{Stream: true}
_, err := client.CreateCompletion(ctx, req)
if !errors.Is(err, ErrCompletionStreamNotSupported) {
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
}
}
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
Prompt: "Lorem ipsum",
}
_, err := client.CreateCompletion(context.Background(), req)
checks.NoError(t, err, "CreateCompletion error")
}
// handleCompletionEndpoint Handles the completion endpoint by the test server.
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt.(string) + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt.(string)) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}