Skip to content

Commit

Permalink
feat: better handling of attachments
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Oct 21, 2024
1 parent a55777a commit e868f05
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
39 changes: 25 additions & 14 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
rerank_chunks,
retrieve_chunks,
)
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
Expand All @@ -24,6 +25,8 @@
@cl.on_chat_start
async def start_chat() -> None:
"""Initialize the chat."""
# Disable tokenizes parallelism to avoid the deadlock warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Add Chainlit settings with which the user can configure the RAGLite config.
default_config = RAGLiteConfig()
config = RAGLiteConfig(
Expand Down Expand Up @@ -56,10 +59,10 @@ async def update_config(settings: cl.ChatSettings) -> None:
# Run a search to prime the pipeline if it's a local pipeline.
# TODO: Don't do this for SQLite once we switch from PyNNDescent to sqlite-vec.
if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"):
async with cl.Step(name="initialize", type="retrieval"):
query = "Hello world"
chunk_ids, _ = await async_hybrid_search(query=query, config=config)
_ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)
# async with cl.Step(name="initialize", type="retrieval"):
query = "Hello world"
chunk_ids, _ = await async_hybrid_search(query=query, config=config)
_ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)


@cl.on_message
Expand All @@ -68,17 +71,27 @@ async def handle_message(user_message: cl.Message) -> None:
# Get the config and message history from the user session.
config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call]
# Insert any attached documents into the database.
inline_attachments = []
for file in user_message.elements:
if file.path:
async with cl.Step(name="insert", type="run") as step:
step.input = Path(file.path).name
await async_insert_document(Path(file.path), config=config)
doc_md = document_to_markdown(Path(file.path))
if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3):
# Document is small enough to attach to the context.
inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}")
else:
# Document is too large and must be inserted into the database.
async with cl.Step(name="insert", type="run") as step:
step.input = Path(file.path).name
await async_insert_document(Path(file.path), config=config)
# Append any inline attachments to the user prompt.
user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
for i, attachment in enumerate(inline_attachments)
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(
query=user_message.content, num_results=10, config=config
)
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
Expand All @@ -87,17 +100,15 @@ async def handle_message(user_message: cl.Message) -> None:
# Rerank the chunks.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(
query=user_message.content, chunk_ids=chunks, config=config
)
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
]
# Stream the LLM response.
assistant_message = cl.Message(content="")
async for token in async_rag(
prompt=user_message.content,
prompt=user_prompt,
search=chunks,
messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
config=config,
Expand Down
17 changes: 13 additions & 4 deletions src/raglite/_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,19 @@ def document_to_markdown(doc_path: Path) -> str:
pages = dictionary_output(doc_path, sort=True, keep_chars=False)
doc = "\n\n".join(parsed_pdf_to_markdown(pages))
else:
# Use pandoc for everything else.
import pypandoc

doc = pypandoc.convert_file(doc_path, to="gfm")
try:
# Use pandoc for everything else.
import pypandoc

doc = pypandoc.convert_file(doc_path, to="gfm")
except ImportError as error:
error_message = (
"To convert files to Markdown with pandoc, please install the `pandoc` extra."
)
raise ImportError(error_message) from error
except RuntimeError:
# File format not supported, fall back to reading the text.
doc = doc_path.read_text()
# Improve Markdown quality.
doc = mdformat.text(doc)
return doc
2 changes: 1 addition & 1 deletion src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _max_contexts(
# Reduce the maximum number of contexts to take into account the LLM's context size.
max_context_tokens = (
max_tokens
- sum(len(m["content"]) // 3 for m in messages or []) # Previous messages.
- sum(len(message["content"]) // 3 for message in messages or []) # Previous messages.
- len(RAG_SYSTEM_PROMPT) // 3 # System prompt.
- len(prompt) // 3 # User prompt.
)
Expand Down

0 comments on commit e868f05

Please sign in to comment.