Skip to content

Commit

Permalink
feat: automatically adjust number of RAG contexts (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Aug 18, 2024
1 parent 7135e98 commit 5f38bd9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/radix-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true)

# RAGLite
# 🧵 RAGLite

RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with SQLite.

Expand Down
18 changes: 6 additions & 12 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@ def default_llm() -> Llama:
# Llama-3.1-8B-instruct on GPU.
repo_id = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct
filename = "*Q4_K_M.gguf"
n_ctx = 8192
else:
# Phi-3.1-mini-128k-instruct on CPU.
repo_id = "bartowski/Phi-3.1-mini-128k-instruct-GGUF" # https://huggingface.co/microsoft/Phi-3-mini-128k-instruct
filename = "*Q4_K_M.gguf"
n_ctx = 4096
# Load the LLM.
llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
n_ctx=8192, # 0 = Use the model's context size (default is 512).
n_gpu_layers=-1, # -1 = Offload all layers to the GPU (default is 0).
verbose=False,
repo_id=repo_id, filename=filename, n_ctx=n_ctx, n_gpu_layers=-1, verbose=False
)
# Enable caching.
llm.set_cache(LlamaRAMCache())
Expand All @@ -44,14 +42,10 @@ def default_embedder() -> Llama:
else:
repo_id = "yishan-wang/snowflake-arctic-embed-m-v1.5-Q8_0-GGUF" # https://github.com/Snowflake-Labs/arctic-embed
filename = "*q8_0.gguf"
# Load the embedder.
# Load the embedder. Setting n_ctx to 0 means that we use the model's context size (the default
# value for n_ctx is 512, irrespective of the model).
embedder = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
n_ctx=0, # 0 = Use the model's context size (default is 512).
n_gpu_layers=-1, # -1 = Offload all layers to the GPU (default is 0).
verbose=False,
embedding=True,
repo_id=repo_id, filename=filename, n_ctx=0, n_gpu_layers=-1, verbose=False, embedding=True
)
return embedder

Expand Down
16 changes: 11 additions & 5 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@
def rag(
prompt: str,
*,
num_contexts: int = 5,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: Callable[[str], tuple[list[int], list[float]]] = hybrid_search,
config: RAGLiteConfig | None = None,
) -> Iterator[str]:
"""Retrieval-augmented generation."""
# Retrieve relevant chunks.
# Reduce the maximum number of contexts to take into account the LLM's context size.
config = config or RAGLiteConfig()
chunk_rowids, _ = search(prompt, num_results=num_contexts, config=config) # type: ignore[call-arg]
chunks = retrieve_segments(chunk_rowids, neighbors=context_neighbors)
max_tokens = config.llm.n_ctx() - 256 # Account for the system and user prompts.
max_tokens_per_context = round(1.2 * (config.chunk_max_size // 4))
max_tokens_per_context *= 1 + len(context_neighbors or [])
max_contexts = min(max_contexts, max_tokens // max_tokens_per_context)
# Retrieve relevant contexts.
chunk_rowids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg]
segments = retrieve_segments(chunk_rowids, neighbors=context_neighbors)
# Respond with an LLM.
contexts = "\n\n".join(
f'<context index="{i}">\n{chunk.strip()}\n</context>' for i, chunk in enumerate(chunks)
f'<context index="{i}">\n{segment.strip()}\n</context>'
for i, segment in enumerate(segments)
)
system_prompt = f"""
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
Expand Down

0 comments on commit 5f38bd9

Please sign in to comment.