Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enhance shell() to know when it is interactive #66

Merged
merged 17 commits into from
Sep 24, 2024
173 changes: 163 additions & 10 deletions src/goose/toolkit/developer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from subprocess import CompletedProcess, run
from typing import List, Dict
import os
from goose.utils.ask import ask_an_ai
from goose.utils.check_shell_command import is_dangerous_command

from exchange import Message
Expand All @@ -11,6 +11,12 @@
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text
import subprocess
import threading
import queue
import re
import time


from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language, render_template
Expand Down Expand Up @@ -146,11 +152,9 @@ def shell(self, command: str) -> str:
command (str): The shell command to run. It can support multiline statements
if you need to run more than one at a time
"""

self.notifier.status("planning to run shell command")
# Log the command being executed in a visually structured format (Markdown).
# The `.log` method is used here to log the command execution in the application's UX
# this method is dynamically attached to functions in the Goose framework to handle user-visible
# logging and integrates with the overall UI logging system
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))

if is_dangerous_command(command):
Expand All @@ -159,16 +163,165 @@ def shell(self, command: str) -> str:
if not keep_unsafe_command_prompt(command):
raise RuntimeError(
f"The command {command} was rejected as dangerous by the user."
+ " Do not proceed further, instead ask for instructions."
" Do not proceed further, instead ask for instructions."
)
self.notifier.start()
self.notifier.status("running shell command")
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:
output = "Command succeeded"

# Define patterns that might indicate the process is waiting for input
interaction_patterns = [
r"Do you want to", # Common prompt phrase
r"Enter password", # Password prompt
r"Are you sure", # Confirmation prompt
r"\(y/N\)", # Yes/No prompt
r"Press any key to continue", # Awaiting keypress
r"Waiting for input", # General waiting message
r"\?\s", # Prompts starting with '? '
]
compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in interaction_patterns]

# Start the process
proc = subprocess.Popen(
command,
shell=True,
stdin=subprocess.DEVNULL, # Close stdin to prevent the process from waiting for input
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)

# Queues to store the output
stdout_queue = queue.Queue()
stderr_queue = queue.Queue()

# Function to read stdout and stderr without blocking
def reader_thread(pipe: any, output_queue: any) -> None:
try:
for line in iter(pipe.readline, ""):
output_queue.put(line)
# Check for prompt patterns
for pattern in compiled_patterns:
if pattern.search(line):
output_queue.put("INTERACTION_DETECTED")
finally:
pipe.close()

# Start threads to read stdout and stderr
stdout_thread = threading.Thread(target=reader_thread, args=(proc.stdout, stdout_queue))
stderr_thread = threading.Thread(target=reader_thread, args=(proc.stderr, stderr_queue))
stdout_thread.start()
stderr_thread.start()

# Collect output
output = ""
error = ""

# Initialize timer and recent lines list
last_line_time = time.time()
recent_lines = []

# Continuously read output
while True:
# Check if process has terminated
if proc.poll() is not None:
break

# Process output from stdout
try:
while True:
line = stdout_queue.get_nowait()
if line == "INTERACTION_DETECTED":
return (
"Command requires interactive input. If unclear, prompt user for required input "
f"or ask to run outside of goose.\nOutput:\n{output}\nError:\n{error}"
)

else:
output += line
recent_lines.append(line)
recent_lines = recent_lines[-10:] # Keep only the last 10 lines
last_line_time = time.time() # Reset timer
except queue.Empty:
pass

# Process output from stderr
try:
while True:
line = stderr_queue.get_nowait()
if line == "INTERACTION_DETECTED":
return (
"Command requires interactive input. If unclear, prompt user for required input "
f"or ask to run outside of goose.\nOutput:\n{output}\nError:\n{error}"
)
else:
error += line
recent_lines.append(line)
recent_lines = recent_lines[-10:] # Keep only the last 10 lines
last_line_time = time.time() # Reset timer
except queue.Empty:
pass

# Check if no new lines have been received for 10 seconds
if time.time() - last_line_time > 10:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious how this handles cases where log lines are replaced (e.g. print(..., end="\r"). This may be handled differently in stdout vs in a terminal console, but figured I'd float it as something to double check before merging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hrm - might be worth trying out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe would timeout and then consider if it should continue or not

# Call maybe_prompt with the last 2 to 10 recent lines
lines_to_check = recent_lines[-10:]
self.notifier.log(f"Still working\n{''.join(lines_to_check)}")
if not lines_to_check or len(recent_lines) == 0:
lines_to_check = list(["busy..."])
response = ask_an_ai(
input=("\n").join(lines_to_check),
prompt="This looks to see if the lines provided from running a command are potentially waiting"
+ " for something, running a server or something that will not termiinate in a shell."
+ " Return [Yes], if so [No] otherwise.",
exchange=self.exchange_view.accelerator,
)
if response.content[0].text == "[Yes]":
answer = (
f"The command {command} looks to be a long running task. "
f"Do not run it in goose but tell user to run it outside, "
f"unless the user explicitly tells you to run it (and then, "
f"remind them they will need to cancel it as long running)."
)
return answer
else:
self.notifier.log(f"Will continue to run {command}")

# Reset last_line_time to avoid repeated calls
last_line_time = time.time()

# Brief sleep to prevent high CPU usage
threading.Event().wait(0.1)

# Wait for process to complete
proc.wait()

# Ensure all output is read
stdout_thread.join()
stderr_thread.join()

# Retrieve any remaining output from queues
try:
while True:
line = stdout_queue.get_nowait()
output += line
except queue.Empty:
pass

try:
while True:
line = stderr_queue.get_nowait()
error += line
except queue.Empty:
pass

# Determine the result based on the return code
if proc.returncode == 0:
result = "Command succeeded"
else:
output = f"Command failed with returncode {result.returncode}"
return "\n".join([output, result.stdout, result.stderr])
result = f"Command failed with returncode {proc.returncode}"

# Return the combined result and outputs if we made it this far
return "\n".join([result, output, error])

@tool
def write_file(self, path: str, content: str) -> str:
Expand Down