diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 702352f..7cf9f4a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -95,23 +95,3 @@ jobs: GITHUB_TOKEN: ${{ steps.generate_token.outputs.token }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} SCOOP_TAP_GITHUB_TOKEN: ${{ secrets.SCOOP_TAP_GITHUB_TOKEN }} - - # Update release PR -# - uses: actions/github-script@d7906e4ad0b1822421a7e6a35d5ca353c962f410 # v6 -# if: ${{ steps.release.outputs.release_created }} -# with: -# github-token: ${{ steps.generate_token.outputs.token }} -# script: | -# github.rest.issues.removeLabel({ -# owner: context.repo.owner, -# repo: context.repo.repo, -# issue_number: ${{ steps.release.outputs.pr }}, -# name: 'autorelease: tagged' -# }); -# -# github.rest.issues.addLabels({ -# owner: context.repo.owner, -# repo: context.repo.repo, -# issue_number: ${{ steps.release.outputs.pr }}, -# labels: ['autorelease: published'] -# }); diff --git a/cli/config_test.go b/cli/config_test.go deleted file mode 100644 index b7052a2..0000000 --- a/cli/config_test.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2023 Tim -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of -// this software and associated documentation files (the "Software"), to deal in -// the Software without restriction, including without limitation the rights to -// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software is furnished to do so, -// subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// SPDX-License-Identifier: MIT - -package cli - -import ( - "path/filepath" - "testing" - - "github.com/spf13/viper" - "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/api" -) - -func TestConfigCmd(t *testing.T) { - mem := &exitMemento{} - - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) - - root.Execute([]string{"config"}) - require.Equal(t, 0, mem.code) -} - -func TestConfigCmdInit(t *testing.T) { - mem := &exitMemento{} - - configDir := t.TempDir() - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("TESTING", 1) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "init"}) - require.Equal(t, 0, mem.code) - - require.FileExists(t, filepath.Join(configDir, "config.yaml")) - // config must only contain values for model, maxtokens, temperature, topp - require.NoError(t, config.ReadInConfig()) - // TESTING may be in the config, because this is a test - require.Equal(t, 5, len(config.AllSettings())) - for _, key := range []string{"model", "maxtokens", "temperature", "topp", "testing"} { - require.Contains(t, config.AllSettings(), key) - } -} - -func TestConfigCmdInitAlreadyExists(t *testing.T) { - mem := &exitMemento{} - - configDir := t.TempDir() - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("TESTING", 1) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "init"}) - require.Equal(t, 0, mem.code) - - require.FileExists(t, filepath.Join(configDir, "config.yaml")) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "init"}) - require.Equal(t, 1, mem.code) -} - -func TestConfigCmdShowConfig(t *testing.T) { - mem := &exitMemento{} - - configDir := t.TempDir() - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("TESTING", 1) - require.NoError(t, setViperDefaults(config)) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "init"}) - require.Equal(t, 0, mem.code) - - require.FileExists(t, filepath.Join(configDir, "config.yaml")) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "show"}) - require.Equal(t, 0, mem.code) -} - -func TestConfigCmdShowConfigNonExistent(t *testing.T) { - mem := &exitMemento{} - - configDir := t.TempDir() - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("TESTING", 1) - - require.NoFileExists(t, filepath.Join(configDir, "config.yaml")) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"config", "show"}) - require.Equal(t, 1, mem.code) -} diff --git a/cmd/sgpt/main.go b/cmd/sgpt/main.go index c67d8d0..e23f68e 100644 --- a/cmd/sgpt/main.go +++ b/cmd/sgpt/main.go @@ -24,7 +24,7 @@ package main import ( "os" - "github.com/tbckr/sgpt/v2/cli" + "github.com/tbckr/sgpt/v2/pkg/cli" ) func main() { diff --git a/docs/getting-started.md b/docs/getting-started.md index 7132b9e..e807db6 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -36,6 +36,9 @@ $ echo -n "mass of sun" | sgpt The mass of the sun is approximately 1.989 x 10^30 kilograms. ``` +If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output +to the command line as it is generated. + ## Code Generation Capabilities By adding the `code` command to your prompt, you can generate code based on given instructions by using the diff --git a/docs/usage/query-models.md b/docs/usage/query-models.md index 59b37aa..cf54f18 100644 --- a/docs/usage/query-models.md +++ b/docs/usage/query-models.md @@ -9,6 +9,9 @@ $ sgpt "mass of sun" The mass of the sun is approximately 1.989 x 10^30 kilograms. ``` +If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output +to the command line as it is generated. + You can also pass prompts to SGPT using pipes: ```shell diff --git a/go.mod b/go.mod index d3c60f0..0721ac0 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jarcoal/httpmock v1.3.1 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/muesli/mango v0.1.0 // indirect diff --git a/go.sum b/go.sum index a93123b..61cd7e6 100644 --- a/go.sum +++ b/go.sum @@ -130,6 +130,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/internal/testlib/env.go b/internal/testlib/env.go new file mode 100644 index 0000000..643d22b --- /dev/null +++ b/internal/testlib/env.go @@ -0,0 +1,38 @@ +// Copyright (c) 2023 Tim +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// +// SPDX-License-Identifier: MIT + +package testlib + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func SetAPIKey(t *testing.T) { + err := os.Setenv("OPENAI_API_KEY", "test") + require.NoError(t, err) + + t.Cleanup(func() { + _ = os.Unsetenv("OPENAI_API_KEY") + }) +} diff --git a/internal/testlib/httmock.go b/internal/testlib/httmock.go new file mode 100644 index 0000000..5e7d8f5 --- /dev/null +++ b/internal/testlib/httmock.go @@ -0,0 +1,95 @@ +// Copyright (c) 2023 Tim +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// +// SPDX-License-Identifier: MIT + +package testlib + +import ( + "bytes" + "fmt" + "net/http" + + "github.com/jarcoal/httpmock" +) + +const ( + baseURL = "https://api.openai.com/v1" + chatCompletionSuffix = "/chat/completions" +) + +func RegisterExpectedChatResponse(response string) { + httpmock.RegisterResponder( + "POST", + fmt.Sprintf("%s%s", baseURL, chatCompletionSuffix), + httpmock.NewStringResponder( + 200, + fmt.Sprintf(`{ + "choices": [ + { + "index": 0, + "finish_reason": "length", + "message": { + "role": "assistant", + "content": "%s" + } + } + ] + }`, response), + ), + ) +} + +func RegisterExpectedChatResponseStream(response string) { + httpmock.RegisterResponder( + "POST", + fmt.Sprintf("%s%s", baseURL, chatCompletionSuffix), + func(request *http.Request) (*http.Response, error) { + // Reference: https://github.com/sashabaranov/go-openai/blob/a09cb0c528c110a6955a9ee9a5d021a57ed44b90/chat_stream_test.go#L39 + data := createStreamedMessages(response) + resp := httpmock.NewBytesResponse(200, data) + resp.Header.Set("Content-Type", "text/event-stream") + return resp, nil + }, + ) +} + +func createStreamedMessages(response string) []byte { + // Data is written in the format of Server-Sent Events (SSE). + // The first message is the event name, followed by the data. + // Right now, there are two events: "message" and "done". The "done" event is sent when the completion is finished. + const ( + eventMessage = "event: message\n" + + dataMessageTemplate = "data: %s\n\n" + messageTemplate = `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"%c"},"finish_reason":"max_tokens"}]}` + + eventDone = "event: done\n" + dataDone = "data: [DONE]\n\n" + ) + var buff bytes.Buffer + for _, char := range response { + buff.WriteString(eventMessage) + buff.WriteString(fmt.Sprintf(dataMessageTemplate, fmt.Sprintf(messageTemplate, char))) + } + buff.WriteString(eventDone) + buff.WriteString(dataDone) + + return buff.Bytes() +} diff --git a/internal/testlib/testctx.go b/internal/testlib/testctx.go new file mode 100644 index 0000000..007d974 --- /dev/null +++ b/internal/testlib/testctx.go @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Tim +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// +// SPDX-License-Identifier: MIT + +package testlib + +import ( + "testing" + + "github.com/spf13/viper" +) + +type TestCtx struct { + Config *viper.Viper + + ConfigDir string + CacheDir string + PersonasDir string +} + +func NewTestCtx(t *testing.T) *TestCtx { + cacheDir := t.TempDir() + configDir := t.TempDir() + personasDir := t.TempDir() + + config := viper.New() + config.AddConfigPath(configDir) + config.SetConfigName("config") + config.SetConfigType("yaml") + config.Set("cacheDir", cacheDir) + config.Set("personas", personasDir) + config.Set("TESTING", 1) + + return &TestCtx{ + Config: config, + ConfigDir: configDir, + CacheDir: cacheDir, + PersonasDir: personasDir, + } +} diff --git a/api/api.go b/pkg/api/api.go similarity index 51% rename from api/api.go rename to pkg/api/api.go index 5347d82..0e3e66e 100644 --- a/api/api.go +++ b/pkg/api/api.go @@ -23,54 +23,42 @@ package api import ( "context" + "errors" "fmt" + "io" "log/slog" "net/http" "os" "strings" - "github.com/tbckr/sgpt/v2/chat" - "github.com/sashabaranov/go-openai" "github.com/spf13/viper" - "github.com/tbckr/sgpt/v2/modifiers" + "github.com/tbckr/sgpt/v2/pkg/chat" + "github.com/tbckr/sgpt/v2/pkg/modifiers" ) const ( + // envKeyOpenAIApi is the environment variable key for the OpenAI API key. envKeyOpenAIApi = "OPENAI_API_KEY" ) var ( - DefaultModel = strings.Clone(openai.GPT3Dot5Turbo) + // DefaultModel is the default model used for chat completions. + DefaultModel = strings.Clone(openai.GPT3Dot5Turbo) + // ErrMissingAPIKey is returned, if the OPENAI_API_KEY environment variable is not set. ErrMissingAPIKey = fmt.Errorf("%s env variable is not set", envKeyOpenAIApi) ) +// OpenAIClient is a client for the OpenAI API. type OpenAIClient struct { - api *openai.Client - retrieveResponseFn func(*openai.Client, context.Context, openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) -} - -func MockClient(response string, err error) func() (*OpenAIClient, error) { - return func() (*OpenAIClient, error) { - return &OpenAIClient{ - api: nil, - retrieveResponseFn: func(_ *openai.Client, _ context.Context, _ openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { - return openai.ChatCompletionResponse{ - Choices: []openai.ChatCompletionChoice{ - { - Message: openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: response, - }, - }, - }, - }, nil - }, - }, err - } + HTTPClient *http.Client + config *viper.Viper + api *openai.Client + out io.Writer } -func CreateClient() (*OpenAIClient, error) { +// CreateClient creates a new OpenAI client with the given config and output writer. +func CreateClient(config *viper.Viper, out io.Writer) (*OpenAIClient, error) { // Check, if api key was set apiKey, exists := os.LookupEnv(envKeyOpenAIApi) if !exists { @@ -82,9 +70,10 @@ func CreateClient() (*OpenAIClient, error) { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, } - clientConfig.HTTPClient = &http.Client{ + httpClient := &http.Client{ Transport: transport, } + clientConfig.HTTPClient = httpClient // Check, if API base url was set baseURL, isSet := os.LookupEnv("OPENAI_API_BASE") @@ -96,23 +85,24 @@ func CreateClient() (*OpenAIClient, error) { // Create client client := &OpenAIClient{ - api: openai.NewClientWithConfig(clientConfig), - // This is necessary to be able to mock the api in tests - retrieveResponseFn: func(api *openai.Client, ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { - return api.CreateChatCompletion(ctx, req) - }, + HTTPClient: httpClient, + config: config, + api: openai.NewClientWithConfig(clientConfig), + out: out, } - slog.Debug("OpenAI client created") return client, nil } -func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Viper, chatID, prompt, modifier string) (string, error) { +// CreateCompletion creates a completion for the given prompt and modifier. If chatID is provided, the chat is reused +// and the completion is added to the chat with this ID. If no chatID is provided, only the modifier and prompt are +// used to create the completion. The completion is printed to the out writer of the client and returned as a string. +func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID, prompt, modifier string) (string, error) { var err error var chatSessionManager chat.SessionManager var messages []openai.ChatCompletionMessage - chatSessionManager, err = chat.NewFilesystemChatSessionManager(config) + chatSessionManager, err = chat.NewFilesystemChatSessionManager(c.config) if err != nil { return "", err } @@ -146,7 +136,7 @@ func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Vipe // then add modifier message if !isChat || (isChat && !chatExists) { var modifierPrompt string - modifierPrompt, err = modifiers.GetChatModifier(config, modifier) + modifierPrompt, err = modifiers.GetChatModifier(c.config, modifier) if err != nil { return "", err } @@ -166,24 +156,30 @@ func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Vipe }) slog.Debug("Added prompt message") - // Do request + // Create request req := openai.ChatCompletionRequest{ Messages: messages, - Model: config.GetString("model"), - MaxTokens: config.GetInt("max-tokens"), - Temperature: float32(config.GetFloat64("temperature")), - TopP: float32(config.GetFloat64("top-p")), + Model: c.config.GetString("model"), + MaxTokens: c.config.GetInt("max-tokens"), + Temperature: float32(c.config.GetFloat64("temperature")), + TopP: float32(c.config.GetFloat64("top-p")), + Stream: c.config.GetBool("stream"), + } + + // Retrieve response + // Retrieve the completion and print to the out writer. The received message is returned to save it to the chat and + // to return it as a string (copy to clipboard). + var receivedMessage openai.ChatCompletionMessage + if c.config.GetBool("stream") { + receivedMessage, err = c.retrieveChatCompletionStream(ctx, req) + } else { + receivedMessage, err = c.retrieveChatCompletion(ctx, req) } - var resp openai.ChatCompletionResponse - resp, err = c.retrieveResponseFn(c.api, ctx, req) if err != nil { return "", err } - receivedMessage := resp.Choices[0].Message - slog.Debug("Received message from OpenAI API") - // Remove surrounding white spaces - receivedMessage.Content = strings.TrimSpace(receivedMessage.Content) + slog.Debug("Received message from OpenAI API") // If a session was provided, save received message to this chat if isChat { @@ -196,3 +192,58 @@ func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Vipe // Return received message return receivedMessage.Content, nil } + +func (c *OpenAIClient) retrieveChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionMessage, error) { + resp, err := c.api.CreateChatCompletion(ctx, req) + if err != nil { + return openai.ChatCompletionMessage{}, err + } + slog.Debug("Received response") + receivedMessage := resp.Choices[0].Message + + _, err = fmt.Fprintln(c.out, receivedMessage.Content) + if err != nil { + return openai.ChatCompletionMessage{}, err + } + slog.Debug("Printed response") + + return receivedMessage, nil +} + +func (c *OpenAIClient) retrieveChatCompletionStream(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionMessage, error) { + stream, err := c.api.CreateChatCompletionStream(ctx, req) + if err != nil { + return openai.ChatCompletionMessage{}, err + } + defer stream.Close() + slog.Debug("Streaming response") + + var receivedMessage openai.ChatCompletionMessage + for { + response, streamErr := stream.Recv() + if errors.Is(streamErr, io.EOF) { + slog.Debug("Stream finished") + break + } + if streamErr != nil { + slog.Debug("Stream error encountered") + return openai.ChatCompletionMessage{}, streamErr + } + + receivedContent := response.Choices[0].Delta.Content + // 1. Append received content to message + receivedMessage.Content += receivedContent + // 2. Print received content + _, err = fmt.Fprint(c.out, receivedContent) + if err != nil { + return openai.ChatCompletionMessage{}, err + } + } + // Print final linebreak + _, err = fmt.Fprintf(c.out, "\n") + if err != nil { + slog.Warn("Could not print final linebreak") + } + // Return received message to save it to the chat session + return receivedMessage, nil +} diff --git a/api/api_test.go b/pkg/api/api_test.go similarity index 53% rename from api/api_test.go rename to pkg/api/api_test.go index 4be1755..9bcd2aa 100644 --- a/api/api_test.go +++ b/pkg/api/api_test.go @@ -22,51 +22,28 @@ package api import ( + "bytes" "context" + "io" "os" "path/filepath" - "strings" + "sync" "testing" + "github.com/jarcoal/httpmock" "github.com/sashabaranov/go-openai" - "github.com/spf13/viper" "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/chat" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/tbckr/sgpt/v2/pkg/chat" ) -func createTestConfig(t *testing.T) *viper.Viper { - cacheDir := createTempDir(t, "cache") - configDir := createTempDir(t, "config") - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("cacheDir", cacheDir) - config.Set("TESTING", 1) - - return config -} - -func createTempDir(t *testing.T, suffix string) string { - if suffix != "" { - suffix = "_" + suffix - } - tempFilepath, err := os.MkdirTemp("", strings.Join([]string{"sgpt_temp_*", suffix}, "")) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, os.RemoveAll(tempFilepath)) - }) - return tempFilepath -} - func TestCreateClient(t *testing.T) { // Set the api key err := os.Setenv("OPENAI_API_KEY", "test") require.NoError(t, err) var client *OpenAIClient - client, err = CreateClient() + client, err = CreateClient(nil, nil) require.NoError(t, err) require.NotNil(t, client) } @@ -76,27 +53,47 @@ func TestCreateClientMissingApiKey(t *testing.T) { require.NoError(t, err) var client *OpenAIClient - client, err = CreateClient() + client, err = CreateClient(nil, nil) require.Error(t, err) require.ErrorIs(t, err, ErrMissingAPIKey) require.Nil(t, client) } func TestSimplePrompt(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := CreateClient(testCtx.Config, writer) + require.NoError(t, err) + prompt := "Say: Hello World!" expected := "Hello World!" - client, err := MockClient(strings.Clone(expected), nil)() - require.NoError(t, err) - config := createTestConfig(t) + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(expected) + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected+"\n", buf.String()) + }() var result string - result, err = client.GetChatCompletion(context.Background(), config, "", prompt, "txt") + result, err = client.CreateCompletion(context.Background(), "", prompt, "txt") require.NoError(t, err) require.Equal(t, expected, result) + require.NoError(t, writer.Close()) // Cache dir should be empty - cacheDir := config.GetString("cacheDir") + cacheDir := testCtx.Config.GetString("cacheDir") err = filepath.Walk(cacheDir, func(path string, info os.FileInfo, err error) error { if path == cacheDir { // Skip the root dir @@ -107,25 +104,85 @@ func TestSimplePrompt(t *testing.T) { return nil }) require.NoError(t, err) + + wg.Wait() } -func TestPromptSaveAsChat(t *testing.T) { +func TestStreamSimplePrompt(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := CreateClient(testCtx.Config, writer) + require.NoError(t, err) + prompt := "Say: Hello World!" expected := "Hello World!" - client, err := MockClient(strings.Clone(expected), nil)() + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponseStream(expected) + + testCtx.Config.Set("stream", true) + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected+"\n", buf.String()) + }() + + var result string + result, err = client.CreateCompletion(context.Background(), "", prompt, "txt") require.NoError(t, err) - config := createTestConfig(t) + require.Equal(t, expected, result) + require.NoError(t, writer.Close()) + + wg.Wait() +} + +func TestPromptSaveAsChat(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + prompt := "Say: Hello World!" + expected := "Hello World!" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(expected) + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected+"\n", buf.String()) + }() var result string - result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "txt") + result, err = client.CreateCompletion(context.Background(), "test_chat", prompt, "txt") require.NoError(t, err) require.Equal(t, expected, result) + require.NoError(t, writer.Close()) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test_chat")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")) var manager chat.SessionManager - manager, err = chat.NewFilesystemChatSessionManager(config) + manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) var messages []openai.ChatCompletionMessage @@ -140,18 +197,29 @@ func TestPromptSaveAsChat(t *testing.T) { // Check if the response was added require.Equal(t, openai.ChatMessageRoleAssistant, messages[1].Role) require.Equal(t, expected, messages[1].Content) + + wg.Wait() } func TestPromptLoadChat(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := CreateClient(testCtx.Config, writer) + require.NoError(t, err) + prompt := "Repeat last message" expected := "World!" - client, err := MockClient(strings.Clone(expected), nil)() - require.NoError(t, err) - config := createTestConfig(t) + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(expected) var manager chat.SessionManager - manager, err = chat.NewFilesystemChatSessionManager(config) + manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) err = manager.SaveSession("test_chat", []openai.ChatCompletionMessage{ @@ -166,10 +234,21 @@ func TestPromptLoadChat(t *testing.T) { }) require.NoError(t, err) + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected+"\n", buf.String()) + }() + var result string - result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "txt") + result, err = client.CreateCompletion(context.Background(), "test_chat", prompt, "txt") require.NoError(t, err) require.Equal(t, expected, result) + require.NoError(t, writer.Close()) var messages []openai.ChatCompletionMessage messages, err = manager.GetSession("test_chat") @@ -183,15 +262,27 @@ func TestPromptLoadChat(t *testing.T) { // Check if the response was added require.Equal(t, openai.ChatMessageRoleAssistant, messages[3].Role) require.Equal(t, expected, messages[3].Content) + + wg.Wait() } func TestPromptWithModifier(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := CreateClient(testCtx.Config, writer) + require.NoError(t, err) + prompt := "Print Hello World" + response := `echo \"Hello World\"` expected := `echo "Hello World"` - client, err := MockClient(strings.Clone(expected), nil)() - require.NoError(t, err) - config := createTestConfig(t) + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) err = os.Setenv("SHELL", "/bin/bash") require.NoError(t, err) @@ -199,17 +290,28 @@ func TestPromptWithModifier(t *testing.T) { require.NoError(t, os.Unsetenv("SHELL")) }) - config.Set("chat", "test_chat") + testCtx.Config.Set("chat", "test_chat") + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected+"\n", buf.String()) + }() var result string - result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "sh") + result, err = client.CreateCompletion(context.Background(), "test_chat", prompt, "sh") require.NoError(t, err) require.Equal(t, expected, result) + require.NoError(t, writer.Close()) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test_chat")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")) var manager chat.SessionManager - manager, err = chat.NewFilesystemChatSessionManager(config) + manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) var messages []openai.ChatCompletionMessage @@ -227,4 +329,6 @@ func TestPromptWithModifier(t *testing.T) { // Check if the response was added require.Equal(t, openai.ChatMessageRoleAssistant, messages[2].Role) require.Equal(t, expected, messages[2].Content) + + wg.Wait() } diff --git a/chat/chat.go b/pkg/chat/chat.go similarity index 100% rename from chat/chat.go rename to pkg/chat/chat.go diff --git a/chat/filesystem.go b/pkg/chat/filesystem.go similarity index 100% rename from chat/filesystem.go rename to pkg/chat/filesystem.go diff --git a/chat/filesystem_test.go b/pkg/chat/filesystem_test.go similarity index 100% rename from chat/filesystem_test.go rename to pkg/chat/filesystem_test.go diff --git a/cli/chat.go b/pkg/cli/chat.go similarity index 93% rename from cli/chat.go rename to pkg/cli/chat.go index f72b52f..0f77924 100644 --- a/cli/chat.go +++ b/pkg/cli/chat.go @@ -27,10 +27,11 @@ import ( "io" "strings" + chat2 "github.com/tbckr/sgpt/v2/pkg/chat" + "github.com/sashabaranov/go-openai" "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/tbckr/sgpt/v2/chat" ) const ( @@ -96,7 +97,7 @@ List all chat sessions. Args: cobra.NoArgs, ValidArgsFunction: cobra.NoFileCompletions, RunE: func(cmd *cobra.Command, _ []string) error { - chatSessionManager, err := chat.NewFilesystemChatSessionManager(config) + chatSessionManager, err := chat2.NewFilesystemChatSessionManager(config) if err != nil { return err } @@ -129,7 +130,7 @@ Show the conversation for the given chat session. Args: cobra.ExactArgs(1), ValidArgsFunction: cobra.NoFileCompletions, RunE: func(cmd *cobra.Command, args []string) error { - chatSessionManager, err := chat.NewFilesystemChatSessionManager(config) + chatSessionManager, err := chat2.NewFilesystemChatSessionManager(config) if err != nil { return err } @@ -166,7 +167,7 @@ Remove the specified chat session. The --all flag removes all chat sessions. Args: cobra.RangeArgs(0, 1), ValidArgsFunction: cobra.NoFileCompletions, RunE: func(cmd *cobra.Command, args []string) error { - chatSessionManager, err := chat.NewFilesystemChatSessionManager(config) + chatSessionManager, err := chat2.NewFilesystemChatSessionManager(config) if err != nil { return err } @@ -203,7 +204,7 @@ func showConversation(out io.Writer, messages []openai.ChatCompletionMessage) er return nil } -func deleteChatSessions(manager chat.SessionManager, out io.Writer, chatSessions []string) error { +func deleteChatSessions(manager chat2.SessionManager, out io.Writer, chatSessions []string) error { for _, chatSession := range chatSessions { err := manager.DeleteSession(chatSession) if err != nil { diff --git a/cli/chat_test.go b/pkg/cli/chat_test.go similarity index 68% rename from cli/chat_test.go rename to pkg/cli/chat_test.go index 1f66b22..a50e838 100644 --- a/cli/chat_test.go +++ b/pkg/cli/chat_test.go @@ -29,33 +29,34 @@ import ( "testing" "github.com/sashabaranov/go-openai" - "github.com/tbckr/sgpt/v2/chat" + + "github.com/tbckr/sgpt/v2/pkg/chat" + + "github.com/tbckr/sgpt/v2/internal/testlib" "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/api" ) func TestChatCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat"}) require.Equal(t, 0, mem.code) } func TestChatCmdListEmptySessions(t *testing.T) { - expected := "" - + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + expected := "" - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.cmd.SetOut(writer) wg.Add(1) @@ -76,30 +77,30 @@ func TestChatCmdListEmptySessions(t *testing.T) { } func TestChatCmdListOneSession(t *testing.T) { - expected := "test\n" - + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + expected := "test\n" - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -112,15 +113,15 @@ func TestChatCmdListOneSession(t *testing.T) { } func TestChatCmdListTwoSessions(t *testing.T) { - expected := "test\ntest2\n" - + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + expected := "test\ntest2\n" - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() @@ -129,15 +130,15 @@ func TestChatCmdListTwoSessions(t *testing.T) { err = manager.SaveSession("test2", messages) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -150,28 +151,28 @@ func TestChatCmdListTwoSessions(t *testing.T) { } func TestChatCmdShowSession(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Contains(t, buf.String(), "You are a chat bot.") require.Contains(t, buf.String(), "I am a chat bot.") @@ -185,57 +186,55 @@ func TestChatCmdShowSession(t *testing.T) { } func TestChatCmdShowSessionMissingName(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "show"}) require.Equal(t, 1, mem.code) } func TestChatCmdShowSessionNonExistent(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "show", "test2"}) require.Equal(t, 1, mem.code) } func TestChatCmdShowSessionWithAlias(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Contains(t, buf.String(), "You are a chat bot.") require.Contains(t, buf.String(), "I am a chat bot.") @@ -249,75 +248,71 @@ func TestChatCmdShowSessionWithAlias(t *testing.T) { } func TestChatCmdRmSession(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "rm", "test"}) require.Equal(t, 0, mem.code) - require.NoFileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) + require.NoFileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) } func TestChatCmdRmSessionNonExistent(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "rm", "test2"}) require.Equal(t, 0, mem.code) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) } func TestChatCmdRmSessionAll(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) messages := createTestMessages() err = manager.SaveSession("test", messages) require.NoError(t, err) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) err = manager.SaveSession("test2", messages) require.NoError(t, err) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test2")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test2")) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "rm", "--all"}) require.Equal(t, 0, mem.code) - require.NoFileExists(t, filepath.Join(config.GetString("cacheDir"), "test")) - require.NoFileExists(t, filepath.Join(config.GetString("cacheDir"), "test2")) + require.NoFileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test")) + require.NoFileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test2")) } func TestChatCmdRmSessionMissingName(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) root.Execute([]string{"chat", "rm"}) require.Equal(t, 1, mem.code) diff --git a/cli/check.go b/pkg/cli/check.go similarity index 89% rename from cli/check.go rename to pkg/cli/check.go index 46bc7c5..6c61654 100644 --- a/cli/check.go +++ b/pkg/cli/check.go @@ -23,18 +23,20 @@ package cli import ( "fmt" + "io" "strings" + "github.com/tbckr/sgpt/v2/pkg/api" + "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/tbckr/sgpt/v2/api" ) type checkCmd struct { cmd *cobra.Command } -func newCheckCmd(config *viper.Viper, createClientFn func() (*api.OpenAIClient, error)) *checkCmd { +func newCheckCmd(config *viper.Viper, createClientFn func(*viper.Viper, io.Writer) (*api.OpenAIClient, error)) *checkCmd { check := &checkCmd{} cmd := &cobra.Command{ Use: "check", @@ -49,7 +51,7 @@ This command will return an error if the API key is not set or invalid. if err != nil { return err } - _, err = createClientFn() + _, err = createClientFn(config, cmd.OutOrStdout()) if err != nil { return err } diff --git a/cli/check_test.go b/pkg/cli/check_test.go similarity index 72% rename from cli/check_test.go rename to pkg/cli/check_test.go index 22f3fdf..31d52fc 100644 --- a/cli/check_test.go +++ b/pkg/cli/check_test.go @@ -22,35 +22,30 @@ package cli import ( - "os" "testing" + "github.com/tbckr/sgpt/v2/pkg/api" + + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/api" ) func TestCheckCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} - config := createTestConfig(t) + testlib.SetAPIKey(t) - err := os.Setenv("OPENAI_API_KEY", "test") - require.NoError(t, err) - t.Cleanup(func() { - os.Unsetenv("OPENAI_API_KEY") - }) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)).Execute([]string{"check"}) + newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), api.CreateClient).Execute([]string{"check"}) require.Equal(t, 0, mem.code) } func TestCheckCmdUnsetEnvAPIKey(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - err := os.Unsetenv("OPENAI_API_KEY") - require.NoError(t, err) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", api.ErrMissingAPIKey)).Execute([]string{"check"}) + newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), api.CreateClient).Execute([]string{"check"}) require.Equal(t, 1, mem.code) } diff --git a/cli/completion_test.go b/pkg/cli/completion_test.go similarity index 74% rename from cli/completion_test.go rename to pkg/cli/completion_test.go index e7c6825..403758d 100644 --- a/cli/completion_test.go +++ b/pkg/cli/completion_test.go @@ -25,15 +25,16 @@ import ( "io" "testing" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/stretchr/testify/require" ) func TestCompletionCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion"}) @@ -41,11 +42,10 @@ func TestCompletionCmd(t *testing.T) { } func TestCompletionCmdBash(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "bash"}) @@ -53,11 +53,10 @@ func TestCompletionCmdBash(t *testing.T) { } func TestCompletionCmdFish(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "fish"}) @@ -65,11 +64,10 @@ func TestCompletionCmdFish(t *testing.T) { } func TestCompletionCmdPowershell(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "powershell"}) @@ -77,11 +75,10 @@ func TestCompletionCmdPowershell(t *testing.T) { } func TestCompletionCmdZsh(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "zsh"}) @@ -89,11 +86,10 @@ func TestCompletionCmdZsh(t *testing.T) { } func TestCompletionCmdUnknownCompletion(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "abcd"}) @@ -101,11 +97,10 @@ func TestCompletionCmdUnknownCompletion(t *testing.T) { } func TestCompletionCmdTooManyArgs(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"completion", "abcd", "efgh"}) diff --git a/cli/config.go b/pkg/cli/config.go similarity index 100% rename from cli/config.go rename to pkg/cli/config.go diff --git a/pkg/cli/config_test.go b/pkg/cli/config_test.go new file mode 100644 index 0000000..c3949df --- /dev/null +++ b/pkg/cli/config_test.go @@ -0,0 +1,94 @@ +// Copyright (c) 2023 Tim +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// +// SPDX-License-Identifier: MIT + +package cli + +import ( + "path/filepath" + "testing" + + "github.com/tbckr/sgpt/v2/internal/testlib" + + "github.com/stretchr/testify/require" +) + +func TestConfigCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + mem := &exitMemento{} + + root := newRootCmd(mem.Exit, testCtx.Config, nil, nil) + + root.Execute([]string{"config"}) + require.Equal(t, 0, mem.code) +} + +func TestConfigCmdInit(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + mem := &exitMemento{} + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "init"}) + require.Equal(t, 0, mem.code) + + require.FileExists(t, filepath.Join(testCtx.ConfigDir, "config.yaml")) + // config must only contain values for model, maxtokens, temperature, topp + require.NoError(t, testCtx.Config.ReadInConfig()) + // TESTING may be in the config, because this is a test + require.Equal(t, 8, len(testCtx.Config.AllSettings())) + for _, key := range []string{"model", "maxtokens", "temperature", "topp", "cachedir", "personas", "stream", "testing"} { + require.Contains(t, testCtx.Config.AllSettings(), key) + } +} + +func TestConfigCmdInitAlreadyExists(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + mem := &exitMemento{} + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "init"}) + require.Equal(t, 0, mem.code) + + require.FileExists(t, filepath.Join(testCtx.ConfigDir, "config.yaml")) + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "init"}) + require.Equal(t, 1, mem.code) +} + +func TestConfigCmdShowConfig(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + mem := &exitMemento{} + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "init"}) + require.Equal(t, 0, mem.code) + + require.FileExists(t, filepath.Join(testCtx.ConfigDir, "config.yaml")) + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "show"}) + require.Equal(t, 0, mem.code) +} + +func TestConfigCmdShowConfigNonExistent(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + mem := &exitMemento{} + + require.NoFileExists(t, filepath.Join(testCtx.ConfigDir, "config.yaml")) + + newRootCmd(mem.Exit, testCtx.Config, nil, nil).Execute([]string{"config", "show"}) + require.Equal(t, 1, mem.code) +} diff --git a/cli/error.go b/pkg/cli/error.go similarity index 100% rename from cli/error.go rename to pkg/cli/error.go diff --git a/cli/licenses.go b/pkg/cli/licenses.go similarity index 100% rename from cli/licenses.go rename to pkg/cli/licenses.go diff --git a/cli/licenses_test.go b/pkg/cli/licenses_test.go similarity index 92% rename from cli/licenses_test.go rename to pkg/cli/licenses_test.go index 16ab9b9..3848ab3 100644 --- a/cli/licenses_test.go +++ b/pkg/cli/licenses_test.go @@ -28,11 +28,13 @@ import ( "sync" "testing" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/api" ) func TestLicensesCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} expected := `To see the open source packages included in SGPT and their respective license information, visit: @@ -40,9 +42,7 @@ their respective license information, visit: var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(writer) wg.Add(1) diff --git a/cli/man.go b/pkg/cli/man.go similarity index 100% rename from cli/man.go rename to pkg/cli/man.go diff --git a/cli/man_test.go b/pkg/cli/man_test.go similarity index 84% rename from cli/man_test.go rename to pkg/cli/man_test.go index 250ecc0..6e86548 100644 --- a/cli/man_test.go +++ b/pkg/cli/man_test.go @@ -25,26 +25,26 @@ import ( "io" "testing" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/stretchr/testify/require" ) func TestManCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"man"}) require.Equal(t, 0, mem.code) } func TestManCmdUnknowArgs(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) root.cmd.SetOut(io.Discard) root.Execute([]string{"man", "abcd"}) diff --git a/cli/root.go b/pkg/cli/root.go similarity index 92% rename from cli/root.go rename to pkg/cli/root.go index 4b9dbf7..e970749 100644 --- a/cli/root.go +++ b/pkg/cli/root.go @@ -23,17 +23,18 @@ package cli import ( "errors" - "fmt" + "io" "log/slog" "os" "strings" + "github.com/tbckr/sgpt/v2/pkg/api" + "github.com/tbckr/sgpt/v2/pkg/fs" + "github.com/tbckr/sgpt/v2/pkg/shell" + "github.com/atotto/clipboard" "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/tbckr/sgpt/v2/api" - "github.com/tbckr/sgpt/v2/fs" - "github.com/tbckr/sgpt/v2/shell" ) type rootCmd struct { @@ -101,7 +102,7 @@ func (r *rootCmd) Execute(args []string) { r.exit(0) } -func newRootCmd(exit func(int), config *viper.Viper, isPipedShell func() (bool, error), createClientFn func() (*api.OpenAIClient, error)) *rootCmd { +func newRootCmd(exit func(int), config *viper.Viper, isPipedShell func() (bool, error), createClientFn func(*viper.Viper, io.Writer) (*api.OpenAIClient, error)) *rootCmd { root := &rootCmd{ exit: exit, } @@ -208,17 +209,13 @@ ls | sort // Create client var client *api.OpenAIClient - client, err = createClientFn() + client, err = createClientFn(config, cmd.OutOrStdout()) if err != nil { return err } var response string - response, err = client.GetChatCompletion(cmd.Context(), config, root.chat, prompt, mode) - if err != nil { - return err - } - _, err = fmt.Fprintln(cmd.OutOrStdout(), response) + response, err = client.CreateCompletion(cmd.Context(), root.chat, prompt, mode) if err != nil { return err } @@ -293,6 +290,12 @@ func createFlagsWithConfigBinding(cmd *cobra.Command, config *viper.Viper) { bindErrors = append(bindErrors, err) } + cmd.Flags().Bool("stream", false, "stream output") + err = config.BindPFlag("stream", cmd.Flags().Lookup("stream")) + if err != nil { + bindErrors = append(bindErrors, err) + } + if len(bindErrors) > 0 { for _, err = range bindErrors { slog.Error("Failed to bind flag to viper", "error", err) @@ -329,6 +332,13 @@ func setViperDefaults(config *viper.Viper) error { return err } config.SetDefault("cacheDir", appCacheDir) + // personas dir + var personasDir string + personasDir, err = fs.GetPersonasPath() + if err != nil { + return err + } + config.SetDefault("personas", personasDir) // model config.SetDefault("model", api.DefaultModel) @@ -338,8 +348,8 @@ func setViperDefaults(config *viper.Viper) error { config.SetDefault("temperature", 1) // top-p config.SetDefault("topP", 1) - // execute - config.SetDefault("execute", false) + // stream + config.SetDefault("stream", false) return nil } diff --git a/cli/root_test.go b/pkg/cli/root_test.go similarity index 55% rename from cli/root_test.go rename to pkg/cli/root_test.go index f543aa8..d0fea72 100644 --- a/cli/root_test.go +++ b/pkg/cli/root_test.go @@ -32,38 +32,45 @@ import ( "sync" "testing" + "github.com/tbckr/sgpt/v2/pkg/chat" + "github.com/atotto/clipboard" "github.com/sashabaranov/go-openai" + + "github.com/jarcoal/httpmock" "github.com/stretchr/testify/require" - "github.com/tbckr/sgpt/v2/api" - "github.com/tbckr/sgpt/v2/chat" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/tbckr/sgpt/v2/pkg/api" ) -func TestCreateViperConfig(t *testing.T) { - config, err := createViperConfig() - require.NoError(t, err) - require.NotNil(t, config) -} - func TestRootCmd_SimplePrompt(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) + + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -76,24 +83,33 @@ func TestRootCmd_SimplePrompt(t *testing.T) { } func TestRootCmd_SimplePromptOnly(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -108,77 +124,137 @@ func TestRootCmd_SimplePromptOnly(t *testing.T) { func TestRootCmd_SimpleClipboard(t *testing.T) { skipInCI(t) + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) + mem := &exitMemento{} + + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + prompt := "Say: Hello World!" - expected := "Hello World!" + response := "Hello World!" + expected := "Hello World!\n" - mem := &exitMemento{} - config := createTestConfig(t) + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) + root.cmd.SetOut(writer) + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected, buf.String()) + }() root.Execute([]string{"--clipboard", prompt}) require.Equal(t, 0, mem.code) + require.NoError(t, writer.Close()) + textInClipboard, _ := clipboard.ReadAll() - require.Equal(t, expected, textInClipboard) + // The clipboard should not have the trailing newline + require.Equal(t, response, textInClipboard) + + wg.Wait() } func TestRootCmd_SimplePromptOverrideValuesWithConfigFile(t *testing.T) { - prompt := "Say: Hello World!" + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} - configDir := t.TempDir() + var wg sync.WaitGroup + reader, writer := io.Pipe() - config, err := createViperConfig() + client, err := api.CreateClient(testCtx.Config, writer) require.NoError(t, err) - config.SetConfigFile(filepath.Join(configDir, "config.yaml")) - config.Set("TESTING", 1) + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) var configFile *os.File - configFile, err = os.Create(filepath.Join(configDir, "config.yaml")) + configFile, err = os.Create(filepath.Join(testCtx.ConfigDir, "config.yaml")) require.NoError(t, err) _, err = configFile.WriteString(fmt.Sprintf("model: \"%s\"\n", openai.GPT4)) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("Hello World", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) + + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, expected, buf.String()) + }() root.Execute([]string{"txt", prompt}) require.Equal(t, 0, mem.code) + require.NoError(t, writer.Close()) + + require.Equal(t, openai.GPT4, testCtx.Config.GetString("model")) - require.Equal(t, openai.GPT4, config.GetString("model")) + wg.Wait() } func TestRootCmd_SimplePromptNoPrompt(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, nil) + require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.Execute([]string{}) require.Equal(t, 1, mem.code) } func TestRootCmd_SimplePromptVerbose(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) + + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -191,17 +267,26 @@ func TestRootCmd_SimplePromptVerbose(t *testing.T) { } func TestRootCmd_SimplePromptViaPipedShell(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup stdinReader, stdinWriter := io.Pipe() stdoutReader, stdoutWriter := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, stdoutWriter) + require.NoError(t, err) + + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(true, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, nil), useMockClient(client)) root.cmd.SetIn(stdinReader) root.cmd.SetOut(stdoutWriter) @@ -217,8 +302,8 @@ func TestRootCmd_SimplePromptViaPipedShell(t *testing.T) { go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, stdoutReader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, stdoutReader) + require.NoError(t, errReader) require.NoError(t, stdoutReader.Close()) require.Equal(t, expected, buf.String()) }() @@ -232,14 +317,20 @@ func TestRootCmd_SimplePromptViaPipedShell(t *testing.T) { } func TestRootCmd_PipedShell_NoInput(t *testing.T) { + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, stdoutWriter) + require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(true, nil), api.MockClient("", nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, nil), useMockClient(client)) root.cmd.SetIn(stdinReader) + root.cmd.SetOut(stdoutWriter) wg.Add(1) go func() { @@ -249,39 +340,78 @@ func TestRootCmd_PipedShell_NoInput(t *testing.T) { require.NoError(t, errWrite) }() + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, stdoutReader) + require.NoError(t, errReader) + require.NoError(t, stdoutReader.Close()) + require.Equal(t, "", buf.String()) + }() + root.Execute([]string{}) require.Equal(t, 1, mem.code) require.NoError(t, stdinReader.Close()) + require.NoError(t, stdoutWriter.Close()) wg.Wait() } func TestRootCmd_SimplePrompt_PipedShellError(t *testing.T) { - prompt := "Say: Hello World!" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} - config := createTestConfig(t) - testError := errors.New("test error") + var wg sync.WaitGroup + reader, writer := io.Pipe() + + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + prompt := "Say: Hello World!" + + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, errors.New("test error")), useMockClient(client)) + root.cmd.SetOut(writer) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(true, testError), api.MockClient("", nil)) + wg.Add(1) + go func() { + defer wg.Done() + var buf bytes.Buffer + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) + require.NoError(t, reader.Close()) + require.Equal(t, "", buf.String()) + }() root.Execute([]string{prompt}) require.Equal(t, 1, mem.code) + require.NoError(t, writer.Close()) + + wg.Wait() } func TestRootCmd_SimplePromptViaPipedShellAndModifier(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup stdinReader, stdinWriter := io.Pipe() stdoutReader, stdoutWriter := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, stdoutWriter) + require.NoError(t, err) + + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(true, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, nil), useMockClient(client)) root.cmd.SetIn(stdinReader) root.cmd.SetOut(stdoutWriter) @@ -297,8 +427,8 @@ func TestRootCmd_SimplePromptViaPipedShellAndModifier(t *testing.T) { go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, stdoutReader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, stdoutReader) + require.NoError(t, errReader) require.NoError(t, stdoutReader.Close()) require.Equal(t, expected, buf.String()) }() @@ -312,30 +442,39 @@ func TestRootCmd_SimplePromptViaPipedShellAndModifier(t *testing.T) { } func TestRootCmd_SimpleShellPrompt(t *testing.T) { - prompt := `echo "Hello World"` - expected := "Hello World\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) - err := os.Setenv("SHELL", "/bin/bash") + prompt := `echo "Hello World"` + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) + + err = os.Setenv("SHELL", "/bin/bash") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, os.Unsetenv("SHELL")) }) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -348,23 +487,32 @@ func TestRootCmd_SimpleShellPrompt(t *testing.T) { } func TestRootCmd_SimpleShellPromptWithExecution(t *testing.T) { - prompt := `Print: Hello World` - expected := "echo \"Hello World\"\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup stdinReader, stdinWriter := io.Pipe() stdoutReader, stdoutWriter := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, stdoutWriter) + require.NoError(t, err) + + prompt := "Print: Hello World" + response := `echo \"Hello World\"` + expected := "echo \"Hello World\"\n" - err := os.Setenv("SHELL", "/bin/bash") + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) + + err = os.Setenv("SHELL", "/bin/bash") require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, os.Unsetenv("SHELL")) + _ = os.Unsetenv("SHELL") }) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetIn(stdinReader) root.cmd.SetOut(stdoutWriter) @@ -380,8 +528,8 @@ func TestRootCmd_SimpleShellPromptWithExecution(t *testing.T) { go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, stdoutReader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, stdoutReader) + require.NoError(t, errReader) require.NoError(t, stdoutReader.Close()) stdoutOutput := expected + "Do you want to execute this command? (Y/n) Hello World\n" require.Equal(t, stdoutOutput, buf.String()) @@ -397,24 +545,33 @@ func TestRootCmd_SimpleShellPromptWithExecution(t *testing.T) { } func TestRootCmd_SimplePromptWithChat(t *testing.T) { - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) + + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -423,9 +580,9 @@ func TestRootCmd_SimplePromptWithChat(t *testing.T) { require.Equal(t, 0, mem.code) require.NoError(t, writer.Close()) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test_chat")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")) - manager, err := chat.NewFilesystemChatSessionManager(config) + manager, err := chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) var messages []openai.ChatCompletionMessage @@ -445,37 +602,47 @@ func TestRootCmd_SimplePromptWithChat(t *testing.T) { } func TestRootCmd_SimplePromptWithChatAndCustomPersona(t *testing.T) { - persona := "This is my custom persona" - prompt := "Say: Hello World!" - expected := "Hello World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + persona := "This is my custom persona" + prompt := "Say: Hello World!" + response := "Hello World!" + expected := "Hello World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) - err := os.Setenv("SHELL", "/bin/bash") + err = os.Setenv("SHELL", "/bin/bash") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, os.Unsetenv("SHELL")) }) - fileHandler, err := os.Create(filepath.Join(config.GetString("personas"), "my-persona")) + var fileHandler *os.File + fileHandler, err = os.Create(filepath.Join(testCtx.Config.GetString("personas"), "my-persona")) require.NoError(t, err) _, err = fileHandler.WriteString(persona) require.NoError(t, err) require.NoError(t, fileHandler.Close()) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -484,10 +651,10 @@ func TestRootCmd_SimplePromptWithChatAndCustomPersona(t *testing.T) { require.Equal(t, 0, mem.code) require.NoError(t, writer.Close()) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test_chat")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")) var manager chat.SessionManager - manager, err = chat.NewFilesystemChatSessionManager(config) + manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) var messages []openai.ChatCompletionMessage @@ -511,17 +678,27 @@ func TestRootCmd_SimplePromptWithChatAndCustomPersona(t *testing.T) { } func TestRootCmd_ChatConversation(t *testing.T) { - prompt := "Repeat last message" - expected := "World!\n" - + testCtx := testlib.NewTestCtx(t) + testlib.SetAPIKey(t) mem := &exitMemento{} + var wg sync.WaitGroup reader, writer := io.Pipe() - config := createTestConfig(t) + client, err := api.CreateClient(testCtx.Config, writer) + require.NoError(t, err) + + prompt := "Repeat last message" + response := "World!" + expected := "World!\n" + + httpmock.ActivateNonDefault(client.HTTPClient) + t.Cleanup(httpmock.DeactivateAndReset) + testlib.RegisterExpectedChatResponse(response) // Create an existing chat session - manager, err := chat.NewFilesystemChatSessionManager(config) + var manager chat.SessionManager + manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config) require.NoError(t, err) err = manager.SaveSession("test_chat", []openai.ChatCompletionMessage{ { @@ -535,15 +712,15 @@ func TestRootCmd_ChatConversation(t *testing.T) { }) require.NoError(t, err) - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), api.MockClient(strings.Clone(expected), nil)) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), useMockClient(client)) root.cmd.SetOut(writer) wg.Add(1) go func() { defer wg.Done() var buf bytes.Buffer - _, err := io.Copy(&buf, reader) - require.NoError(t, err) + _, errReader := io.Copy(&buf, reader) + require.NoError(t, errReader) require.NoError(t, reader.Close()) require.Equal(t, expected, buf.String()) }() @@ -552,7 +729,7 @@ func TestRootCmd_ChatConversation(t *testing.T) { require.Equal(t, 0, mem.code) require.NoError(t, writer.Close()) - require.FileExists(t, filepath.Join(config.GetString("cacheDir"), "test_chat")) + require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")) var messages []openai.ChatCompletionMessage messages, err = manager.GetSession("test_chat") diff --git a/cli/util_test.go b/pkg/cli/util_test.go similarity index 80% rename from cli/util_test.go rename to pkg/cli/util_test.go index ef8b269..b9592ef 100644 --- a/cli/util_test.go +++ b/pkg/cli/util_test.go @@ -22,12 +22,21 @@ package cli import ( + "io" "os" "testing" + "github.com/tbckr/sgpt/v2/pkg/api" + "github.com/spf13/viper" ) +var useMockClient = func(mockClient *api.OpenAIClient) func(*viper.Viper, io.Writer) (*api.OpenAIClient, error) { + return func(_ *viper.Viper, _ io.Writer) (*api.OpenAIClient, error) { + return mockClient, nil + } +} + type exitMemento struct { code int } @@ -42,22 +51,6 @@ func mockIsPipedShell(isPiped bool, err error) func() (bool, error) { } } -func createTestConfig(t *testing.T) *viper.Viper { - configDir := t.TempDir() - cacheDir := t.TempDir() - personasDir := t.TempDir() - - config := viper.New() - config.AddConfigPath(configDir) - config.SetConfigName("config") - config.SetConfigType("yaml") - config.Set("cacheDir", cacheDir) - config.Set("personas", personasDir) - config.Set("TESTING", 1) - - return config -} - func skipInCI(t *testing.T) { if os.Getenv("CI") != "" { t.Skip("Skipping test on CI") diff --git a/cli/version.go b/pkg/cli/version.go similarity index 100% rename from cli/version.go rename to pkg/cli/version.go diff --git a/cli/version_test.go b/pkg/cli/version_test.go similarity index 83% rename from cli/version_test.go rename to pkg/cli/version_test.go index 7758c3d..f517a76 100644 --- a/cli/version_test.go +++ b/pkg/cli/version_test.go @@ -27,15 +27,16 @@ import ( "strings" "testing" + "github.com/tbckr/sgpt/v2/internal/testlib" + "github.com/stretchr/testify/require" ) func TestVersionCmd(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) cmd := root.cmd outBytes := bytes.NewBufferString("") @@ -54,11 +55,10 @@ func TestVersionCmd(t *testing.T) { } func TestVersionCmdFull(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - root := newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil) + root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil) cmd := root.cmd outBytes := bytes.NewBufferString("") @@ -77,10 +77,9 @@ func TestVersionCmdFull(t *testing.T) { } func TestVersionCmdUnknowArg(t *testing.T) { + testCtx := testlib.NewTestCtx(t) mem := &exitMemento{} - config := createTestConfig(t) - - newRootCmd(mem.Exit, config, mockIsPipedShell(false, nil), nil).Execute([]string{"version", "abcd"}) + newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(false, nil), nil).Execute([]string{"version", "abcd"}) require.Equal(t, 1, mem.code) } diff --git a/fs/fs.go b/pkg/fs/fs.go similarity index 100% rename from fs/fs.go rename to pkg/fs/fs.go diff --git a/fs/fs_test.go b/pkg/fs/fs_test.go similarity index 100% rename from fs/fs_test.go rename to pkg/fs/fs_test.go diff --git a/modifiers/defaults.go b/pkg/modifiers/defaults.go similarity index 100% rename from modifiers/defaults.go rename to pkg/modifiers/defaults.go diff --git a/modifiers/defaults_test.go b/pkg/modifiers/defaults_test.go similarity index 100% rename from modifiers/defaults_test.go rename to pkg/modifiers/defaults_test.go diff --git a/modifiers/modifiers.go b/pkg/modifiers/modifiers.go similarity index 99% rename from modifiers/modifiers.go rename to pkg/modifiers/modifiers.go index 31a27f9..33256b5 100644 --- a/modifiers/modifiers.go +++ b/pkg/modifiers/modifiers.go @@ -31,9 +31,9 @@ import ( "strings" "text/template" - "github.com/spf13/viper" + "github.com/tbckr/sgpt/v2/pkg/fs" - "github.com/tbckr/sgpt/v2/fs" + "github.com/spf13/viper" ) const ( diff --git a/modifiers/modifiers_test.go b/pkg/modifiers/modifiers_test.go similarity index 100% rename from modifiers/modifiers_test.go rename to pkg/modifiers/modifiers_test.go diff --git a/modifiers/prompts.yml b/pkg/modifiers/prompts.yml similarity index 100% rename from modifiers/prompts.yml rename to pkg/modifiers/prompts.yml diff --git a/shell/shell.go b/pkg/shell/shell.go similarity index 100% rename from shell/shell.go rename to pkg/shell/shell.go diff --git a/shell/shell_test.go b/pkg/shell/shell_test.go similarity index 100% rename from shell/shell_test.go rename to pkg/shell/shell_test.go