Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: provide input via stdin and as an argument reopened #239

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ repos:
language: golang
require_serial: true
pass_filenames: false
- id: govulncheck
name: govulncheck
description: Check for vulnerable dependencies
entry: govulncheck ./...
types: [ go ]
language: golang
require_serial: true
pass_filenames: false
# TODO: deactivated as long as this issues is not resolved: https://github.com/golang/go/issues/65608
# - id: govulncheck
# name: govulncheck
# description: Check for vulnerable dependencies
# entry: govulncheck ./...
# types: [ go ]
# language: golang
# require_serial: true
# pass_filenames: false
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ $ echo -n "mass of sun" | sgpt
The mass of the sun is approximately 1.989 x 10^30 kilograms.
```

You can also add another prompt to the piped data by specifying the `stdin` modifier and then specifying the prompt:

```shell
$ echo "Say: Hello World!" | sgpt stdin 'Replace every "World" word with "ChatGPT"'
Hello ChatGPT!
```

If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output
to the command line as it is generated.

Expand Down
7 changes: 7 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ $ echo -n "mass of sun" | sgpt
The mass of the sun is approximately 1.989 x 10^30 kilograms.
```

You can also add another prompt to the piped data by specifying the `stdin` modifier and then specifying the prompt:

```shell
$ echo "Say: Hello World!" | sgpt stdin 'Replace every "World" word with "ChatGPT"'
Hello ChatGPT!
```

If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output
to the command line as it is generated.

Expand Down
44 changes: 25 additions & 19 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func CreateClient(config *viper.Viper, out io.Writer) (*OpenAIClient, error) {
// CreateCompletion creates a completion for the given prompt and modifier. If chatID is provided, the chat is reused
// and the completion is added to the chat with this ID. If no chatID is provided, only the modifier and prompt are
// used to create the completion. The completion is printed to the out writer of the client and returned as a string.
func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID, prompt, modifier string, input []string) (string, error) {
func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID string, prompt []string, modifier string, input []string) (string, error) {
var messages []openai.ChatCompletionMessage
var err error

Expand All @@ -128,12 +128,12 @@ func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID, prompt, mod
messages = append(messages, loadedMessages...)

// Add prompt to messages
var promptMessage openai.ChatCompletionMessage
promptMessage, err = c.createPromptMessage(prompt, input)
var promptMessages []openai.ChatCompletionMessage
promptMessages, err = c.createPromptMessages(prompt, input)
if err != nil {
return "", err
}
messages = append(messages, promptMessage)
messages = append(messages, promptMessages...)
slog.Debug("Added prompt message")

// Create request
Expand Down Expand Up @@ -218,20 +218,22 @@ func (c *OpenAIClient) loadChatMessages(isChat bool, chatID, modifier string) (m
return
}

func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (message openai.ChatCompletionMessage, err error) {
func (c *OpenAIClient) createPromptMessages(prompts, input []string) (messages []openai.ChatCompletionMessage, err error) {
if len(input) > 0 {
slog.Warn("The GPT-4 Vision API is in beta and may not work as expected")
// Request to the gpt-4-vision API
slog.Warn("The GPT-4 Vision API is in beta and may not work as expected")

var messageParts []openai.ChatMessagePart
// Add prompt to message
messageParts := []openai.ChatMessagePart{
{
// We append the stdin as part of the prompt as a message part
for _, p := range prompts {
messageParts = append(messageParts, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeText,
Text: prompt,
},
Text: p,
})
}

// Add images to message
// Add images to messages
for _, i := range input {
// By default, assume that the input is a URL
imageData := i
Expand All @@ -241,7 +243,7 @@ func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (messa
// Input is a file, load image data
imageData, err = c.buildImageFileData(i)
if err != nil {
return openai.ChatCompletionMessage{}, err
return []openai.ChatCompletionMessage{}, err
}
}

Expand All @@ -253,19 +255,23 @@ func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (messa
})
}

message = openai.ChatCompletionMessage{
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
MultiContent: messageParts,
}
})
} else {
// Normal prompt
message = openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: prompt,
// We append the stdin as part of the prompt
// This means we just add the prompt as a message
for _, p := range prompts {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: p,
})
}
}
slog.Debug("Added prompt message")
return message, nil
slog.Debug("Added prompt messages")
return messages, nil
}

func (c *OpenAIClient) buildImageFileData(inputFile string) (imageData string, err error) {
Expand Down
30 changes: 15 additions & 15 deletions pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestSimplePrompt(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestStreamSimplePrompt(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestPromptSaveAsChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -193,7 +193,7 @@ func TestPromptSaveAsChat(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[0].Role)
require.Equal(t, prompt, messages[0].Content)
require.Equal(t, prompt[0], messages[0].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[1].Role)
Expand All @@ -212,7 +212,7 @@ func TestPromptLoadChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Repeat last message"
prompt := []string{"Repeat last message"}
expected := "World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestPromptLoadChat(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[2].Role)
require.Equal(t, prompt, messages[2].Content)
require.Equal(t, prompt[0], messages[2].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[3].Role)
Expand All @@ -277,7 +277,7 @@ func TestPromptWithModifier(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Print Hello World"
prompt := []string{"Print Hello World!"}
response := `echo \"Hello World\"`
expected := `echo "Hello World"`

Expand Down Expand Up @@ -325,7 +325,7 @@ func TestPromptWithModifier(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[1].Role)
require.Equal(t, prompt, messages[1].Content)
require.Equal(t, prompt[0], messages[1].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[2].Role)
Expand All @@ -344,7 +344,7 @@ func TestSimplePromptWithLocalImage(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

Expand Down Expand Up @@ -381,7 +381,7 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

Expand Down Expand Up @@ -423,7 +423,7 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
require.Len(t, messages[0].MultiContent, 2)
// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.NotEmpty(t, messages[0].MultiContent[1].ImageURL.URL)
Expand All @@ -446,7 +446,7 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"

Expand Down Expand Up @@ -488,7 +488,7 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
require.Len(t, messages[0].MultiContent, 2)
// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.Equal(t, inputImage, messages[0].MultiContent[1].ImageURL.URL)
Expand All @@ -510,7 +510,7 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what is the difference between those two pictures?"
prompt := []string{"what is the difference between those two pictures?"}
expected := "The two images provided appear to be identical. Both show the same depiction of a"
inputImageFile := "testdata/marvin.jpg"
inputImageURL := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"
Expand Down Expand Up @@ -554,7 +554,7 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {

// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the URL image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.Equal(t, inputImageURL, messages[0].MultiContent[1].ImageURL.URL)
Expand Down
19 changes: 12 additions & 7 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,30 @@ ls | sort
return err
}

var prompt, input string
var prompts []string
mode := "txt"

if isPiped {
var stdinInput string
slog.Debug("Piped shell detected")
// input is provided via stdin
input, err = fs.ReadString(cmd.InOrStdin())
stdinInput, err = fs.ReadString(cmd.InOrStdin())
if err != nil {
return err
}
if len(input) == 0 {
if len(stdinInput) == 0 {
slog.Debug("No input via pipe provided")
return ErrMissingInput
}
prompt = input
prompts = append(prompts, stdinInput)
// mode is provided via command line args
if len(args) == 1 {
slog.Debug("Mode provided via command line args")
mode = args[0]
} else if len(args) == 2 {
slog.Debug("Mode and prompt provided via command line args")
mode = args[0]
prompts = append(prompts, args[1])
}

} else {
Expand All @@ -201,12 +206,12 @@ ls | sort
} else if len(args) == 1 {
// input is provided via command line args
slog.Debug("No mode provided via command line args - using default mode")
prompt = args[0]
prompts = append(prompts, args[0])
} else {
// input and mode are provided via command line args
slog.Debug("Mode and prompt provided via command line args")
mode = strings.ToLower(args[0])
prompt = args[1]
prompts = append(prompts, args[1])
}
}

Expand All @@ -218,7 +223,7 @@ ls | sort
}

var response string
response, err = client.CreateCompletion(cmd.Context(), root.chat, prompt, mode, root.input)
response, err = client.CreateCompletion(cmd.Context(), root.chat, prompts, mode, root.input)
if err != nil {
return err
}
Expand Down
51 changes: 51 additions & 0 deletions pkg/cli/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,57 @@ func TestRootCmd_SimplePromptViaPipedShellAndModifier(t *testing.T) {
wg.Wait()
}

func TestRootCmd_PipedShellAndModifierAndPrompt(t *testing.T) {
testCtx := testlib.NewTestCtx(t)
testlib.SetAPIKey(t)
mem := &exitMemento{}

var wg sync.WaitGroup
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()

client, err := api.CreateClient(testCtx.Config, stdoutWriter)
require.NoError(t, err)

stdinPrompt := "Say: Hello World!"
prompt := "Replace every 'World' word with 'ChatGPT'"
response := "Hello ChatGPT!"
expected := "Hello ChatGPT!\n"

httpmock.ActivateNonDefault(client.HTTPClient)
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(response)

root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, nil), useMockClient(client))
root.cmd.SetIn(stdinReader)
root.cmd.SetOut(stdoutWriter)

wg.Add(1)
go func() {
defer wg.Done()
_, errWrite := stdinWriter.Write([]byte(stdinPrompt))
require.NoError(t, stdinWriter.Close())
require.NoError(t, errWrite)
}()

wg.Add(1)
go func() {
defer wg.Done()
var buf bytes.Buffer
_, errReader := io.Copy(&buf, stdoutReader)
require.NoError(t, errReader)
require.NoError(t, stdoutReader.Close())
require.Equal(t, expected, buf.String())
}()

root.Execute([]string{"stdin", prompt})
require.Equal(t, 0, mem.code)
require.NoError(t, stdinReader.Close())
require.NoError(t, stdoutWriter.Close())

wg.Wait()
}

func TestRootCmd_SimpleShellPrompt(t *testing.T) {
testCtx := testlib.NewTestCtx(t)
testlib.SetAPIKey(t)
Expand Down
Loading
Loading