Skip to content

Commit

Permalink
fix ctrl+y don't works
Browse files Browse the repository at this point in the history
  • Loading branch information
jiacai2050 committed Jun 1, 2024
1 parent a0c0431 commit 104d7b2
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 59 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ description = "Chat with LLM in your terminal, be it shell generator, story tell
readme = "README.md"
keywords = ["llm", "shell", "gpt"]
license = "GPL-3.0"
requires-python = ">=3.0.0"
dependencies = [
"requests",
"pyperclip",
Expand Down
19 changes: 0 additions & 19 deletions shgpt/api/history.py

This file was deleted.

35 changes: 14 additions & 21 deletions shgpt/api/ollama.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
import json
import os
from ..utils.http import TimeoutSession
from ..utils.common import *
from ..utils.conf import *
from .history import DummyHistory, FileHistory

HIST_SEP = "=========="


# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class Ollama(object):
def __init__(self, base_url, model, role, timeout):
self.base_url = base_url
self.http_session = TimeoutSession(timeout=timeout)
if ENABLE_HISTORY:
self.history_file = FileHistory(os.path.join(CONF_PATH, "history"))
else:
self.history_file = DummyHistory()
self.model = model
self.role = role
self.system_message = (
None
if role == "default"
else {"role": "system", "content": ROLE_CONTENT.get(self.role)}
)
self.messages = []

def generate(self, prompt, stream=True):
def chat(self, prompt, stream=True):
url = self.base_url + "/api/chat"
debug_print(
f"generate: {prompt} to {url} with model {self.model} role {self.role} and stream {stream}"
)
system_content = ROLE_CONTENT.get(self.role, self.role)
self.messages.append({"role": "user", "content": prompt})
if len(self.messages) > 10:
self.messages = self.messages[-10:]
payload = {
"messages": [
{"role": "system", "content": system_content, "name": "ShellGPT"},
{"role": "user", "content": prompt, "name": "user"},
],
"messages": [] if self.system_message is None else [self.system_message],
"model": self.model,
"stream": stream,
}
for m in self.messages:
payload["messages"].append(m)

debug_print(f"Infer message: {payload}")
r = self.http_session.post(url, json=payload, stream=stream)
if r.status_code != 200:
Expand All @@ -46,10 +46,3 @@ def generate(self, prompt, stream=True):
content = resp["message"]["content"]
answer += content
yield content
else:
self.history_file.write(rf"""{now_ms()},{resp['eval_duration']},{resp['eval_count']}
{prompt}
{HIST_SEP}
{answer}
{HIST_SEP}
""")
22 changes: 19 additions & 3 deletions shgpt/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import sys
from os import makedirs
from os import makedirs, path
from .api.ollama import Ollama
from .version import __version__
from .utils.conf import *
Expand All @@ -13,6 +13,21 @@ def init_app():
makedirs(CONF_PATH, exist_ok=True)


def setup_readline():
import readline
import atexit

history = path.join(CONF_PATH, "history")
try:
readline.read_history_file(history)
readline.set_history_length(MAX_HISTORY)
except FileNotFoundError:
debug_print(f"History file not found: {history}")
pass

atexit.register(readline.write_history_file, history)


def read_action(cmd):
if IS_TTY:
action = input("(E)xecute, (Y)ank or Continue(default): ")
Expand All @@ -25,7 +40,6 @@ def read_action(cmd):

class ShellGPT(object):
def __init__(self, url, model, role, timeout):
self.role = role
self.is_shell = role == "shell"
self.llm = Ollama(url, model, role, timeout)

Expand Down Expand Up @@ -54,7 +68,7 @@ def infer(self, prompt):

buf = ""
try:
for r in self.llm.generate(prompt):
for r in self.llm.chat(prompt):
buf += r
if self.is_shell is False:
print(r, end="")
Expand All @@ -78,6 +92,7 @@ def main():
prog="shgpt",
description="Chat with LLM in your terminal, be it shell generator, story teller, linux-terminal, etc.",
)

parser.add_argument(
"-V", "--version", action="version", version="%(prog)s " + __version__
)
Expand Down Expand Up @@ -151,4 +166,5 @@ def main():
elif app_mode == AppMode.TUI:
sg.tui(prompt)
else:
setup_readline()
sg.repl(prompt)
26 changes: 13 additions & 13 deletions shgpt/tui/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from textual.app import App, ComposeResult
from textual.app import App, ComposeResult, Binding
from textual.widgets import Header, Footer, Static, TextArea, Button
from typing import Optional
from ..utils.common import *
Expand All @@ -24,9 +24,9 @@ def compose(self) -> ComposeResult:


class ButtonDispatch(Static):
def __init__(self, copy_handler, run_handler):
def __init__(self, yank_handler, run_handler):
super().__init__()
self.copy_handler = copy_handler
self.yank_handler = yank_handler
self.run_handler = run_handler

def compose(self) -> ComposeResult:
Expand All @@ -36,18 +36,18 @@ def compose(self) -> ComposeResult:
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Event handler called when a button is pressed."""
if event.button.id == "copy":
self.copy_handler()
self.yank_handler()
elif event.button.id == "run":
self.run_handler()


class ShellGPTApp(App):
CSS_PATH = "app.tcss"
BINDINGS = [
("ctrl+j", "infer", "Infer answer"),
("ctrl+d", "toggle_dark", "Toggle dark mode"),
("ctrl+y", "Yank", "Yank code block"),
("ctrl+r", "run", "Run code block"),
Binding("ctrl+j", "infer", "Infer answer"),
Binding("ctrl+d", "toggle_dark", "Toggle dark mode"),
Binding("ctrl+y", "yank", "Yank code block", priority=True),
Binding("ctrl+r", "run", "Run code block"),
]

def __init__(self, llm, initial_prompt):
Expand All @@ -62,7 +62,7 @@ def compose(self) -> ComposeResult:
yield PromptInput(self.initial_prompt)
yield AnswerOutput()
yield ButtonDispatch(
lambda: self.action_copy(),
lambda: self.action_yank(),
lambda: self.action_run(),
)
yield CommandOutput()
Expand All @@ -79,7 +79,7 @@ def action_infer(self) -> None:

self.has_inflight_req = True
try:
self.action_infer_inner()
self.infer_inner()
except Exception as e:
answer_output = self.query_one("#answer_output")
answer_output.load_text(f"Error when infer: {e}")
Expand All @@ -96,14 +96,14 @@ def get_answer_output(self) -> Optional[str]:
text = out.text.strip()
return None if text == "" else text

def action_infer_inner(self) -> None:
def infer_inner(self) -> None:
prompt = self.get_prompt_input()
if prompt is None:
return

debug_print(f"infer {prompt}")
# llm infer
resp = self.llm.generate(prompt, True)
resp = self.llm.chat(prompt, True)
buf = ""
for item in resp:
buf += item
Expand All @@ -112,7 +112,7 @@ def action_infer_inner(self) -> None:
answer_output = self.query_one("#answer_output")
answer_output.load_text(buf)

def action_copy(self) -> None:
def action_yank(self) -> None:
text = self.get_answer_output()
if text is None:
return
Expand Down
6 changes: 3 additions & 3 deletions shgpt/utils/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
OLLAMA_URL = environ.get("SHELLGPT_OLLAMA_URL", "http://127.0.0.1:11434")
OLLAMA_MODEL = environ.get("SHELLGPT_OLLAMA_MODEL", "llama3")
INFER_TIMEOUT = int(environ.get("SHELLGPT_INFER_TIMEOUT", "15")) # seconds
ENABLE_HISTORY = int(environ.get("SHELLGPT_ENABLE_HISTORY", "0")) == 1
MAX_HISTORY = int(environ.get("SHELLGPT_MAX_HISTORY", "1000"))

# There are different roles for different types of prompts
# Built-in roles for different workloads.
ROLE_CONTENT = {
"default": "",
"default": None,
"code": """
Provide only code as output without any description.
Provide only code in plain text format without Markdown formatting.
Expand Down

0 comments on commit 104d7b2

Please sign in to comment.