Skip to content

Commit

Permalink
Possible error when the QA retrieval_gt shape will be different (#774)
Browse files Browse the repository at this point in the history
* add apply_recursive function at util

* resolve possible error at the run.py retrieval

* resolve possible error in other run.py while using retrieval_gt

* add qa evolve api docs rst

---------

Co-authored-by: jeffrey <vkefhdl1@gmail.com>
  • Loading branch information
vkehfdl1 and jeffrey authored Oct 1, 2024
1 parent 0b90dbd commit 8a090fd
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 33 deletions.
9 changes: 2 additions & 7 deletions autorag/nodes/passageaugmenter/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.schema.metricinput import MetricInput
from autorag.strategy import measure_speed, filter_by_threshold, select_best
from autorag.utils.util import apply_recursive, to_list

logger = logging.getLogger("AutoRAG")

Expand All @@ -26,13 +27,7 @@ def run_passage_augmenter_node(
os.path.join(project_dir, "data", "qa.parquet"), engine="pyarrow"
)
retrieval_gt = qa_df["retrieval_gt"].tolist()
retrieval_gt = [
[
[str(uuid) for uuid in sub_array] if sub_array.size > 0 else []
for sub_array in inner_array
]
for inner_array in retrieval_gt
]
retrieval_gt = apply_recursive(lambda x: str(x), to_list(retrieval_gt))

results, execution_times = zip(
*map(
Expand Down
10 changes: 3 additions & 7 deletions autorag/nodes/passagefilter/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.schema.metricinput import MetricInput
from autorag.strategy import measure_speed, filter_by_threshold, select_best
from autorag.utils.util import to_list, apply_recursive


def run_passage_filter_node(
Expand Down Expand Up @@ -37,13 +38,8 @@ def run_passage_filter_node(
os.path.join(project_dir, "data", "qa.parquet"), engine="pyarrow"
)
retrieval_gt = qa_df["retrieval_gt"].tolist()
retrieval_gt = [
[
[str(uuid) for uuid in sub_array] if sub_array.size > 0 else []
for sub_array in inner_array
]
for inner_array in retrieval_gt
]
retrieval_gt = apply_recursive(lambda x: str(x), to_list(retrieval_gt))

# make rows to metric_inputs
metric_inputs = [
MetricInput(retrieval_gt=ret_gt, query=query, generation_gt=gen_gt)
Expand Down
10 changes: 3 additions & 7 deletions autorag/nodes/passagereranker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.schema.metricinput import MetricInput
from autorag.strategy import measure_speed, filter_by_threshold, select_best
from autorag.utils.util import apply_recursive, to_list

logger = logging.getLogger("AutoRAG")

Expand Down Expand Up @@ -40,13 +41,8 @@ def run_passage_reranker_node(
os.path.join(project_dir, "data", "qa.parquet"), engine="pyarrow"
)
retrieval_gt = qa_df["retrieval_gt"].tolist()
retrieval_gt = [
[
[str(uuid) for uuid in sub_array] if sub_array.size > 0 else []
for sub_array in inner_array
]
for inner_array in retrieval_gt
]
retrieval_gt = apply_recursive(lambda x: str(x), to_list(retrieval_gt))

# make rows to metric_inputs
metric_inputs = [
MetricInput(retrieval_gt=ret_gt, query=query, generation_gt=gen_gt)
Expand Down
10 changes: 2 additions & 8 deletions autorag/nodes/retrieval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from autorag.schema.metricinput import MetricInput
from autorag.strategy import measure_speed, filter_by_threshold, select_best
from autorag.support import get_support_modules
from autorag.utils.util import get_best_row, to_list
from autorag.utils.util import get_best_row, to_list, apply_recursive

logger = logging.getLogger("AutoRAG")

Expand Down Expand Up @@ -47,13 +47,7 @@ def run_retrieval_node(
os.path.join(project_dir, "data", "qa.parquet"), engine="pyarrow"
)
retrieval_gt = qa_df["retrieval_gt"].tolist()
retrieval_gt = [
[
[str(uuid) for uuid in sub_array] if sub_array.size > 0 else []
for sub_array in inner_array
]
for inner_array in retrieval_gt
]
retrieval_gt = apply_recursive(lambda x: str(x), to_list(retrieval_gt))
# make rows to metric_inputs
metric_inputs = [
MetricInput(retrieval_gt=ret_gt, query=query, generation_gt=gen_gt)
Expand Down
20 changes: 20 additions & 0 deletions autorag/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,23 @@ def pop_params(func: Callable, kwargs: Dict) -> Dict:
if key in target_params:
init_params[key] = kwargs.pop(key)
return init_params


def apply_recursive(func, data):
"""
Recursively apply a function to all elements in a list, tuple, set, np.ndarray, or pd.Series and return as List.
:param func: Function to apply to each element.
:param data: List or nested list.
:return: List with the function applied to each element.
"""
if (
isinstance(data, list)
or isinstance(data, tuple)
or isinstance(data, set)
or isinstance(data, np.ndarray)
or isinstance(data, pd.Series)
):
return [apply_recursive(func, item) for item in data]
else:
return func(data)
37 changes: 37 additions & 0 deletions docs/source/api_spec/autorag.data.qa.evolve.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
autorag.data.qa.evolve package
==============================

Submodules
----------

autorag.data.qa.evolve.llama\_index\_query\_evolve module
---------------------------------------------------------

.. automodule:: autorag.data.qa.evolve.llama_index_query_evolve
:members:
:undoc-members:
:show-inheritance:

autorag.data.qa.evolve.openai\_query\_evolve module
---------------------------------------------------

.. automodule:: autorag.data.qa.evolve.openai_query_evolve
:members:
:undoc-members:
:show-inheritance:

autorag.data.qa.evolve.prompt module
------------------------------------

.. automodule:: autorag.data.qa.evolve.prompt
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: autorag.data.qa.evolve
:members:
:undoc-members:
:show-inheritance:
8 changes: 4 additions & 4 deletions tests/autorag/nodes/retrieval/test_run_retrieval_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def node_line_dir():
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as test_project_dir:
sample_project_dir = os.path.join(resources_dir, "sample_project")
# copy & paste all folders and files in sample_project folder
# copy & paste all folders and files in the sample_project folder
shutil.copytree(sample_project_dir, test_project_dir, dirs_exist_ok=True)

chroma_path = os.path.join(test_project_dir, "resources", "chroma")
Expand All @@ -39,9 +39,9 @@ def node_line_dir():
corpus_df = pd.read_parquet(corpus_path)
vectordb_ingest(collection, corpus_df, MockEmbedding(1536))

test_trail_dir = os.path.join(test_project_dir, "test_trial")
os.makedirs(test_trail_dir)
node_line_dir = os.path.join(test_trail_dir, "test_node_line")
test_trial_dir = os.path.join(test_project_dir, "test_trial")
os.makedirs(test_trial_dir)
node_line_dir = os.path.join(test_trial_dir, "test_node_line")
os.makedirs(node_line_dir)
yield node_line_dir

Expand Down
27 changes: 27 additions & 0 deletions tests/autorag/utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_event_loop,
find_key_values,
pop_params,
apply_recursive,
)
from tests.mock import MockLLM

Expand Down Expand Up @@ -609,3 +610,29 @@ def func_mixed(param1, param2, *args, **kwargs):
result = pop_params(func_mixed, kwargs)
assert result == expected
assert kwargs == {"extra_param": "extra_value"}


def test_apply_recursive():
data = [1, 2, 3, 4]
result = apply_recursive(lambda x: x * 2, data)
assert result == [2, 4, 6, 8]

data = [[1, 2], [3, 4]]
result = apply_recursive(lambda x: x * 2, data)
assert result == [[2, 4], [6, 8]]

data = [[1, [2, 3]], [4, [5, 6]]]
result = apply_recursive(lambda x: x * 2, data)
assert result == [[2, [4, 6]], [8, [10, 12]]]

data = []
result = apply_recursive(lambda x: x * 2, data)
assert result == []

data = 5
result = apply_recursive(lambda x: x * 2, data)
assert result == 10

data = [(4, 5), (6, 7), [5, [6, 7]], np.array([4, 5]), pd.Series([4, 5])]
result = apply_recursive(lambda x: x * 2, data)
assert result == [[8, 10], [12, 14], [10, [12, 14]], [8, 10], [8, 10]]

0 comments on commit 8a090fd

Please sign in to comment.