Skip to content

Commit

Permalink
feat: append assistant message for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
jiacai2050 committed Jun 1, 2024
1 parent 104d7b2 commit 17797ed
Showing 6 changed files with 51 additions and 26 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -10,15 +10,18 @@ build: clean
clean:
rm -rf build dist shgpt.egg-info

fix:
ruff check --fix
ruff format

lint:
ruff check
ruff format
ruff format --check

shell:
hatch shell

roles:
@ python download-roles.py

.PHONY: tui repl build clean lint shell roles
.PHONY: tui repl build clean fix lint shell roles
9 changes: 6 additions & 3 deletions shgpt/api/ollama.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ def __init__(self, base_url, model, role, timeout):
self.http_session = TimeoutSession(timeout=timeout)
self.model = model
self.role = role
self.max_messages = MAX_CHAT_MESSAGES
self.system_message = (
None
if role == "default"
@@ -24,8 +25,8 @@ def chat(self, prompt, stream=True):
f"generate: {prompt} to {url} with model {self.model} role {self.role} and stream {stream}"
)
self.messages.append({"role": "user", "content": prompt})
if len(self.messages) > 10:
self.messages = self.messages[-10:]
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages :]
payload = {
"messages": [] if self.system_message is None else [self.system_message],
"model": self.model,
@@ -42,7 +43,9 @@ def chat(self, prompt, stream=True):
answer = ""
for item in r.iter_content(chunk_size=None):
resp = json.loads(item)
if resp["done"] is False:
if resp["done"]:
self.messages.append({"role": "assistant", "content": answer})
else:
content = resp["message"]["content"]
answer += content
yield content
29 changes: 9 additions & 20 deletions shgpt/app.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
import argparse
import sys
from os import makedirs, path
from os import makedirs
from .api.ollama import Ollama
from .version import __version__
from .utils.conf import *
from .utils.common import *
from .tui.app import ShellGPTApp
from .history import History


def init_app():
print(f"Create {CONF_PATH}...")
makedirs(CONF_PATH, exist_ok=True)


def setup_readline():
import readline
import atexit

history = path.join(CONF_PATH, "history")
try:
readline.read_history_file(history)
readline.set_history_length(MAX_HISTORY)
except FileNotFoundError:
debug_print(f"History file not found: {history}")
pass

atexit.register(readline.write_history_file, history)


def read_action(cmd):
if IS_TTY:
action = input("(E)xecute, (Y)ank or Continue(default): ")
@@ -43,8 +29,8 @@ def __init__(self, url, model, role, timeout):
self.is_shell = role == "shell"
self.llm = Ollama(url, model, role, timeout)

def tui(self, initial_prompt):
app = ShellGPTApp(self.llm, initial_prompt)
def tui(self, history, initial_prompt):
app = ShellGPTApp(self.llm, history, initial_prompt)
app.run()

def repl(self, initial_prompt):
@@ -161,10 +147,13 @@ def main():
sys.exit(1)

sg = ShellGPT(args.ollama_url, args.ollama_model, role, args.timeout)
history = History()
if prompt != "":
history.add(prompt)

if app_mode == AppMode.Direct:
sg.infer(prompt)
elif app_mode == AppMode.TUI:
sg.tui(prompt)
sg.tui(history, prompt)
else:
setup_readline()
sg.repl(prompt)
27 changes: 27 additions & 0 deletions shgpt/history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import readline
import atexit
from .utils.conf import CONF_PATH, MAX_HISTORY
from os import path


class History:
def __init__(self, filename=path.join(CONF_PATH, "history")):
self.filename = filename
try:
readline.read_history_file(filename)
readline.set_history_length(MAX_HISTORY)
except FileNotFoundError:
pass
atexit.register(self.save)

def save(self):
try:
readline.write_history_file(self.filename)
except Exception:
pass

def add(self, line):
readline.add_history(line)

def __del__(self):
self.save()
4 changes: 3 additions & 1 deletion shgpt/tui/app.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,9 @@ class ShellGPTApp(App):
Binding("ctrl+r", "run", "Run code block"),
]

def __init__(self, llm, initial_prompt):
def __init__(self, llm, history, initial_prompt):
self.llm = llm
self.history = history
self.has_inflight_req = False
self.initial_prompt = initial_prompt
super().__init__()
@@ -102,6 +103,7 @@ def infer_inner(self) -> None:
return

debug_print(f"infer {prompt}")
self.history.add(prompt)
# llm infer
resp = self.llm.chat(prompt, True)
buf = ""
1 change: 1 addition & 0 deletions shgpt/utils/conf.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
OLLAMA_MODEL = environ.get("SHELLGPT_OLLAMA_MODEL", "llama3")
INFER_TIMEOUT = int(environ.get("SHELLGPT_INFER_TIMEOUT", "15")) # seconds
MAX_HISTORY = int(environ.get("SHELLGPT_MAX_HISTORY", "1000"))
MAX_CHAT_MESSAGES = int(environ.get("SHELLGPT_MAX_CHAT_MESSAGES", "10"))

# Built-in roles for different workloads.
ROLE_CONTENT = {

0 comments on commit 17797ed

Please sign in to comment.