Skip to content

Commit

Permalink
Merge branch 'main' into bug_fix_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyash2106 committed Jul 31, 2024
2 parents f564c03 + 7d1e707 commit f6267e1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
14 changes: 8 additions & 6 deletions src/agrag/agrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def initialize_data_module(self):
"""Initializes the Data Processing module."""
self.data_processing_module = DataProcessingModule(
data_dir=self.data_dir,
web_urls=self.args.web_urls,
web_urls=self.web_urls,
chunk_size=self.args.chunk_size,
chunk_overlap=self.args.chunk_overlap,
file_exts=self.args.data_file_extns,
Expand Down Expand Up @@ -486,10 +486,8 @@ def batched_processing(self):
"""

logger.info(f"Processing Data from Data Directory: {self.data_processing_module.data_dir}")
file_paths = get_all_file_paths(self.data_processing_module.data_dir, self.data_processing_module.file_exts)

logger.info(f"Processing the Web URLs: {self.data_processing_module.web_urls}")
web_urls = []
if self.parse_urls_recursive:
for idx, url in enumerate(self.web_urls):
Expand All @@ -511,16 +509,20 @@ def batched_processing(self):
self.login_info[sub_url] = self.login_info[url]

batch_num = 1

for i in range(0, max(len(file_paths), len(web_urls)), self.batch_size):
logger.info(f"Batch {batch_num}")

batch_file_paths = file_paths[i : i + self.batch_size]
batch_urls = web_urls[i : i + self.batch_size]

# Data Processing
processed_data = self.data_processing_module.process_files(batch_file_paths)
processed_files_data, last_doc_id = self.data_processing_module.process_files(file_paths, start_doc_id=0)
processed_urls_data = self.data_processing_module.process_urls(batch_urls, start_doc_id=last_doc_id)
processed_files_data, last_doc_id = self.data_processing_module.process_files(
batch_file_paths, start_doc_id=0
)
processed_urls_data = self.data_processing_module.process_urls(
batch_urls, login_info=self.login_info, start_doc_id=last_doc_id
)
processed_data = pd.concat([processed_files_data, processed_urls_data]).reset_index(drop=True)

# Embedding
Expand Down
4 changes: 2 additions & 2 deletions src/agrag/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def base_urls(self, value):

@property
def html_tags_to_extract(self):
return self.config.get("data", {}).get("html_tags_to_extract")
return self.config.get("data", {}).get("html_tags_to_extract", self.data_defaults.get("SUPPORTED_HTML_TAGS"))

@html_tags_to_extract.setter
def html_tags_to_extract(self, value):
Expand Down Expand Up @@ -164,7 +164,7 @@ def chunk_overlap(self, value):

@property
def data_file_extns(self):
return self.config.get("data", {}).get("file_extns", [])
return self.config.get("data", {}).get("file_extns", self.data_defaults.get("SUPPORTED_FILE_EXTENSIONS"))

@data_file_extns.setter
def data_file_extns(self, value):
Expand Down
2 changes: 2 additions & 0 deletions src/agrag/configs/data_processing/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
CHUNK_SIZE: 512
CHUNK_OVERLAP: 64
PARSE_URLS_RECURSIVE: true
SUPPORTED_FILE_EXTENSIONS: [".pdf", ".txt", ".docx", ".doc", ".rtf", ".csv", ".md", ".py", ".log"]
SUPPORTED_HTML_TAGS: ["p", "table"]
16 changes: 13 additions & 3 deletions src/agrag/modules/data_processing/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def __init__(
)
self.html_tags_to_extract = kwargs.get("html_tags_to_extract", SUPPORTED_HTML_TAGS)

if self.data_dir:
logger.info(f"Processing Data from Data Directory: {self.data_dir}")
logger.info(f"\n Extracting text from the following document types: {self.file_exts}.")

if self.web_urls:
logger.info(f"Processing the Web URLs: {self.web_urls}")
logger.info(f"\n Extracting text from the following HTML tags: {self.html_tags_to_extract}.")

def chunk_data_naive(self, text: str) -> List[str]:
"""
Naively chunks text into segments of a specified size without any overlap.
Expand Down Expand Up @@ -182,7 +190,7 @@ def process_files(self, file_paths: List[str], start_doc_id: int = 0) -> pd.Data
pd.DataFrame
A DataFrame of processed text chunks from the given files.
"""
processed_data = []
processed_data = [pd.DataFrame([])]
doc_id_counter = start_doc_id

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -252,12 +260,14 @@ def process_urls(self, urls: List[str], login_info: dict = {}, start_doc_id: int
pd.DataFrame
A DataFrame of processed text chunks from the given URLs.
"""
processed_data = []
processed_data = [pd.DataFrame([])]
doc_id_counter = start_doc_id

with concurrent.futures.ThreadPoolExecutor() as executor:
results = executor.map(
self.process_url, urls, [login_info] * len(urls), range(doc_id_counter, doc_id_counter + len(urls))
lambda url, doc_id: self.process_url(url, doc_id, login_info),
urls,
range(doc_id_counter, doc_id_counter + len(urls)),
)
for result in results:
processed_data.append(result)
Expand Down
2 changes: 0 additions & 2 deletions src/agrag/modules/data_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ def bs4_extractor(html: str, tags_to_extract: List[str] = ["p", "table"]) -> str

extracted_text = []

logger.info(f"\n Extracting text from the following HTML tags: {tags_to_extract}.")

for tag in tags_to_extract:
elements = soup.find_all(tag)
for element in elements:
Expand Down

0 comments on commit f6267e1

Please sign in to comment.