From 4373be09eea52c7151bb3a84897ec5d51cc88ca9 Mon Sep 17 00:00:00 2001 From: jiacai2050 Date: Mon, 3 Jun 2024 21:03:14 +0800 Subject: [PATCH] fix: image path extract --- .github/workflows/ci.yml | 16 ++++++++++------ Makefile | 3 +++ shgpt/api/ollama.py | 3 ++- shgpt/app.py | 4 ++-- shgpt/tui/app.py | 4 ++-- shgpt/utils/common.py | 23 ++++++++--------------- shgpt/utils/conf.py | 2 +- tests/__init__.py | 0 tests/test_common.py | 38 ++++++++++++++++++++++++++++++++++++++ 9 files changed, 66 insertions(+), 27 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_common.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb1cec1..8656f6a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,12 +16,16 @@ on: - '**.toml' jobs: - ruff: - name: Ruff check + ci: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: chartboost/ruff-action@v1 - with: - args: 'format --check' - - uses: chartboost/ruff-action@v1 + - name: Install + run: | + pip install hatch + - name: lint + run: | + make lint + - name: unit test + run: + make ut diff --git a/Makefile b/Makefile index d800c9c..ecca6be 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,9 @@ lint: shell: hatch shell +ut: + hatch run python -m unittest -v + roles: @ python download-roles.py diff --git a/shgpt/api/ollama.py b/shgpt/api/ollama.py index d421b44..9fa7729 100644 --- a/shgpt/api/ollama.py +++ b/shgpt/api/ollama.py @@ -1,6 +1,6 @@ import json from ..utils.http import TimeoutSession -from ..utils.common import debug_print, prepare_prompt +from ..utils.common import base64_image, debug_print, prepare_prompt from ..utils.conf import MAX_CHAT_MESSAGES, OLLAMA_IMAGE_MODEL, ROLE_CONTENT @@ -29,6 +29,7 @@ def chat(self, prompt, stream=True): after, imgs = prepare_prompt(prompt) model = self.model if len(imgs) > 0: + imgs = [base64_image(img) for img in imgs] self.messages.append({'role': 'user', 'content': after, 'images': imgs}) model = OLLAMA_IMAGE_MODEL else: diff --git a/shgpt/app.py b/shgpt/app.py index 53cf0da..fc9a1f7 100644 --- a/shgpt/app.py +++ b/shgpt/app.py @@ -16,7 +16,7 @@ execute_cmd, copy_text, read_stdin, - get_executable_script, + extract_code, set_verbose, debug_print, AppMode, @@ -66,7 +66,7 @@ def infer(self, prompt): print(r, end='') if self.is_shell: - shell = get_executable_script(buf) + shell = extract_code(buf) if shell is not None: buf = shell print(buf) diff --git a/shgpt/tui/app.py b/shgpt/tui/app.py index 2398787..cfb6ada 100644 --- a/shgpt/tui/app.py +++ b/shgpt/tui/app.py @@ -1,7 +1,7 @@ from textual.app import App, ComposeResult, Binding from textual.widgets import Header, Footer, Static, TextArea, Button from typing import Optional -from ..utils.common import copy_text, execute_cmd, debug_print, get_executable_script +from ..utils.common import copy_text, execute_cmd, debug_print, extract_code class PromptInput(Static): @@ -110,7 +110,7 @@ def infer_inner(self) -> None: for item in resp: buf += item - script = get_executable_script(buf) + script = extract_code(buf) if script is not None: buf = script diff --git a/shgpt/utils/common.py b/shgpt/utils/common.py index 4f68c1c..2893e7d 100644 --- a/shgpt/utils/common.py +++ b/shgpt/utils/common.py @@ -8,7 +8,7 @@ import sys import pyperclip -from shgpt.utils.conf import CONF_PATH, IS_TTY +from shgpt.utils.conf import DEFAULT_IMAGE_DIR, IS_TTY IS_VERBOSE = False @@ -29,12 +29,12 @@ def debug_print(msg): print(msg) -def get_executable_script(text: str) -> Optional[str]: - script_blocks = re.findall('```(.*?)\n(.*?)```', text, re.DOTALL) - if len(script_blocks) == 0: +def extract_code(text: str) -> Optional[str]: + code = re.findall('```(?:.*?)\n(.*?)```', text, re.DOTALL) + if len(code) == 0: return None else: - return script_blocks[0][1].strip() + return code[0].strip() def now_ms(): @@ -66,7 +66,7 @@ def base64_image(image_path: str) -> str: # https://www.debuggex.com/r/6b2cfvu8bb_stYGu -FILE_PATH_RE = re.compile(r'(\/|@@)(.*?)(?:\s|$)', re.I | re.M) +FILE_PATH_RE = re.compile(r' (\/|@@)(.*?)(?:\s|$)', re.I | re.M) def extract_paths(txt): @@ -77,17 +77,10 @@ def gen_path(prefix, left): if prefix == '/': return prefix + left else: - return os.path.join(CONF_PATH, left) + return os.path.join(DEFAULT_IMAGE_DIR, left) def prepare_prompt(raw): - imgs = [ - base64_image(gen_path(prefix, path)) for (prefix, path) in extract_paths(raw) - ] + imgs = [gen_path(prefix, path) for (prefix, path) in extract_paths(raw)] after = raw if len(imgs) == 0 else re.sub(FILE_PATH_RE, '', raw) return after, imgs - - -if __name__ == '__main__': - print(prepare_prompt('hello world /tmp/xxx.png @@xxx.png')) - print(prepare_prompt('hello world!')) diff --git a/shgpt/utils/conf.py b/shgpt/utils/conf.py index cc8d32e..7c5c048 100644 --- a/shgpt/utils/conf.py +++ b/shgpt/utils/conf.py @@ -9,7 +9,7 @@ OLLAMA_MODEL = environ.get('SHELLGPT_OLLAMA_MODEL', 'llama3') OLLAMA_IMAGE_MODEL = environ.get('SHELLGPT_OLLAMA_IMAGE_MODEL', 'llava') OLLAMA_TEMPERATURE = float(environ.get('SHELLGPT_OLLAMA_TEMPERATURE', '0.8')) -DEFAULT_IMAGE_DIR = path.expanduser(environ.get('SHELLGPT_IMAGE_DIR', '~/Downloads/')) +DEFAULT_IMAGE_DIR = path.expanduser(environ.get('SHELLGPT_IMAGE_DIR', '~/Pictures')) INFER_TIMEOUT = int(environ.get('SHELLGPT_INFER_TIMEOUT', '15')) # seconds MAX_HISTORY = int(environ.get('SHELLGPT_MAX_HISTORY', '1000')) MAX_CHAT_MESSAGES = int(environ.get('SHELLGPT_MAX_CHAT_MESSAGES', '10')) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..b639cb2 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,38 @@ +from shgpt.utils.common import extract_code, prepare_prompt +from shgpt.utils.conf import DEFAULT_IMAGE_DIR +from os import path +import unittest + + +class ATestCommon(unittest.TestCase): + def test_prepare_prompt(self): + for args, expected in [ + ('hello', ('hello', [])), + ('hello @@test.png', ('hello', [path.join(DEFAULT_IMAGE_DIR, 'test.png')])), + ('hello /tmp/test.png', ('hello', ['/tmp/test.png'])), + ('hello/tmp/test.png', ('hello/tmp/test.png', [])), + ('hello@@test.png', ('hello@@test.png', [])), + ]: + self.assertEqual(prepare_prompt(args), expected) + + def test_extract_code(self): + for args, expected in [ + ('1+1', None), + ( + """ + ``` + 1+1 + ``` + """, + '1+1', + ), + ( + """ + ```python + 1+1 + ``` + """, + '1+1', + ), + ]: + self.assertEqual(extract_code(args), expected)