Skip to content

Commit

Permalink
Add support for AzureAI (#38)
Browse files Browse the repository at this point in the history
* Update the config command to set OpenAI deployments.

* Add documentation for using Azure OpenAI.
  • Loading branch information
jlewi authored Apr 9, 2024
1 parent 75db6ad commit fa47aed
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 20 deletions.
29 changes: 23 additions & 6 deletions app/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,30 @@ func NewSetConfigCmd() *cobra.Command {
}

pieces := strings.Split(args[0], "=")
if len(pieces) < 2 {
return errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgName := pieces[0]
cfgValue := pieces[1]
viper.Set(cfgName, cfgValue)
fConfig := config.GetConfig()

var fConfig *config.Config
switch cfgName {
case "azureOpenAI.deployments":
if len(pieces) != 3 {
return errors.New("Invalid argument; argument is not in the form azureOpenAI.deployments=<model>=<deployment>")
}

d := config.AzureDeployment{
Model: pieces[1],
Deployment: pieces[2],
}

fConfig = config.GetConfig()
config.SetAzureDeployment(fConfig, d)
default:
if len(pieces) < 2 {
return errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgValue := pieces[1]
viper.Set(cfgName, cfgValue)
fConfig = config.GetConfig()
}

file := viper.ConfigFileUsed()
if file == "" {
Expand Down
2 changes: 1 addition & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (a *Agent) completeWithRetries(ctx context.Context, req *v1alpha1.GenerateR
},
}
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo0125,
Model: oai.DefaultModel,
Messages: messages,
MaxTokens: 2000,
Temperature: temperature,
Expand Down
19 changes: 19 additions & 0 deletions app/pkg/config/azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package config

func SetAzureDeployment(cfg *Config, d AzureDeployment) {
if cfg.AzureOpenAI == nil {
cfg.AzureOpenAI = &AzureOpenAIConfig{}
}
if cfg.AzureOpenAI.Deployments == nil {
cfg.AzureOpenAI.Deployments = make([]AzureDeployment, 0, 1)
}
// First check if there is a deployment for the model and if there is update it
for i := range cfg.AzureOpenAI.Deployments {
if cfg.AzureOpenAI.Deployments[i].Model == d.Model {
cfg.AzureOpenAI.Deployments[i].Deployment = d.Deployment
return
}
}

cfg.AzureOpenAI.Deployments = append(cfg.AzureOpenAI.Deployments, d)
}
27 changes: 27 additions & 0 deletions app/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Config struct {
Server ServerConfig `json:"server" yaml:"server"`
Assets AssetConfig `json:"assets" yaml:"assets"`
OpenAI OpenAIConfig `json:"openai" yaml:"openai"`
// AzureOpenAI contains configuration for Azure OpenAI. A non nil value means use Azure OpenAI.
AzureOpenAI *AzureOpenAIConfig `json:"azureOpenAI,omitempty" yaml:"azureOpenAI,omitempty"`
}

// ServerConfig configures the server
Expand Down Expand Up @@ -71,6 +73,31 @@ type OpenAIConfig struct {
APIKeyFile string `json:"apiKeyFile" yaml:"apiKeyFile"`
}

type AzureOpenAIConfig struct {
// APIKeyFile is the path to the file containing the API key
APIKeyFile string `json:"apiKeyFile" yaml:"apiKeyFile"`

// BaseURL is the baseURL for the API.
// This can be obtained using the Azure CLI with the command:
// az cognitiveservices account show \
// --name <myResourceName> \
// --resource-group <myResourceGroupName> \
// | jq -r .properties.endpoint
BaseURL string `json:"baseURL" yaml:"baseURL"`

// Deployments is a list of Azure deployments of various models.
Deployments []AzureDeployment `json:"deployments" yaml:"deployments"`
}

type AzureDeployment struct {
// Deployment is the Azure Deployment name
Deployment string `json:"deployment" yaml:"deployment"`

// Model is the OpenAI name for this model
// This is used to map OpenAI models to Azure deployments
Model string `json:"model" yaml:"model"`
}

type CorsConfig struct {
// AllowedOrigins is a list of origins allowed to make cross-origin requests.
AllowedOrigins []string `json:"allowedOrigins" yaml:"allowedOrigins"`
Expand Down
120 changes: 107 additions & 13 deletions app/pkg/oai/client.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
package oai

import (
"net/url"
"strings"

"github.com/go-logr/zapr"
"go.uber.org/zap"

"github.com/hashicorp/go-retryablehttp"
"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/hydros/pkg/files"
"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
)

const (
DefaultModel = openai.GPT3Dot5Turbo0125

// AzureOpenAIVersion is the version of the Azure OpenAI API to use.
// For a list of versions see:
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
AzureOpenAIVersion = "2024-02-01"
)

// NewClient helper function to create a new OpenAI client from a config
func NewClient(cfg config.Config) (*openai.Client, error) {
if cfg.OpenAI.APIKeyFile == "" {
return nil, errors.New("OpenAI APIKeyFile is required")
}
apiKeyBytes, err := files.Read(cfg.OpenAI.APIKeyFile)
if err != nil {
return nil, errors.Wrapf(err, "could not read OpenAI APIKeyFile: %v", cfg.OpenAI.APIKeyFile)
}
// make sure there is no leading or trailing whitespace
apiKey := strings.TrimSpace(string(apiKeyBytes))

log := zapr.NewLogger(zap.L())
// ************************************************************************
// Setup middleware
// ************************************************************************
Expand All @@ -32,9 +36,99 @@ func NewClient(cfg config.Config) (*openai.Client, error) {
retryClient := retryablehttp.NewClient()
httpClient := retryClient.StandardClient()

clientCfg := openai.DefaultConfig(apiKey)
clientCfg.HTTPClient = httpClient
client := openai.NewClientWithConfig(clientCfg)
var clientConfig openai.ClientConfig
if cfg.AzureOpenAI != nil {
var clientErr error
clientConfig, clientErr = buildAzureConfig(cfg)

if clientErr != nil {
return nil, clientErr
}
} else {
log.Info("Configuring OpenAI client")
apiKey, err := readAPIKey(cfg.OpenAI.APIKeyFile)
if err != nil {
return nil, err
}
clientConfig = openai.DefaultConfig(apiKey)
}
clientConfig.HTTPClient = httpClient
client := openai.NewClientWithConfig(clientConfig)

return client, nil
}

// buildAzureConfig helper function to create a new Azure OpenAI client config
func buildAzureConfig(cfg config.Config) (openai.ClientConfig, error) {
apiKey, err := readAPIKey(cfg.AzureOpenAI.APIKeyFile)
if err != nil {
return openai.ClientConfig{}, err
}
u, err := url.Parse(cfg.AzureOpenAI.BaseURL)
if err != nil {
return openai.ClientConfig{}, errors.Wrapf(err, "could not parse Azure OpenAI BaseURL: %v", cfg.AzureOpenAI.BaseURL)
}

if u.Scheme != "https" {
return openai.ClientConfig{}, errors.Errorf("Azure BaseURL %s is not valid; it must use the scheme https", cfg.AzureOpenAI.BaseURL)
}

// Check that all required models are deployed
required := map[string]bool{
DefaultModel: true,
}

for _, d := range cfg.AzureOpenAI.Deployments {
delete(required, d.Model)
}

if len(required) > 0 {
models := make([]string, 0, len(required))
for m := range required {
models = append(models, m)
}
return openai.ClientConfig{}, errors.Errorf("Missing Azure deployments for for OpenAI models %v; update AzureOpenAIConfig.deployments in your configuration to specify deployments for these models ", strings.Join(models, ", "))
}
log := zapr.NewLogger(zap.L())
log.Info("Configuring Azure OpenAI", "baseURL", cfg.AzureOpenAI.BaseURL, "deployments", cfg.AzureOpenAI.Deployments)
clientConfig := openai.DefaultAzureConfig(apiKey, cfg.AzureOpenAI.BaseURL)
clientConfig.APIVersion = AzureOpenAIVersion
mapper := AzureModelMapper{
modelToDeployment: make(map[string]string),
}
for _, m := range cfg.AzureOpenAI.Deployments {
mapper.modelToDeployment[m.Model] = m.Deployment
}
clientConfig.AzureModelMapperFunc = mapper.Map

return clientConfig, nil
}

// AzureModelMapper maps OpenAI models to Azure deployments
type AzureModelMapper struct {
modelToDeployment map[string]string
}

// Map maps an OpenAI model to an Azure deployment
func (m AzureModelMapper) Map(model string) string {
log := zapr.NewLogger(zap.L())
deployment, ok := m.modelToDeployment[model]
if !ok {
log.Error(errors.Errorf("No AzureAI deployment found for model %v", model), "missing deployment", "model", model)
return "missing-deployment"
}
return deployment
}

func readAPIKey(apiKeyFile string) (string, error) {
if apiKeyFile == "" {
return "", errors.New("APIKeyFile is required")
}
apiKeyBytes, err := files.Read(apiKeyFile)
if err != nil {
return "", errors.Wrapf(err, "could not read APIKeyFile: %v", apiKeyFile)
}
// make sure there is no leading or trailing whitespace
apiKey := strings.TrimSpace(string(apiKeyBytes))
return apiKey, nil
}
45 changes: 45 additions & 0 deletions app/pkg/oai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package oai

import (
"os"
"testing"

"github.com/jlewi/foyle/app/pkg/config"
)

func Test_BuildAzureAIConfig(t *testing.T) {
f, err := os.CreateTemp("", "key.txt")
if err != nil {
t.Fatalf("Error creating temp file: %v", err)
}
if _, err := f.WriteString("somekey"); err != nil {
t.Fatalf("Error writing to temp file: %v", err)
}

cfg := &config.Config{
AzureOpenAI: &config.AzureOpenAIConfig{
APIKeyFile: f.Name(),
BaseURL: "https://someurl.com",
Deployments: []config.AzureDeployment{
{
Model: DefaultModel,
Deployment: "somedeployment",
},
},
},
}

if err := f.Close(); err != nil {
t.Fatalf("Error closing temp file: %v", err)
}
defer os.Remove(f.Name())

clientConfig, err := buildAzureConfig(*cfg)
if err != nil {
t.Fatalf("Error building Azure config: %v", err)
}

if clientConfig.BaseURL != "https://someurl.com" {
t.Fatalf("Expected BaseURL to be https://someurl.com but got %v", clientConfig.BaseURL)
}
}
Loading

0 comments on commit fa47aed

Please sign in to comment.