Skip to content

Commit

Permalink
Add support for Web URLs in RAG Pipeline (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyash2106 authored Jul 30, 2024
1 parent 674628d commit 383bda1
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 26 deletions.
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ from agrag.agrag import AutoGluonRAG
def ag_rag():
agrag = AutoGluonRAG(config_file="path/to/config")
agrag = AutoGluonRAG(
preset_quality="medium_quality", # or path to config file
web_urls=["https://auto.gluon.ai/stable/index.html"],
base_urls=["https://auto.gluon.ai/stable/"],
parse_urls_recursive=True,
data_dir="s3://autogluon-rag-github-dev/autogluon_docs/"
)
agrag.initialize_rag_pipeline()
agrag.generate_response("What is AutoGluon?")
Expand All @@ -69,11 +75,25 @@ model_ids : dict
Example: {"generator_model_id": "mistral.mistral-7b-instruct-v0:2", "retriever_model_id": "BAAI/bge-large-en", "reranker_model_id": "nv_embed"}
data_dir : str
The directory containing the data files that will be used for the RAG pipeline. If this value is not provided when initializing the object, it must be provided in the config file. If both are provided, the value in the class instantiation will be prioritized.
web_urls : List[str]
List of website URLs to be ingested and processed. Each URL will processed recursively based on the base URL to include the content of URLs that exist within this URL.
If this value is not provided when initializing the object, it must be provided in the config file. If both are provided, the value in the class instantiation will be prioritized.
base_urls : List[str]
List of optional base URLs to check for links recursively. The base URL controls which URLs will be processed during recursion. The base_url does not need to be the same as the web_url. For example. the web_url can be "https://auto.gluon.ai/stable/index.html", and the base_urls will be "https://auto.gluon.ai/stable/".
If this value is not provided when initializing the object, it must be provided in the config file. If both are provided, the value in the class instantiation will be prioritized.
login_info: dict
A dictionary containing login credentials for each URL. Required if the target URL requires authentication.
Must be structured as {target_url: {"login_url": <login_url>, "credentials": {"username": "your_username", "password": "your_password"}}}
The target_url is a url that is present in the list of web_urls
parse_urls_recursive: bool
Whether to parse each URL in the provided recursively. Setting this to True means that the child links present in each parent webpage will also be processed.
pipeline_batch_size: int
Batch size to use for pre-processing stage (Data Processing, Embedding, Vector DB Module). This represents the number of files in each batch.
The default value is 20.
```

**Note**: You may provide both `data_dir` and `web_urls`.

The configuration file contains the specific parameters to use for each module in the RAG pipeline. For an example of a config file, please refer to `example_config.yaml` in `src/agrag/configs/`. For specific details about the parameters in each individual module, refer to the `README` files in each module in `src/agrag/modules/`.

There is also a `shared` section in the config file for parameters that do not refer to a specific module. Currently, the parameters in `shared` are:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ classifiers = [
]
license = {file = "LICENSE"}
dependencies = [
"beautifulsoup4>=4.12.0,<5.0",
"boto3>=1.34.124,<2.0",
"datasets>=2.20.0,<3.0",
"evaluate>=0.4.2,<1.0",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
beautifulsoup4>=4.12.0,<5.0
boto3>=1.34.124,<2.0
datasets>=2.20.0,<3.0
evaluate>=0.4.2,<1.0
Expand Down
65 changes: 60 additions & 5 deletions src/agrag/agrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pandas as pd
import yaml
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_core.utils.html import extract_sub_links

from agrag.args import Arguments
from agrag.modules.data_processing.data_processing import DataProcessingModule
Expand Down Expand Up @@ -34,6 +36,10 @@ def __init__(
preset_quality: Optional[str] = "medium_quality",
model_ids: Dict = None,
data_dir: str = "",
web_urls: List = [],
base_urls: List = [],
login_info: dict = {},
parse_urls_recursive: bool = True,
pipeline_batch_size: int = 0,
):
"""
Expand All @@ -50,6 +56,17 @@ def __init__(
Example: {"generator_model_id": "mistral.mistral-7b-instruct-v0:2", "retriever_model_id": "BAAI/bge-large-en", "reranker_model_id": "nv_embed"}
data_dir : str
The directory containing the data files that will be used for the RAG pipeline
web_urls : List[str]
List of website URLs to be ingested and processed.
base_urls : List[str]
List of optional base URLs to check for links recursively. The base URL controls which URLs will be processed during recursion.
The base_url does not need to be the same as the web_url. For example. the web_url can be "https://auto.gluon.ai/stable/index.html", and the base_urls will be "https://auto.gluon.ai/stable/"/
login_info: dict
A dictionary containing login credentials for each URL. Required if the target URL requires authentication.
Must be structured as {target_url: {"login_url": <login_url>, "credentials": {"username": "your_username", "password": "your_password"}}}
The target_url is a url that is present in the list of web_urls
parse_urls_recursive: bool
Whether to parse each URL in the provided recursively. Setting this to True means that the child links present in each parent webpage will also be processed.
pipeline_batch_size: int
Optional batch size to use for pre-processing stage (Data Processing, Embedding, Vector DB Module)
Expand Down Expand Up @@ -114,11 +131,15 @@ def __init__(

self.args = Arguments(self.config)

# will short-circuit to provided data_dir if config data_dir also provided
# will short-circuit to provided data_dir if config value also provided
self.data_dir = data_dir or self.args.data_dir
self.web_urls = web_urls or self.args.web_urls
self.base_urls = base_urls or self.args.base_urls
self.parse_urls_recursive = parse_urls_recursive or self.args.parse_urls_recursive
self.login_info = login_info or self.args.login_info

if not self.data_dir:
raise ValueError("data_dir argument must be provided")
if not self.data_dir and not self.web_urls:
raise ValueError("Either data_dir or web_urls argument must be provided")

self.data_processing_module = None
self.embedding_module = None
Expand Down Expand Up @@ -153,9 +174,12 @@ def initialize_data_module(self):
"""Initializes the Data Processing module."""
self.data_processing_module = DataProcessingModule(
data_dir=self.data_dir,
web_urls=self.args.web_urls,
chunk_size=self.args.chunk_size,
chunk_overlap=self.args.chunk_overlap,
file_exts=self.args.data_file_extns,
html_tags_to_extract=self.args.html_tags_to_extract,
login_info=self.login_info,
)
logger.info("Data Processing module initialized")

Expand Down Expand Up @@ -459,16 +483,47 @@ def batched_processing(self):
"""

logger.info(f"Retrieving and Processing Data from {self.data_processing_module.data_dir}")
logger.info(f"Processing Data from Data Directory: {self.data_processing_module.data_dir}")
file_paths = get_all_file_paths(self.data_processing_module.data_dir, self.data_processing_module.file_exts)

logger.info(f"Processing the Web URLs: {self.data_processing_module.web_urls}")
web_urls = []
if self.parse_urls_recursive:
for idx, url in enumerate(self.web_urls):
loader = RecursiveUrlLoader(url=url, max_depth=1)
docs = loader.load()
urls = extract_sub_links(
raw_html=docs[0].page_content,
url=url,
base_url=self.base_urls[idx],
continue_on_failure=True,
)
urls = [url] + urls
logger.info(
f"\nFound {len(urls)} URLs by recursively parsing the webpage {url} with base URL {self.base_urls[idx]}."
)
web_urls.extend(urls)
if url in self.login_info:
for sub_url in urls:
self.login_info[sub_url] = self.login_info[url]

batch_num = 1
for i in range(0, len(file_paths), self.batch_size):
for i in range(0, max(len(file_paths), len(web_urls)), self.batch_size):
logger.info(f"Batch {batch_num}")

batch_file_paths = file_paths[i : i + self.batch_size]
batch_urls = web_urls[i : i + self.batch_size]

# Data Processing
processed_data = self.data_processing_module.process_files(batch_file_paths)
processed_files_data, last_doc_id = self.data_processing_module.process_files(file_paths, start_doc_id=0)
processed_urls_data = self.data_processing_module.process_urls(batch_urls, start_doc_id=last_doc_id)
processed_data = pd.concat([processed_files_data, processed_urls_data]).reset_index(drop=True)

# Embedding
embeddings = self.generate_embeddings(processed_data)

# Vector DB
self.construct_vector_db(embeddings)

# Clear memory
Expand Down
40 changes: 40 additions & 0 deletions src/agrag/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,46 @@ def data_dir(self):
def data_dir(self, value):
self.config["data"]["data_dir"] = value

@property
def web_urls(self):
return self.config.get("data", {}).get("web_urls", [])

@web_urls.setter
def web_urls(self, value):
self.config["data"]["web_urls"] = value

@property
def base_urls(self):
return self.config.get("data", {}).get("base_urls", [])

@base_urls.setter
def base_urls(self, value):
self.config["data"]["base_urls"] = value

@property
def html_tags_to_extract(self):
return self.config.get("data", {}).get("html_tags_to_extract")

@html_tags_to_extract.setter
def html_tags_to_extract(self, value):
self.config["data"]["html_tags_to_extract"] = value

@property
def login_info(self):
return self.config.get("data", {}).get("login_info", {})

@login_info.setter
def web_urls(self, value):
self.config["data"]["login_info"] = value

@property
def parse_urls_recursive(self):
return self.config.get("data", {}).get("parse_urls_recursive", self.data_defaults.get("PARSE_URLS_RECURSIVE"))

@parse_urls_recursive.setter
def parse_urls_recursive(self, value):
self.config["data"]["parse_urls_recursive"] = value

@property
def chunk_size(self):
return self.config.get("data", {}).get("chunk_size", self.data_defaults.get("CHUNK_SIZE"))
Expand Down
1 change: 1 addition & 0 deletions src/agrag/configs/data_processing/default.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
CHUNK_SIZE: 512
CHUNK_OVERLAP: 64
PARSE_URLS_RECURSIVE: true
1 change: 1 addition & 0 deletions src/agrag/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
EMBEDDING_KEY = "embedding"
EMBEDDING_HIDDEN_DIM_KEY = "embedding_hidden_dim"
SUPPORTED_FILE_EXTENSIONS = [".pdf", ".txt", ".docx", ".doc", ".rtf", ".csv", ".md", ".py", ".log"]
SUPPORTED_HTML_TAGS = ["p", "table"]
EVALUATION_DIR = "./evaluation_data"
EVALUATION_MAX_FILE_SIZE = 5 * 1000 * 1000 # 5 MB
9 changes: 8 additions & 1 deletion src/agrag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@


def ag_rag():
agrag = AutoGluonRAG(preset_quality="medium_quality", data_dir="s3://autogluon-rag-github-dev/autogluon_docs/")
agrag = AutoGluonRAG(
preset_quality="medium_quality",
web_urls=["https://auto.gluon.ai/stable/index.html"],
base_urls=["https://auto.gluon.ai/stable/"],
parse_urls_recursive=True,
data_dir="s3://autogluon-rag-github-dev/autogluon_docs/",
)

agrag.initialize_rag_pipeline()
while True:
query_text = input(
Expand Down
13 changes: 12 additions & 1 deletion src/agrag/modules/data_processing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ Here are the configurable parameters for this module:
data:
data_dir : The directory containing the data files to be ingested. This can be either a local directory or an S3 URI to a directory in an S3 bucket.
web_urls: List of website URLs to be ingested and processed.
base_urls: List of base URLs to check for links recursively. The base URL controls which URLs will be processed during recursion. The base_url does not need to be the same as the web_url. For example. the web_url can be "https://auto.gluon.ai/stable/index.html", and the base_urls will be "https://auto.gluon.ai/stable/"/
parse_urls_recursive: Whether to parse each URL in the provided recursively. Setting this to True means that the child links present in each parent webpage will also be processed.
chunk_size : The size of each chunk of text (default is 512).
chunk_overlap : The overlap between consecutive chunks of text (default is 128).
file_exts: List of file extensions to read. Only the following file extensions are supported: ".pdf", ".txt", ".docx", ".doc", ".rtf", ".csv", ".md", ".py", ".log"
file_exts: List of file extensions to support. Default is [".pdf", ".txt", ".docx", ".doc", ".rtf", ".csv", ".md", ".py", ".log"]
html_tags_to_extract: List of HTML tags to extract text from. Default is ["p", "table"]. We support ["p", "table", "li", "div", "span", "<h_tags>"] currently.
login_info: A dictionary containing login credentials for each URL. Required if the target URL requires authentication.
Must be structured as {target_url: {"login_url": <login_url>, "credentials": {"username": "your_username", "password": "your_password"}}}
```
Loading

0 comments on commit 383bda1

Please sign in to comment.