diff --git a/README.md b/README.md index d500cf8..9ce4633 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Static Badge +## ✨ Latest News + +- [11/06/2024]: Our paper is available on arXiv. You can access it [here](https://arxiv.org/abs/2411.02959). +- [11/05/2024]: The open-source toolkit and models are released. You can apply HtmlRAG in your own RAG systems now. + We propose HtmlRAG, which uses HTML instead of plain text as the format of external knowledge in RAG systems. To tackle the long context brought by HTML, we propose **Lossless HTML Cleaning** and **Two-Step Block-Tree-Based HTML Pruning**. - **Lossless HTML Cleaning**: This cleaning process just removes totally irrelevant contents and compress redundant structures, retaining all semantic information in the original HTML. The compressed HTML of lossless HTML cleaning is suitable for RAG systems that have long-context LLMs and are not willing to loss any information before generation. @@ -24,6 +29,11 @@ We provide a simple tookit to apply HtmlRAG in your own RAG systems. ### 📦 Installation +Install the package using pip: +```bash +pip install htmlrag +``` +Or install the package from source: ```bash cd toolkit/ pip install -e . diff --git a/jupyter/module_test.ipynb b/jupyter/module_test.ipynb index ae89076..68e1ce7 100644 --- a/jupyter/module_test.ipynb +++ b/jupyter/module_test.ipynb @@ -2,39 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, "id": "initial_id", "metadata": { + "collapsed": true, "ExecuteTime": { - "end_time": "2024-10-20T05:33:09.950503Z", - "start_time": "2024-10-20T05:32:33.270934Z" - }, - "collapsed": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data_train/search/InternData/jiejuntan/anaconda3/envs/py39/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "When was the bellagio in las vegas built?\n", - "

The Bellagio is a luxury hotel and casino located on the Las Vegas Strip in Paradise, Nevada. It was built in 1998.

\n", - "
\n", - "

Some other text

\n", - "

Some other text

\n", - "
\n", - "\n" - ] + "end_time": "2024-11-06T06:49:10.896569Z", + "start_time": "2024-11-06T06:49:10.685326Z" } - ], + }, "source": [ "import sys\n", "sys.path.append('..')\n", @@ -76,7 +51,21 @@ "#

Some other text

\n", "# \n", "# " - ] + ], + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'HtmlRAG'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01msys\u001B[39;00m\n\u001B[1;32m 2\u001B[0m sys\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mappend(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m..\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mHtmlRAG\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m clean_html\n\u001B[1;32m 5\u001B[0m question\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mWhen was the bellagio in las vegas built?\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 6\u001B[0m html\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\"\"\u001B[39m\n\u001B[1;32m 7\u001B[0m \u001B[38;5;124m\u001B[39m\n\u001B[1;32m 8\u001B[0m \u001B[38;5;124m\u001B[39m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 25\u001B[0m \u001B[38;5;124m\u001B[39m\n\u001B[1;32m 26\u001B[0m \u001B[38;5;124m\"\"\"\u001B[39m\n", + "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'HtmlRAG'" + ] + } + ], + "execution_count": 1 }, { "cell_type": "code", diff --git a/toolkit/README.md b/toolkit/README.md index cec8b85..618a317 100644 --- a/toolkit/README.md +++ b/toolkit/README.md @@ -1,11 +1,22 @@ # 🤖🔍 HtmlRAG -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +
+ + + +License +Static Badge +
A toolkit to apply HtmlRAG in your own RAG systems. ## 📦 Installation +Install the package using pip: +```bash +pip install htmlrag +``` +Or install the package from source: ```bash pip install -e . ``` @@ -89,8 +100,21 @@ for block in block_tree: ```python from htmlrag import EmbedHTMLPruner -embed_html_pruner = EmbedHTMLPruner(embed_model="bm25") -block_rankings = embed_html_pruner.calculate_block_rankings(question, simplified_html, block_tree) +embed_model="/train_data_load/huggingface/tjj_hf/bge-large-en/" +query_instruction_for_retrieval = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: " +embed_html_pruner = EmbedHTMLPruner(embed_model=embed_model, local_inference=True, query_instruction_for_retrieval = query_instruction_for_retrieval) +# alternatively you can init a remote TEI model, refer to https://github.com/huggingface/text-embeddings-inference. +# tei_endpoint="http://YOUR_TEI_ENDPOINT" +# embed_html_pruner = EmbedHTMLPruner(embed_model=embed_model, local_inference=False, query_instruction_for_retrieval = query_instruction_for_retrieval, endpoint=tei_endpoint) +block_rankings=embed_html_pruner.calculate_block_rankings(question, simplified_html, block_tree) +print(block_rankings) + +# [0, 2, 1] + +#. alternatively you can use bm25 to rank the blocks +from htmlrag import BM25HTMLPruner +bm25_html_pruner = BM25HTMLPruner() +block_rankings=bm25_html_pruner.calculate_block_rankings(question, simplified_html, block_tree) print(block_rankings) # [0, 2, 1] @@ -100,8 +124,7 @@ from transformers import AutoTokenizer chat_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-70B-Instruct") max_context_window = 60 -pruned_html = embed_html_pruner.prune_HTML(simplified_html, block_tree, block_rankings, chat_tokenizer, - max_context_window) +pruned_html = embed_html_pruner.prune_HTML(simplified_html, block_tree, block_rankings, chat_tokenizer, max_context_window) print(pruned_html) # diff --git a/toolkit/HtmlRAG/__init__.py b/toolkit/htmlrag/__init__.py similarity index 100% rename from toolkit/HtmlRAG/__init__.py rename to toolkit/htmlrag/__init__.py diff --git a/toolkit/HtmlRAG/html_utils.py b/toolkit/htmlrag/html_utils.py similarity index 100% rename from toolkit/HtmlRAG/html_utils.py rename to toolkit/htmlrag/html_utils.py diff --git a/toolkit/HtmlRAG/pruner.py b/toolkit/htmlrag/pruner.py similarity index 76% rename from toolkit/HtmlRAG/pruner.py rename to toolkit/htmlrag/pruner.py index 220f739..d099be9 100644 --- a/toolkit/HtmlRAG/pruner.py +++ b/toolkit/htmlrag/pruner.py @@ -1,8 +1,8 @@ -from langchain_community.vectorstores import FAISS + from langchain_core.documents import Document import bs4 from .html_utils import trim_path, simplify_html, truncate_input, TokenIdNode -from langchain_community.retrievers import BM25Retriever + import json from typing import List, Tuple @@ -72,20 +72,22 @@ def prune_HTML(self, html, block_tree: List[Tuple], block_rankings: List[int], c class EmbedHTMLPruner(Pruner): - def __init__(self, embed_model="bm25", url=""): - self.embed_model = embed_model - if embed_model == "bm25": - self.embedder=None - self.query_instruction_for_retrieval = "" + def __init__(self, embed_model="BAAI/bge-large-en", local_inference=True, query_instruction_for_retrieval="", endpoint=""): + self.query_instruction_for_retrieval = "" + if embed_model == "BAAI/bge-large-en": + self.query_instruction_for_retrieval = "Represent this sentence for searching relevant passages: " + elif embed_model == "intfloat/e5-mistral-7b-instruct": + self.query_instruction_for_retrieval = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: " + if query_instruction_for_retrieval: + self.query_instruction_for_retrieval = query_instruction_for_retrieval + + if local_inference: + from langchain_huggingface import HuggingFaceEmbeddings + self.embedder = HuggingFaceEmbeddings(model_name=embed_model) else: - if embed_model == "bgelargeen": - self.query_instruction_for_retrieval = "Represent this sentence for searching relevant passages: " - elif embed_model == "e5-mistral": - self.query_instruction_for_retrieval = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: " from langchain_huggingface import HuggingFaceEndpointEmbeddings - embedder = HuggingFaceEndpointEmbeddings( - model=url, + model=endpoint, huggingfacehub_api_token="a-default-token", model_kwargs={"truncate": True}) self.embedder = embedder @@ -101,14 +103,29 @@ def calculate_block_rankings(self, question: str, html: str, block_tree: List[Tu node_docs.append(Document(page_content=path_tags[pidx].get_text(), metadata={"path_idx": pidx})) batch_size = 256 - if self.embed_model == "bm25": - retriever=BM25Retriever.from_documents(node_docs) - else: - db = FAISS.from_documents(node_docs[:batch_size], self.embedder) - if len(node_docs) > batch_size: - for doc_batch_idx in range(batch_size, len(node_docs), batch_size): - db.add_documents(node_docs[doc_batch_idx:doc_batch_idx + batch_size]) - retriever = db.as_retriever(search_kwargs={"k": len(node_docs)}) + from langchain_community.vectorstores import FAISS + db = FAISS.from_documents(node_docs[:batch_size], self.embedder) + if len(node_docs) > batch_size: + for doc_batch_idx in range(batch_size, len(node_docs), batch_size): + db.add_documents(node_docs[doc_batch_idx:doc_batch_idx + batch_size]) + retriever = db.as_retriever(search_kwargs={"k": len(node_docs)}) + ranked_docs = retriever.invoke(question) + block_rankings = [doc.metadata["path_idx"] for doc in ranked_docs] + + return block_rankings + + +class BM25HTMLPruner(Pruner): + def calculate_block_rankings(self, question: str, html: str, block_tree: List[Tuple]): + path_tags = [b[0] for b in block_tree] + paths = [b[1] for b in block_tree] + + node_docs = [] + for pidx in range(len(paths)): + node_docs.append(Document(page_content=path_tags[pidx].get_text(), metadata={"path_idx": pidx})) + from langchain_community.retrievers import BM25Retriever + retriever = BM25Retriever.from_documents(node_docs) + retriever.from_documents(node_docs) ranked_docs = retriever.invoke(question) block_rankings = [doc.metadata["path_idx"] for doc in ranked_docs] diff --git a/toolkit/pyproject.toml b/toolkit/pyproject.toml index 63cd1de..32a6383 100644 --- a/toolkit/pyproject.toml +++ b/toolkit/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "htmlrag" -version = "0.0.1" +version = "0.0.3" authors = [ { name="Example Author", email="author@example.com" }, ] @@ -23,6 +23,7 @@ dependencies = [ "pydantic-settings", "bs4", "transformers", + "sentence-transformers", "torch", "numpy", "langchain", @@ -30,15 +31,16 @@ dependencies = [ "langchain-huggingface", "anytree", "langchain-community >= 0.3.0", + "faiss-cpu", ] [project.urls] -Homepage = "https://github.com/pypa/sampleproject" -Issues = "https://github.com/pypa/sampleproject/issues" +Homepage = "https://github.com/plageon/HtmlRAG" +Issues = "https://github.com/plageon/HtmlRAG/issues" [tool.setuptools.packages.find] # All the following settings are optional: where = ["."] -include = ["HtmlRAG"] +include = ["htmlrag"] exclude = ["tests"] namespaces = true