From 7d1e7072d395cfc4f80d18e716bedeb558643e86 Mon Sep 17 00:00:00 2001 From: Shreyash Sridhar Iyengar <54525027+shreyash2106@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:05:18 -0700 Subject: [PATCH] Fix bug and Improve logging in Data Processing module for websites (#42) --- src/agrag/agrag.py | 14 ++++++++------ src/agrag/args.py | 4 ++-- src/agrag/configs/data_processing/default.yaml | 2 ++ .../modules/data_processing/data_processing.py | 16 +++++++++++++--- src/agrag/modules/data_processing/utils.py | 2 -- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/agrag/agrag.py b/src/agrag/agrag.py index c49c7da3..fdaed792 100644 --- a/src/agrag/agrag.py +++ b/src/agrag/agrag.py @@ -174,7 +174,7 @@ def initialize_data_module(self): """Initializes the Data Processing module.""" self.data_processing_module = DataProcessingModule( data_dir=self.data_dir, - web_urls=self.args.web_urls, + web_urls=self.web_urls, chunk_size=self.args.chunk_size, chunk_overlap=self.args.chunk_overlap, file_exts=self.args.data_file_extns, @@ -486,10 +486,8 @@ def batched_processing(self): """ - logger.info(f"Processing Data from Data Directory: {self.data_processing_module.data_dir}") file_paths = get_all_file_paths(self.data_processing_module.data_dir, self.data_processing_module.file_exts) - logger.info(f"Processing the Web URLs: {self.data_processing_module.web_urls}") web_urls = [] if self.parse_urls_recursive: for idx, url in enumerate(self.web_urls): @@ -511,6 +509,7 @@ def batched_processing(self): self.login_info[sub_url] = self.login_info[url] batch_num = 1 + for i in range(0, max(len(file_paths), len(web_urls)), self.batch_size): logger.info(f"Batch {batch_num}") @@ -518,9 +517,12 @@ def batched_processing(self): batch_urls = web_urls[i : i + self.batch_size] # Data Processing - processed_data = self.data_processing_module.process_files(batch_file_paths) - processed_files_data, last_doc_id = self.data_processing_module.process_files(file_paths, start_doc_id=0) - processed_urls_data = self.data_processing_module.process_urls(batch_urls, start_doc_id=last_doc_id) + processed_files_data, last_doc_id = self.data_processing_module.process_files( + batch_file_paths, start_doc_id=0 + ) + processed_urls_data = self.data_processing_module.process_urls( + batch_urls, login_info=self.login_info, start_doc_id=last_doc_id + ) processed_data = pd.concat([processed_files_data, processed_urls_data]).reset_index(drop=True) # Embedding diff --git a/src/agrag/args.py b/src/agrag/args.py index 426cf85d..d992e597 100644 --- a/src/agrag/args.py +++ b/src/agrag/args.py @@ -124,7 +124,7 @@ def base_urls(self, value): @property def html_tags_to_extract(self): - return self.config.get("data", {}).get("html_tags_to_extract") + return self.config.get("data", {}).get("html_tags_to_extract", self.data_defaults.get("SUPPORTED_HTML_TAGS")) @html_tags_to_extract.setter def html_tags_to_extract(self, value): @@ -164,7 +164,7 @@ def chunk_overlap(self, value): @property def data_file_extns(self): - return self.config.get("data", {}).get("file_extns", []) + return self.config.get("data", {}).get("file_extns", self.data_defaults.get("SUPPORTED_FILE_EXTENSIONS")) @data_file_extns.setter def data_file_extns(self, value): diff --git a/src/agrag/configs/data_processing/default.yaml b/src/agrag/configs/data_processing/default.yaml index e0fbf4c8..27e3e4c1 100644 --- a/src/agrag/configs/data_processing/default.yaml +++ b/src/agrag/configs/data_processing/default.yaml @@ -1,3 +1,5 @@ CHUNK_SIZE: 512 CHUNK_OVERLAP: 64 PARSE_URLS_RECURSIVE: true +SUPPORTED_FILE_EXTENSIONS: [".pdf", ".txt", ".docx", ".doc", ".rtf", ".csv", ".md", ".py", ".log"] +SUPPORTED_HTML_TAGS: ["p", "table"] diff --git a/src/agrag/modules/data_processing/data_processing.py b/src/agrag/modules/data_processing/data_processing.py index 958065bc..c5e91839 100644 --- a/src/agrag/modules/data_processing/data_processing.py +++ b/src/agrag/modules/data_processing/data_processing.py @@ -94,6 +94,14 @@ def __init__( ) self.html_tags_to_extract = kwargs.get("html_tags_to_extract", SUPPORTED_HTML_TAGS) + if self.data_dir: + logger.info(f"Processing Data from Data Directory: {self.data_dir}") + logger.info(f"\n Extracting text from the following document types: {self.file_exts}.") + + if self.web_urls: + logger.info(f"Processing the Web URLs: {self.web_urls}") + logger.info(f"\n Extracting text from the following HTML tags: {self.html_tags_to_extract}.") + def chunk_data_naive(self, text: str) -> List[str]: """ Naively chunks text into segments of a specified size without any overlap. @@ -182,7 +190,7 @@ def process_files(self, file_paths: List[str], start_doc_id: int = 0) -> pd.Data pd.DataFrame A DataFrame of processed text chunks from the given files. """ - processed_data = [] + processed_data = [pd.DataFrame([])] doc_id_counter = start_doc_id with concurrent.futures.ThreadPoolExecutor() as executor: @@ -252,12 +260,14 @@ def process_urls(self, urls: List[str], login_info: dict = {}, start_doc_id: int pd.DataFrame A DataFrame of processed text chunks from the given URLs. """ - processed_data = [] + processed_data = [pd.DataFrame([])] doc_id_counter = start_doc_id with concurrent.futures.ThreadPoolExecutor() as executor: results = executor.map( - self.process_url, urls, [login_info] * len(urls), range(doc_id_counter, doc_id_counter + len(urls)) + lambda url, doc_id: self.process_url(url, doc_id, login_info), + urls, + range(doc_id_counter, doc_id_counter + len(urls)), ) for result in results: processed_data.append(result) diff --git a/src/agrag/modules/data_processing/utils.py b/src/agrag/modules/data_processing/utils.py index 012db4ca..09b99c24 100644 --- a/src/agrag/modules/data_processing/utils.py +++ b/src/agrag/modules/data_processing/utils.py @@ -270,8 +270,6 @@ def bs4_extractor(html: str, tags_to_extract: List[str] = ["p", "table"]) -> str extracted_text = [] - logger.info(f"\n Extracting text from the following HTML tags: {tags_to_extract}.") - for tag in tags_to_extract: elements = soup.find_all(tag) for element in elements: