Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Hybrid Retrieval RRF module and edit run.py for it #98

Merged
merged 17 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions autorag/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
import yaml


from autorag.deploy import Runner
from autorag import embedding_models
from autorag.node_line import run_node_line
Expand All @@ -20,7 +19,7 @@
from autorag.schema import Node
from autorag.schema.node import module_type_exists, extract_values_from_nodes
from autorag.utils import cast_qa_dataset, cast_corpus_dataset
from autorag.utils.util import load_summary_file
from autorag.utils.util import load_summary_file, convert_string_to_tuple_in_dict

logger = logging.getLogger("AutoRAG")

Expand Down Expand Up @@ -110,7 +109,8 @@ def __embed(self, node_lines: Dict[str, List[Node]]):
# ingest VectorDB corpus
logger.info(f'Embedding VectorDB corpus with {embedding_model_str}...')
# Get the collection with GET or CREATE, as it may already exist
collection = vectordb.get_or_create_collection(name=embedding_model_str, metadata={"hnsw:space": "cosine"})
collection = vectordb.get_or_create_collection(name=embedding_model_str,
metadata={"hnsw:space": "cosine"})
# get embedding_model
if embedding_model_str in embedding_models:
embedding_model = embedding_models[embedding_model_str]
Expand Down Expand Up @@ -156,6 +156,7 @@ def _load_node_lines(yaml_path: str) -> Dict[str, List[Node]]:
except yaml.YAMLError as exc:
raise ValueError(f"YAML file {yaml_path} could not be loaded.") from exc

yaml_dict = convert_string_to_tuple_in_dict(yaml_dict)
node_lines = yaml_dict['node_lines']
node_line_dict = {}
for node_line in node_lines:
Expand Down
1 change: 1 addition & 0 deletions autorag/nodes/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import retrieval_node
from .bm25 import bm25
from .vectordb import vectordb
from .hybrid_rrf import hybrid_rrf
10 changes: 7 additions & 3 deletions autorag/nodes/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import pandas as pd

from autorag import embedding_models
from autorag.strategy import select_best_average
from autorag.utils import fetch_contents, result_to_dataframe, validate_qa_dataset

import logging

logger = logging.getLogger("AutoRAG")


Expand All @@ -29,7 +31,7 @@ def retrieval_node(func):
def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
**kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
validate_qa_dataset(previous_result)
resources_dir = os.path.join(project_dir, "resources")
data_dir = os.path.join(project_dir, "data")
Expand All @@ -56,7 +58,7 @@ def wrapper(
# run retrieval function
if func.__name__ == "bm25":
bm25_corpus = load_bm25_corpus(bm25_path)
ids, scores = func(queries=queries, bm25_corpus=bm25_corpus, *args, **kwargs)
ids, scores = func(queries=queries, bm25_corpus=bm25_corpus, **kwargs)
elif func.__name__ == "vectordb":
chroma_collection = load_chroma_collection(db_path=chroma_path, collection_name=embedding_model_str)
if embedding_model_str in embedding_models:
Expand All @@ -65,7 +67,9 @@ def wrapper(
logger.error(f"embedding_model_str {embedding_model_str} does not exist.")
raise KeyError(f"embedding_model_str {embedding_model_str} does not exist.")
ids, scores = func(queries=queries, collection=chroma_collection,
embedding_model=embedding_model, *args, **kwargs)
embedding_model=embedding_model, **kwargs)
elif func.__name__ == "hybrid_rrf":
ids, scores = func(**kwargs)
else:
raise ValueError(f"invalid func name for using retrieval_io decorator.")

Expand Down
70 changes: 70 additions & 0 deletions autorag/nodes/retrieval/hybrid_rrf.py
vkehfdl1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import List, Tuple

import pandas as pd
import swifter

from autorag.nodes.retrieval import retrieval_node


@retrieval_node
def hybrid_rrf(
ids: Tuple,
scores: Tuple,
top_k: int,
rrf_k: int = 60) -> Tuple[List[List[str]], List[List[float]]]:
"""
Hybrid RRF function.
RRF (Rank Reciprocal Fusion) is a method to fuse multiple retrieval results.
It is common to fuse dense retrieval and sparse retrieval results using RRF.
To use this function, you must input ids and scores as tuple.
It is uniquer than other retrieval modules, because it does not really execute retrieval,
but just fuse the results of other retrieval functions.
So you have to run more than two retrieval modules before running this function.
And collect ids and scores result from each retrieval module.
Make it as tuple and input it to this function.

:param ids: The tuple of ids that you want to fuse.
The length of this must be the same as the length of scores.
:param scores: The retrieve scores that you want to fuse.
The length of this must be the same as the length of ids.
:param top_k: The number of passages to be retrieved.
:param rrf_k: Hyperparameter for RRF.
Default is 60.
For more information, please visit our documentation.
:return: The tuple of ids and fused scores that fused by RRF.
"""
assert len(ids) == len(scores), "The length of ids and scores must be the same."
assert len(ids) > 1, "You must input more than one retrieval results."
assert top_k > 0, "top_k must be greater than 0."
assert rrf_k > 0, "rrf_k must be greater than 0."

id_df = pd.DataFrame({f'id_{i}': id_list for i, id_list in enumerate(ids)})
score_df = pd.DataFrame({f'score_{i}': score_list for i, score_list in enumerate(scores)})
df = pd.concat([id_df, score_df], axis=1)

def rrf_pure_apply(row):
ids_tuple = tuple(row[[f'id_{i}' for i in range(len(ids))]].values)
scores_tuple = tuple(row[[f'score_{i}' for i in range(len(scores))]].values)
return pd.Series(rrf_pure(ids_tuple, scores_tuple, rrf_k, top_k))

df[['rrf_id', 'rrf_score']] = df.swifter.apply(rrf_pure_apply, axis=1)
return df['rrf_id'].tolist(), df['rrf_score'].tolist()


def rrf_pure(ids: Tuple, scores: Tuple, rrf_k: int, top_k: int) -> Tuple[
List[str], List[float]]:
df = pd.concat([pd.Series(dict(zip(_id, score))) for _id, score in zip(ids, scores)], axis=1)
rank_df = df.rank(ascending=False, method='min')
rank_df = rank_df.fillna(0)
rank_df['rrf'] = rank_df.apply(lambda row: rrf_calculate(row, rrf_k), axis=1)
rank_df = rank_df.sort_values(by='rrf', ascending=False)
return rank_df.index.tolist()[:top_k], rank_df['rrf'].tolist()[:top_k]


def rrf_calculate(row, rrf_k):
result = 0
for r in row:
if r == 0:
continue
result += 1 / (r + rrf_k)
return result
127 changes: 102 additions & 25 deletions autorag/nodes/retrieval/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import os
import pathlib
from typing import List, Callable, Dict
from typing import List, Callable, Dict, Tuple

import pandas as pd

from autorag.evaluate import evaluate_retrieval
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average
from autorag.utils.util import load_summary_file

logger = logging.getLogger("AutoRAG")

Expand All @@ -33,31 +34,62 @@ def run_retrieval_node(modules: List[Callable],
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
average_times = list(map(lambda x: x / len(results[0]), execution_times))

# run metrics before filtering
if strategies.get('metrics') is None:
raise ValueError("You must at least one metrics for retrieval evaluation.")
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results))

# save results to folder
save_dir = os.path.join(node_line_dir, "retrieval") # node name
if not os.path.exists(save_dir):
os.makedirs(save_dir)
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules))))
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet
filenames = list(map(lambda x: os.path.basename(x), filepaths))

summary_df = pd.DataFrame({
'filename': filenames,
'module_name': list(map(lambda module: module.__name__, modules)),
'module_params': module_params,
'execution_time': average_times,
**{metric: list(map(lambda result: result[metric].mean(), results)) for metric in strategies.get('metrics')},
})

def run_and_save(input_modules, input_module_params):
result, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]),
zip(input_modules, input_module_params)))
average_times = list(map(lambda x: x / len(result[0]), execution_times))

# run metrics before filtering
if strategies.get('metrics') is None:
raise ValueError("You must at least one metrics for retrieval evaluation.")
result = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), result))

# save results to folder
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(input_modules))))
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(result, filepaths))) # execute save to parquet
filename_list = list(map(lambda x: os.path.basename(x), filepaths))

summary_df = pd.DataFrame({
'filename': filename_list,
'module_name': list(map(lambda module: module.__name__, input_modules)),
'module_params': input_module_params,
'execution_time': average_times,
**{metric: list(map(lambda result: result[metric].mean(), result)) for metric in
strategies.get('metrics')},
})
summary_df.to_csv(os.path.join(save_dir, 'summary.csv'), index=False)
return result, average_times, summary_df

# run retrieval modules except hybrid
hybrid_module_names = ['hybrid_rrf']
non_hybrid_modules, non_hybrid_module_params = zip(*filter(lambda x: x[0].__name__ not in hybrid_module_names,
zip(modules, module_params)))
non_hybrid_results, non_hybrid_times, non_hybrid_summary_df = run_and_save(non_hybrid_modules,
non_hybrid_module_params)

if any([module.__name__ in hybrid_module_names for module in modules]):
hybrid_modules, hybrid_module_params = zip(*filter(lambda x: x[0].__name__ in hybrid_module_names,
zip(modules, module_params)))
target_modules = list(map(lambda x: x.pop('target_modules'), hybrid_module_params))
target_filenames = list(map(lambda x: select_result_for_hybrid(save_dir, x), target_modules))
ids_scores = list(map(lambda x: get_ids_and_scores(save_dir, x), target_filenames))
hybrid_module_params = list(map(lambda x: {**x[0], **x[1]}, zip(hybrid_module_params, ids_scores)))
real_hybrid_times = list(map(lambda filename: get_hybrid_execution_times(save_dir, filename), target_filenames))
hybrid_results, hybrid_times, hybrid_summary_df = run_and_save(hybrid_modules, hybrid_module_params)
hybrid_times = real_hybrid_times.copy()
hybrid_summary_df['execution_time'] = hybrid_times
else:
hybrid_results, hybrid_times, hybrid_summary_df = [], [], pd.DataFrame()

summary = pd.concat([non_hybrid_summary_df, hybrid_summary_df], ignore_index=True)
results = non_hybrid_results + hybrid_results
average_times = non_hybrid_times + hybrid_times
filenames = summary['filename'].tolist()

# filter by strategies
if strategies.get('speed_threshold') is not None:
Expand All @@ -66,11 +98,11 @@ def run_retrieval_node(modules: List[Callable],
best_result = pd.concat([previous_result, selected_result], axis=1)

# add summary.csv 'is_best' column
summary_df['is_best'] = summary_df['filename'] == selected_filename
summary['is_best'] = summary['filename'] == selected_filename

# save the result files
best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'), index=False)
summary_df.to_csv(os.path.join(save_dir, 'summary.csv'), index=False)
summary.to_csv(os.path.join(save_dir, 'summary.csv'), index=False)
return best_result


Expand All @@ -90,3 +122,48 @@ def evaluate_this_module(df: pd.DataFrame):
return df['retrieved_contents'].tolist(), df['retrieved_ids'].tolist(), df['retrieve_scores'].tolist()

return evaluate_this_module(result_df)


def select_result_for_hybrid(node_dir: str, target_modules: Tuple) -> List[str]:
"""
Get ids and scores of target_module from summary.csv and each result parquet file.

:param node_dir: The directory of the node.
:param target_modules: The name of the target modules.
:return: A list of filenames.
"""
def select_best_among_module(df: pd.DataFrame, module_name: str):
modules_summary = df.loc[lambda row: row['module_name'] == module_name]
if len(modules_summary) == 1:
return modules_summary.iloc[0, :]
elif len(modules_summary) <= 0:
raise ValueError(f"module_name {module_name} does not exist in summary.csv. "
f"You must run {module_name} before running hybrid retrieval.")
metrics = modules_summary.drop(columns=['filename', 'module_name', 'module_params', 'execution_time'])
metric_average = metrics.mean(axis=1)
metric_average = metric_average.reset_index(drop=True)
max_idx = metric_average.idxmax()
best_module = modules_summary.iloc[max_idx, :]
return best_module

summary_df = load_summary_file(os.path.join(node_dir, "summary.csv"))
best_results = list(map(lambda module_name: select_best_among_module(summary_df, module_name), target_modules))
best_filenames = list(map(lambda df: df['filename'], best_results))
return best_filenames


def get_ids_and_scores(node_dir: str, filenames: List[str]) -> Dict:
best_results_df = list(map(lambda filename: pd.read_parquet(os.path.join(node_dir, filename)), filenames))
ids = tuple(map(lambda df: df['retrieved_ids'].apply(list).tolist(), best_results_df))
scores = tuple(map(lambda df: df['retrieve_scores'].apply(list).tolist(), best_results_df))
return {
'ids': ids,
'scores': scores,
}


def get_hybrid_execution_times(node_dir: str, filenames: List[str]) -> float:
summary_df = load_summary_file(os.path.join(node_dir, "summary.csv"))
best_results = summary_df[summary_df['filename'].isin(filenames)]
execution_times = best_results['execution_time'].sum()
return execution_times
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_support_modules(module_name: str) -> Callable:
'monot5': ('autorag.nodes.passagereranker', 'monot5'),
'tart': ('autorag.nodes.passagereranker', 'tart'),
'upr': ('autorag.nodes.passagereranker', 'upr'),
'hybrid_rrf': ('autorag.nodes.retrieval', 'hybrid_rrf'),
}
return dynamically_find_function(module_name, support_modules)

Expand Down
40 changes: 39 additions & 1 deletion autorag/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import string
from typing import List, Callable, Dict, Optional, Any, Collection
import ast

import pandas as pd
import swifter
Expand Down Expand Up @@ -98,7 +99,22 @@ def make_combinations(target_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
dict_with_lists = dict(map(lambda x: (x[0], x[1] if isinstance(x[1], list) else [x[1]]),
target_dict.items()))
dict_with_lists = dict(map(lambda x: (x[0], list(set(x[1]))), dict_with_lists.items()))

def delete_duplicate(x):
def is_hashable(obj):
try:
hash(obj)
return True
except TypeError:
return False

if any([not is_hashable(elem) for elem in x]):
# TODO: add duplication check for unhashable objects
return x
else:
return list(set(x))

dict_with_lists = dict(map(lambda x: (x[0], delete_duplicate(x[1])), dict_with_lists.items()))
combination = list(itertools.product(*dict_with_lists.values()))
combination_dicts = [dict(zip(dict_with_lists.keys(), combo)) for combo in combination]
return combination_dicts
Expand Down Expand Up @@ -161,3 +177,25 @@ def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def convert_string_to_tuple_in_dict(d):
"""Recursively converts strings that start with '(' and end with ')' to tuples in a dictionary."""
for key, value in d.items():
# If the value is a dictionary, recurse
if isinstance(value, dict):
convert_string_to_tuple_in_dict(value)
# If the value is a list, iterate through its elements
elif isinstance(value, list):
for i, item in enumerate(value):
# If an item in the list is a dictionary, recurse
if isinstance(item, dict):
convert_string_to_tuple_in_dict(item)
# If an item in the list is a string matching the criteria, convert it to a tuple
elif isinstance(item, str) and item.startswith('(') and item.endswith(')'):
value[i] = ast.literal_eval(item)
# If the value is a string matching the criteria, convert it to a tuple
elif isinstance(value, str) and value.startswith('(') and value.endswith(')'):
d[key] = ast.literal_eval(value)

return d
Loading
Loading