Skip to content

Commit

Permalink
Add experimental cache
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Apr 15, 2023
1 parent a6372a7 commit 92c0771
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
69 changes: 65 additions & 4 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from .llama_types import *


class LlamaCache:
"""Cache for a llama.cpp model.
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""

pass


class Llama:
"""High-level Python wrapper for a llama.cpp model."""

Expand Down Expand Up @@ -82,6 +91,14 @@ def __init__(
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.

### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###

self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)

if not os.path.exists(model_path):
Expand Down Expand Up @@ -135,6 +152,14 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
output += llama_cpp.llama_token_to_str(self.ctx, token)
return output

def set_cache(self, cache: Optional[LlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self._cache = cache

def reset(self):
"""Reset the model state."""
self.last_n_tokens_data.extend(
Expand Down Expand Up @@ -245,6 +270,17 @@ def generate(
The generated tokens.
"""
assert self.ctx is not None
### HACK
if (
reset
and self._cache
and len(self.tokens) > 0
and self.tokens == tokens[: len(self.tokens)]
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
reset = False
###
if reset:
self.reset()
while True:
Expand Down Expand Up @@ -361,13 +397,29 @@ def _create_completion(
"logprobs is not supported for models created with logits_all=False"
)

### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###

finish_reason = "length"
for token in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
reset=reset,
):
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
Expand Down Expand Up @@ -397,6 +449,9 @@ def _create_completion(
break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield {
"id": completion_id,
"object": "text_completion",
Expand All @@ -418,6 +473,9 @@ def _create_completion(
break

if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield {
"id": completion_id,
"object": "text_completion",
Expand All @@ -434,13 +492,16 @@ def _create_completion(
}
return

text = text.decode("utf-8")
### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")

if echo:
text = prompt + text
text_str = prompt + text_str

if suffix is not None:
text = text + suffix
text_str = text_str + suffix

logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
Expand Down Expand Up @@ -493,7 +554,7 @@ def _create_completion(
"model": self.model_path,
"choices": [
{
"text": text,
"text": text_str,
"index": 0,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,
Expand Down
5 changes: 4 additions & 1 deletion llama_cpp/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Settings(BaseSettings):
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature


app = FastAPI(
Expand All @@ -60,6 +61,9 @@ class Settings(BaseSettings):
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()


Expand All @@ -68,7 +72,6 @@ def get_llama():
yield llama



class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
Expand Down

0 comments on commit 92c0771

Please sign in to comment.