Skip to content

Commit

Permalink
fix: refactor pricing model and context length handling
Browse files Browse the repository at this point in the history
Refactor to move pricing model definition to a separate class, enhancing modularity and maintainability. This change also streamlines the handling of context length across the CLI and API interactions, improving consistency and reducing redundancy.
  • Loading branch information
liblaf committed Jun 21, 2024
1 parent 775c921 commit 32aace9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
14 changes: 7 additions & 7 deletions src/aic/api/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import pydantic


class Pricing(pydantic.BaseModel):
prompt: float
completion: float
request: float
image: float


class Model(pydantic.BaseModel):
id: str
name: str
description: str

class Pricing(pydantic.BaseModel):
prompt: float
completion: float
request: float
image: float

pricing: Pricing
context_length: int

Expand Down
3 changes: 3 additions & 0 deletions src/aic/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(
base_url: Annotated[Optional[str], typer.Option(envvar="OPENAI_BASE_URL")] = None, # noqa: UP007
model: Annotated[Optional[str], typer.Option()] = None, # noqa: UP007
max_tokens: Annotated[Optional[int], typer.Option()] = None, # noqa: UP007
context_length: Annotated[Optional[int], typer.Option()] = None, # noqa: UP007
verify: Annotated[bool, typer.Option()] = True,
) -> None:
log.init()
Expand All @@ -36,4 +37,6 @@ def main(
cfg.model = model
if max_tokens is not None:
cfg.max_tokens = max_tokens
if context_length is not None:
cfg.context_length = context_length
cli_main.main(*pathspec, cfg=cfg, verify=verify)
6 changes: 5 additions & 1 deletion src/aic/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def main(*pathspec: str, cfg: config.Config, verify: bool) -> None:
git.status(*pathspec)
diff: str = git.diff(*pathspec)
model_info: openrouter.Model = openrouter.get_model(cfg.model)
if cfg.pricing is not None:
model_info.pricing = cfg.pricing
if cfg.context_length is not None:
model_info.context_length = cfg.context_length
client = openai.OpenAI(api_key=cfg.api_key, base_url=cfg.base_url)
prompt_builder = prompt.Prompt()
prompt_builder.ask()
Expand Down Expand Up @@ -66,7 +70,7 @@ def format_tokens(prompt_tokens: int, completion_tokens: int) -> str:


def format_cost(
prompt_tokens: int, completion_tokens: int, pricing: openrouter.Model.Pricing
prompt_tokens: int, completion_tokens: int, pricing: openrouter.Pricing
) -> str:
prompt_cost: float = prompt_tokens * pricing.prompt
completion_cost: float = completion_tokens * pricing.completion
Expand Down
11 changes: 3 additions & 8 deletions src/aic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@
import pydantic
import typer

from aic.api import openrouter


class Config(pydantic.BaseModel):
api_key: str | None = None
base_url: str | None = None
model: str = "gpt-3.5-turbo"
max_tokens: int = 500

class Pricing(pydantic.BaseModel):
prompt: float | None = None
completion: float | None = None
request: float | None = None
image: float | None = None

pricing: Pricing | None = None
pricing: openrouter.Pricing | None = None
context_length: int | None = None


Expand Down

0 comments on commit 32aace9

Please sign in to comment.