Skip to content

Commit

Permalink
feat: add ability to pass request-specific env vars to chat completion
Browse files Browse the repository at this point in the history
This will allow authentication per-request in model providers.

Signed-off-by: Donnie Adams <donnie@acorn.io>
  • Loading branch information
thedadams committed Nov 4, 2024
1 parent 50489f2 commit bda5f60
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 52 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131 h1:y2FcmT4X8U606gUS0teX5+JWX9K/NclsLEhHiyrd+EU=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA=
Expand Down
11 changes: 0 additions & 11 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,3 @@ func GetLogger(ctx context.Context) mvl.Logger {

return l
}

type envKey struct{}

func WithEnv(ctx context.Context, env []string) context.Context {
return context.WithValue(ctx, envKey{}, env)
}

func GetEnv(ctx context.Context) []string {
l, _ := ctx.Value(envKey{}).([]string)
return l
}
5 changes: 2 additions & 3 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/config"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
ProxyInfo() (string, string, error)
}

Expand Down Expand Up @@ -389,7 +388,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
}
}()

resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress)
resp, err := e.Model.Call(ctx, state.Completion, e.Env, progress)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/llm/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {

var (
model string
data = map[string]any{}
data map[string]any
)

if json.Unmarshal(inBytes, &data) == nil {
Expand All @@ -65,7 +65,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
model = builtin.GetDefaultModel()
}

c, err := r.getClient(req.Context(), model)
c, err := r.getClient(req.Context(), model, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
16 changes: 8 additions & 8 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

type Client interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
ListModels(ctx context.Context, providers ...string) (result []string, _ error)
Supports(ctx context.Context, modelName string) (bool, error)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func (r *Registry) fastPath(modelName string) Client {
return r.clients[0]
}

func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
func (r *Registry) getClient(ctx context.Context, modelName string, env []string) (Client, error) {
if c := r.fastPath(modelName); c != nil {
return c, nil
}
Expand All @@ -101,7 +101,7 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, modelName)
Expand All @@ -119,13 +119,13 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err
return nil, errors.Join(errs...)
}

func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if messageRequest.Model == "" {
return nil, fmt.Errorf("model is required")
}

if c := r.fastPath(messageRequest.Model); c != nil {
return c.Call(ctx, messageRequest, status)
return c.Call(ctx, messageRequest, env, status)
}

var errs []error
Expand All @@ -140,20 +140,20 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ

errs = append(errs, err)
} else if ok {
return client.Call(ctx, messageRequest, status)
return client.Call(ctx, messageRequest, env, status)
}
}

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
if err != nil {
return nil, err
} else if ok {
return oaiClient.Call(ctx, messageRequest, status)
return oaiClient.Call(ctx, messageRequest, env, status)
}
}

Expand Down
39 changes: 27 additions & 12 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -303,9 +302,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
return
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
if err := c.RetrieveAPIKey(ctx); err != nil {
if err := c.RetrieveAPIKey(ctx, env); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -401,15 +400,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
if err != nil {
return nil, err
} else if !ok {
result, err = c.call(ctx, request, id, status)
result, err = c.call(ctx, request, id, env, status)

// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
var apiError *openai.APIError
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
result, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
var (
response types.CompletionMessage
err error
Expand All @@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC
for range 10 { // maximum 10 tries
// Try to drop older messages again, with a decreased max tokens.
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
response, err = c.call(ctx, request, id, status)
response, err = c.call(ctx, request, id, env, status)
if err == nil {
return response, nil
}
Expand Down Expand Up @@ -542,7 +541,7 @@ func override(left, right string) string {
return left
}

func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false"

partial <- types.CompletionStatus{
Expand All @@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
},
}

var (
headers map[string]string
modelProviderEnv []string
)
for _, e := range env {
if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") {
modelProviderEnv = append(modelProviderEnv, e)
}
}

if len(modelProviderEnv) > 0 {
headers = map[string]string{
"X-GPTScript-Env": strings.Join(modelProviderEnv, ","),
}
}

slog.Debug("calling openai", "message", request.Messages)

if !streamResponse {
request.StreamOptions = nil
resp, err := c.c.CreateChatCompletion(ctx, request)
resp, err := c.c.CreateChatCompletion(ctx, request, headers)
if err != nil {
return types.CompletionMessage{}, err
}
Expand All @@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}), nil
}

stream, err := c.c.CreateChatCompletionStream(ctx, request)
stream, err := c.c.CreateChatCompletionStream(ctx, request, headers)
if err != nil {
return types.CompletionMessage{}, err
}
Expand Down Expand Up @@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}
}

func (c *Client) RetrieveAPIKey(ctx context.Context) error {
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx))
func (c *Client) RetrieveAPIKey(ctx context.Context, env []string) error {
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", env)
if err != nil {
return err
}
Expand Down
19 changes: 9 additions & 10 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
env2 "github.com/gptscript-ai/gptscript/pkg/env"
Expand Down Expand Up @@ -42,13 +41,13 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
}
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
_, provider := c.parseModel(messageRequest.Model)
if provider == "" {
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
}

client, err := c.load(ctx, provider)
client, err := c.load(ctx, provider, env...)
if err != nil {
return nil, err
}
Expand All @@ -60,7 +59,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
modelName = toolName
}
messageRequest.Model = modelName
return client.Call(ctx, messageRequest, status)
return client.Call(ctx, messageRequest, env, status)
}

func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
Expand Down Expand Up @@ -111,7 +110,7 @@ func isHTTPURL(toolName string) bool {
strings.HasPrefix(toolName, "https://")
}

func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Client, error) {
func (c *Client) clientFromURL(ctx context.Context, apiURL string, envs []string) (*openai.Client, error) {
parsed, err := url.Parse(apiURL)
if err != nil {
return nil, err
Expand All @@ -121,7 +120,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie

if key == "" && !isLocalhost(apiURL) {
var err error
key, err = c.retrieveAPIKey(ctx, env, apiURL)
key, err = c.retrieveAPIKey(ctx, env, apiURL, envs)
if err != nil {
return nil, err
}
Expand All @@ -134,7 +133,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
})
}

func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
func (c *Client) load(ctx context.Context, toolName string, env ...string) (*openai.Client, error) {
c.clientsLock.Lock()
defer c.clientsLock.Unlock()

Expand All @@ -144,7 +143,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
}

if isHTTPURL(toolName) {
remoteClient, err := c.clientFromURL(ctx, toolName)
remoteClient, err := c.clientFromURL(ctx, toolName, env)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -183,8 +182,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
return oClient, nil
}

func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...))
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string, envs []string) (string, error) {
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(envs, c.envs...))
}

func isLocalhost(url string) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
if err != nil {
return nil, fmt.Errorf("marshaling input for output filter: %w", err)
}
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, inputData, "", engine.OutputToolCategory)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/judge/judge.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (j *Judge[T]) Equal(ctx context.Context, expected, actual T, criteria strin
},
},
}
response, err := j.client.CreateChatCompletion(ctx, request)
response, err := j.client.CreateChatCompletion(ctx, request, nil)
if err != nil {
return false, "", fmt.Errorf("failed to create chat completion request: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/tester/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *Client) ProxyInfo() (string, string, error) {
return "test-auth", "test-url", nil
}

func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ []string, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
msgData, err := json.MarshalIndent(messageRequest, "", " ")
require.NoError(c.t, err)

Expand Down

0 comments on commit bda5f60

Please sign in to comment.