Skip to content

Commit

Permalink
First -i/--image input prototype, refs #331
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Mar 4, 2024
1 parent de6af1c commit eaf50d8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
24 changes: 22 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
set_alias,
remove_alias,
)

from .migrations import migrate
from .plugins import pm
import base64
import pathlib
import pydantic
import re
import readline
from runpy import run_module
import shutil
Expand All @@ -36,6 +36,7 @@
import sys
import textwrap
from typing import cast, Optional, Iterable, Union, Tuple
import urllib
import warnings
import yaml

Expand Down Expand Up @@ -83,10 +84,28 @@ def cli():
"""


class FileOrUrl(click.ParamType):
name = "file_or_url"

def convert(self, value, param, ctx):
if value == "-":
return sys.stdin
if re.match(r"^https?://", value):
return urllib.request.urlopen(value)
# Use pathlib to detect if it is a readable file
path = pathlib.Path(value)
if path.exists() and path.is_file():
return path.open("rb")
self.fail(f"{value} is not a valid file path or URL", param, ctx)


@cli.command(name="prompt")
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"images", "-i", "--image", type=FileOrUrl(), multiple=True, help="Images for prompt"
)
@click.option(
"options",
"-o",
Expand Down Expand Up @@ -126,6 +145,7 @@ def prompt(
prompt,
system,
model_id,
images,
options,
template,
param,
Expand Down Expand Up @@ -272,7 +292,7 @@ def read_prompt():
prompt_method = conversation.prompt

try:
response = prompt_method(prompt, system, **validated_options)
response = prompt_method(prompt, system, images=images, **validated_options)
if should_stream:
for chunk in response:
print(chunk, end="")
Expand Down
21 changes: 21 additions & 0 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from llm import EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
Expand Down Expand Up @@ -31,6 +32,7 @@ def register_models(register):
register(Chat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-preview"), aliases=("gpt-4-turbo", "4-turbo", "4t"))
register(Chat("gpt-4-vision-preview", images=True), aliases=("4v",))
# The -instruct completion model
register(
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
Expand Down Expand Up @@ -264,6 +266,7 @@ def __init__(
api_version=None,
api_engine=None,
headers=None,
images=False,
):
self.model_id = model_id
self.key = key
Expand All @@ -273,6 +276,7 @@ def __init__(
self.api_version = api_version
self.api_engine = api_engine
self.headers = headers
self.supports_images = images

def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)
Expand All @@ -297,6 +301,23 @@ def execute(self, prompt, stream, response, conversation=None):
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
if prompt.images:
for image in prompt.images:
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,{}".format(
base64.b64encode(image.read()).decode("utf-8")
)
},
}
],
}
)
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()
Expand Down
26 changes: 24 additions & 2 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,38 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from abc import ABC, abstractmethod
import json
from pathlib import Path
from pydantic import BaseModel
from ulid import ULID

CONVERSATION_NAME_LENGTH = 32


@dataclass
class PromptImage:
filepath: Optional[Path]
url: Optional[str]
bytes: Optional[bytes]


@dataclass
class Prompt:
prompt: str
model: "Model"
system: Optional[str]
prompt_json: Optional[str]
options: "Options"
images: Optional[List[PromptImage]]

def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
def __init__(
self, prompt, model, system=None, images=None, prompt_json=None, options=None
):
self.prompt = prompt
self.model = model
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
self.images = images


@dataclass
Expand Down Expand Up @@ -246,6 +258,7 @@ class Model(ABC, _get_key_mixin):
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
supports_images: bool = False

class Options(_Options):
pass
Expand All @@ -272,10 +285,19 @@ def prompt(
prompt: Optional[str],
system: Optional[str] = None,
stream: bool = True,
images: Optional[List[PromptImage]] = None,
**options
):
if images and not self.supports_images:
raise ValueError("This model does not support images")
return self.response(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
Prompt(
prompt,
system=system,
model=self,
images=images,
options=self.Options(**options),
),
stream=stream,
)

Expand Down

0 comments on commit eaf50d8

Please sign in to comment.