Skip to content

Commit

Permalink
fix: image path extract
Browse files Browse the repository at this point in the history
  • Loading branch information
jiacai2050 committed Jun 3, 2024
1 parent b57fb6f commit a551d64
Showing 9 changed files with 64 additions and 23 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -21,7 +21,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- name: format
uses: chartboost/ruff-action@v1
with:
args: 'format --check'
- uses: chartboost/ruff-action@v1
- name: unit test
uses: chartboost/ruff-action@v1
with:
args: 'run python -m unittest'
- name: lint
uses: chartboost/ruff-action@v1
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -21,6 +21,9 @@ lint:
shell:
hatch shell

ut:
hatch run python -m unittest -v

roles:
@ python download-roles.py

3 changes: 2 additions & 1 deletion shgpt/api/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from ..utils.http import TimeoutSession
from ..utils.common import debug_print, prepare_prompt
from ..utils.common import base64_image, debug_print, prepare_prompt
from ..utils.conf import MAX_CHAT_MESSAGES, OLLAMA_IMAGE_MODEL, ROLE_CONTENT


@@ -29,6 +29,7 @@ def chat(self, prompt, stream=True):
after, imgs = prepare_prompt(prompt)
model = self.model
if len(imgs) > 0:
imgs = [base64_image(img) for img in imgs]
self.messages.append({'role': 'user', 'content': after, 'images': imgs})
model = OLLAMA_IMAGE_MODEL
else:
4 changes: 2 additions & 2 deletions shgpt/app.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
execute_cmd,
copy_text,
read_stdin,
get_executable_script,
extract_code,
set_verbose,
debug_print,
AppMode,
@@ -66,7 +66,7 @@ def infer(self, prompt):
print(r, end='')

if self.is_shell:
shell = get_executable_script(buf)
shell = extract_code(buf)
if shell is not None:
buf = shell
print(buf)
4 changes: 2 additions & 2 deletions shgpt/tui/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textual.app import App, ComposeResult, Binding
from textual.widgets import Header, Footer, Static, TextArea, Button
from typing import Optional
from ..utils.common import copy_text, execute_cmd, debug_print, get_executable_script
from ..utils.common import copy_text, execute_cmd, debug_print, extract_code


class PromptInput(Static):
@@ -110,7 +110,7 @@ def infer_inner(self) -> None:
for item in resp:
buf += item

script = get_executable_script(buf)
script = extract_code(buf)
if script is not None:
buf = script

23 changes: 8 additions & 15 deletions shgpt/utils/common.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
import sys
import pyperclip

from shgpt.utils.conf import CONF_PATH, IS_TTY
from shgpt.utils.conf import DEFAULT_IMAGE_DIR, IS_TTY

IS_VERBOSE = False

@@ -29,12 +29,12 @@ def debug_print(msg):
print(msg)


def get_executable_script(text: str) -> Optional[str]:
script_blocks = re.findall('```(.*?)\n(.*?)```', text, re.DOTALL)
if len(script_blocks) == 0:
def extract_code(text: str) -> Optional[str]:
code = re.findall('```(?:.*?)\n(.*?)```', text, re.DOTALL)
if len(code) == 0:
return None
else:
return script_blocks[0][1].strip()
return code[0].strip()


def now_ms():
@@ -66,7 +66,7 @@ def base64_image(image_path: str) -> str:


# https://www.debuggex.com/r/6b2cfvu8bb_stYGu
FILE_PATH_RE = re.compile(r'(\/|@@)(.*?)(?:\s|$)', re.I | re.M)
FILE_PATH_RE = re.compile(r' (\/|@@)(.*?)(?:\s|$)', re.I | re.M)


def extract_paths(txt):
@@ -77,17 +77,10 @@ def gen_path(prefix, left):
if prefix == '/':
return prefix + left
else:
return os.path.join(CONF_PATH, left)
return os.path.join(DEFAULT_IMAGE_DIR, left)


def prepare_prompt(raw):
imgs = [
base64_image(gen_path(prefix, path)) for (prefix, path) in extract_paths(raw)
]
imgs = [gen_path(prefix, path) for (prefix, path) in extract_paths(raw)]
after = raw if len(imgs) == 0 else re.sub(FILE_PATH_RE, '', raw)
return after, imgs


if __name__ == '__main__':
print(prepare_prompt('hello world /tmp/xxx.png @@xxx.png'))
print(prepare_prompt('hello world!'))
2 changes: 1 addition & 1 deletion shgpt/utils/conf.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
OLLAMA_MODEL = environ.get('SHELLGPT_OLLAMA_MODEL', 'llama3')
OLLAMA_IMAGE_MODEL = environ.get('SHELLGPT_OLLAMA_IMAGE_MODEL', 'llava')
OLLAMA_TEMPERATURE = float(environ.get('SHELLGPT_OLLAMA_TEMPERATURE', '0.8'))
DEFAULT_IMAGE_DIR = path.expanduser(environ.get('SHELLGPT_IMAGE_DIR', '~/Downloads/'))
DEFAULT_IMAGE_DIR = path.expanduser(environ.get('SHELLGPT_IMAGE_DIR', '~/Pictures'))
INFER_TIMEOUT = int(environ.get('SHELLGPT_INFER_TIMEOUT', '15')) # seconds
MAX_HISTORY = int(environ.get('SHELLGPT_MAX_HISTORY', '1000'))
MAX_CHAT_MESSAGES = int(environ.get('SHELLGPT_MAX_CHAT_MESSAGES', '10'))
Empty file added tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from shgpt.utils.common import extract_code, prepare_prompt
from shgpt.utils.conf import DEFAULT_IMAGE_DIR
from os import path
import unittest


class ATestCommon(unittest.TestCase):
def test_prepare_prompt(self):
for args, expected in [
('hello', ('hello', [])),
('hello @@test.png', ('hello', [path.join(DEFAULT_IMAGE_DIR, 'test.png')])),
('hello /tmp/test.png', ('hello', ['/tmp/test.png'])),
('hello/tmp/test.png', ('hello/tmp/test.png', [])),
('hello@@test.png', ('hello@@test.png', [])),
]:
self.assertEqual(prepare_prompt(args), expected)

def test_extract_code(self):
for args, expected in [
('1+1', None),
(
"""
```
1+1
```
""",
'1+1',
),
(
"""
```python
1+1
```
""",
'1+1',
),
]:
self.assertEqual(extract_code(args), expected)

0 comments on commit a551d64

Please sign in to comment.