Skip to content

Commit

Permalink
fix: improve chunk and segment ordering (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Oct 13, 2024
1 parent ed215cc commit c9c0c7d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PYTHON_VERSION=310
ACCELERATOR=metal|cu121|cu122|cu123|cu124
PLATFORM=macosx_11_0_arm64|linux_x86_64|win_amd64

# Install llama-python-cpp:
# Install llama-cpp-python:
pip install "https://github.com/abetlen/llama-cpp-python/releases/download/v$LLAMA_CPP_PYTHON_VERSION-$ACCELERATOR/llama_cpp_python-$LLAMA_CPP_PYTHON_VERSION-cp$PYTHON_VERSION-cp$PYTHON_VERSION-$PLATFORM.whl"
```

Expand Down
7 changes: 7 additions & 0 deletions src/raglite/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def retrieve_chunks(
engine = create_database_engine(config)
with Session(engine) as session:
chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all())
chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
return chunks


Expand Down Expand Up @@ -217,6 +218,12 @@ def retrieve_segments(
segments.append(segment)
segment = [chunk]
segments.append(segment)
# Rank segments according to the aggregate relevance of their chunks.
chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
segments.sort(
key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment),
reverse=True,
)
# Convert the segments into strings.
segments = [
segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc]
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_split_sentences.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def split_sentences(doc: str, max_len: int | None = None) -> list[str]:
try:
nlp = spacy.load("xx_sent_ud_sm")
except OSError as error:
error_message = "Please install `xx_sent_ud_sm` with `pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.8.0/xx_sent_ud_sm-3.8.0-py3-none-any.whl`."
error_message = "Please install `xx_sent_ud_sm` with `pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl`."
raise ImportError(error_message) from error
nlp.add_pipe("_mark_additional_sentence_boundaries", before="senter")
sentences = [sent.text_with_ws for sent in nlp(doc).sents if sent.text.strip()]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def test_reranker(
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
# Rerank the chunks given an inverted chunk order.
reranked_chunks = rerank(query, chunks[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
assert reranked_chunks[0] == chunks[0]
# Test that we can also rerank given the chunk_ids only.
reranked_chunks = rerank(query, chunk_ids[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
assert reranked_chunks[0] == chunks[0]
1 change: 1 addition & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
Expand Down

0 comments on commit c9c0c7d

Please sign in to comment.