Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Query_expansion node and its decorator and run function. #67

Merged
merged 33 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d50fc82
Create query expansion decorator and not finished run.py
bwook00 Jan 25, 2024
3dd0cc8
Merge branch 'main' into Feature/#56
bwook00 Jan 27, 2024
02650b2
Implement query_expansion decorator(base.py) and not yet run.py witho…
bwook00 Jan 27, 2024
3a0c154
Merge branch 'main' into Feature/#56
bwook00 Jan 27, 2024
0137ba4
Add query_expansion node_type in yaml file.
bwook00 Jan 27, 2024
8f55a4c
delete () in init generator_models
bwook00 Jan 27, 2024
dbf6cf5
Create Query expansion decorator without test
bwook00 Jan 27, 2024
bd25a3d
Add decorator in hyde, query_decompose.py
bwook00 Jan 27, 2024
4765803
Add if prompt is None, prompt is default prompt
bwook00 Jan 27, 2024
29200d2
pop prompt in kwargs
bwook00 Jan 27, 2024
8f6c408
apply wrapped in test hyde
bwook00 Jan 27, 2024
27ff4ab
change module name
bwook00 Jan 27, 2024
e0970a7
query_expansion node. test pass version
bwook00 Jan 27, 2024
4648950
Add node test in test_hyde.py
bwook00 Jan 27, 2024
308f38f
move retrieval module and top to strategy
bwook00 Jan 28, 2024
ebb2e1f
move retrieval module and top to strategy
bwook00 Jan 28, 2024
064266d
Add query_decompose and hyde module
bwook00 Jan 28, 2024
e05c243
not complete test query expansion run
bwook00 Jan 28, 2024
437f710
Merge branch 'main' into Feature/#56
bwook00 Jan 28, 2024
318a2b0
Create query_expansion run.py without testcode
bwook00 Jan 28, 2024
3825ada
Create query_expansion run's testcode
bwook00 Jan 28, 2024
c4f31f2
Add annotation
bwook00 Jan 28, 2024
c9c3f7b
fix hyde
bwook00 Jan 28, 2024
9786b07
fix hyde
bwook00 Jan 28, 2024
0bcac82
Merge branch 'main' into Feature/#56
bwook00 Jan 29, 2024
b108670
resolve conversation
bwook00 Jan 29, 2024
56578fb
[Not complete] Replacing run.py after the pattern of the prompt maker
bwook00 Jan 29, 2024
6819eea
[complete] Replacing run.py after the pattern of the prompt maker
bwook00 Jan 29, 2024
bdd5e7d
Merge branch 'main' into Feature/#56
bwook00 Jan 29, 2024
e430679
fix test code
bwook00 Jan 29, 2024
c488c0d
implement mock llm in hyde test
bwook00 Jan 29, 2024
513ea6e
fix annotation
bwook00 Jan 29, 2024
367f639
add sphinx
bwook00 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions autorag/nodes/queryexpansion/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import functools
from pathlib import Path
from typing import List, Union

import pandas as pd

from autorag import generator_models

from autorag.utils import result_to_dataframe, validate_qa_dataset

import logging

logger = logging.getLogger("AutoRAG")


def query_expansion_node(func):
@functools.wraps(func)
@result_to_dataframe(["queries"])
def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> List[List[str]]:
validate_qa_dataset(previous_result)

# find queries columns
assert "query" in previous_result.columns, "previous_result must have query column."
queries = previous_result["query"].tolist()

# set module parameters
llm_str = kwargs.pop("llm")

# pop prompt from kwargs
if "prompt" in kwargs.keys():
prompt = kwargs.pop("prompt")
else:
prompt = ""

# set llm model for query expansion
if llm_str in generator_models:
llm = generator_models[llm_str](**kwargs)
else:
logger.error(f"llm_str {llm_str} does not exist.")
raise KeyError(f"llm_str {llm_str} does not exist.")

# run query expansion function
expanded_queries = func(queries=queries, llm=llm, prompt=prompt)

return expanded_queries

return wrapper



5 changes: 5 additions & 0 deletions autorag/nodes/queryexpansion/hyde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from llama_index.llms.llm import BaseLLM

from autorag.nodes.queryexpansion.base import query_expansion_node

hyde_prompt = "Please write a passage to answer the question"


@query_expansion_node
def hyde(queries: List[str], llm: BaseLLM,
prompt: str = hyde_prompt) -> List[List[str]]:
"""
Expand All @@ -26,6 +29,8 @@ def hyde(queries: List[str], llm: BaseLLM,

async def hyde_pure(query: str, llm: BaseLLM,
prompt: str = hyde_prompt) -> List[str]:
if prompt is "":
prompt = hyde_prompt
full_prompt = prompt + f"\nQuestion: {query}\nPassage:"
hyde_answer = llm.complete(full_prompt)
return [hyde_answer.text]
5 changes: 5 additions & 0 deletions autorag/nodes/queryexpansion/query_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from llama_index.llms.llm import BaseLLM

from autorag.nodes.queryexpansion.base import query_expansion_node

decompose_prompt = """Decompose a question in self-contained sub-questions. Use \"The question needs no decomposition\" when no decomposition is needed.

Example 1:
Expand Down Expand Up @@ -50,6 +52,7 @@
"""


@query_expansion_node
def query_decompose(queries: List[str], llm: BaseLLM,
prompt: str = decompose_prompt) -> List[List[str]]:
"""
Expand Down Expand Up @@ -77,6 +80,8 @@ async def query_decompose_pure(query: str, llm: BaseLLM,
default prompt comes from Visconde's StrategyQA few-shot prompt.
:return: List[str], list of decomposed query. Return input query if query is not decomposable.
"""
if prompt == "":
prompt = decompose_prompt
full_prompt = "prompt: " + prompt + "\n\n" "question: " + query
answer = llm.complete(full_prompt)
if answer.text == "the question needs no decomposition.":
Expand Down
169 changes: 169 additions & 0 deletions autorag/nodes/queryexpansion/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
import os
import pathlib
from typing import List, Callable, Dict, Optional
from copy import deepcopy

import pandas as pd

from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average
from autorag.utils.util import make_module_file_name, make_combinations, explode
from autorag.support import get_support_modules

logger = logging.getLogger("AutoRAG")


def run_query_expansion_node(modules: List[Callable],
module_params: List[Dict],
previous_result: pd.DataFrame,
node_line_dir: str,
strategies: Dict,
) -> pd.DataFrame:
"""
Run evaluation and select the best module among query expansion node results.
bwook00 marked this conversation as resolved.
Show resolved Hide resolved
Initially, retrieval is run using expanded_queries, the result of the query_expansion module.
The retrieval module is run as a combination of the retrieval_modules in strategies.
If there are multiple retrieval_modules, run them all and choose the best result.
If there are no retrieval_modules, run them with the default of bm25.
In this way, the best result is selected for each module, and then the best result is selected.

:param modules: Query expansion modules to run.
:param module_params: Query expansion module parameters.
:param previous_result: Previous result dataframe.
In this case, it would be qa data.
:param node_line_dir: This node line's directory.
:param strategies: Strategies for query expansion node.
:return: The best result dataframe.
"""
if not os.path.exists(node_line_dir):
os.makedirs(node_line_dir)
node_dir = os.path.join(node_line_dir, "query_expansion")
if not os.path.exists(node_dir):
os.makedirs(node_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent

# run query expansion
results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
average_times = list(map(lambda x: x / len(results[0]), execution_times))

# save results to folder
pseudo_module_params = deepcopy(module_params)
for i, module_param in enumerate(pseudo_module_params):
if 'prompt' in module_params:
module_param['prompt'] = str(i)
filepaths = list(map(lambda x: os.path.join(node_dir, make_module_file_name(x[0].__name__, x[1])),
zip(modules, pseudo_module_params)))
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet
filenames = list(map(lambda x: os.path.basename(x), filepaths))

# make summary file
summary_df = pd.DataFrame({
'filename': filenames,
'module_name': list(map(lambda module: module.__name__, modules)),
'module_params': module_params,
'execution_time': average_times,
})

# Run evaluation when there are more than one module.
if len(modules) > 1:
# pop general keys from strategies (e.g. metrics, speed_threshold)
general_key = ['metrics', 'speed_threshold']
general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))
extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))

# first, filter by threshold if it is enabled.
if general_strategy.get('speed_threshold') is not None:
results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],
filenames)

# check metrics in strategy
if general_strategy.get('metrics') is None:
raise ValueError("You must at least one metrics for query expansion evaluation.")

if extra_strategy.get('top_k') is None:
extra_strategy['top_k'] = 10 # default value

# get retrieval modules from strategy
retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)

# get retrieval_gt
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()

# run evaluation
evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(
retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,
general_strategy['metrics'], project_dir, previous_result), results))

evaluation_df = pd.DataFrame({
'filename': filenames,
**{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))
for metric_name in general_strategy['metrics']}
})
summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')

best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)
# change metric name columns to query_expansion_metric_name
best_result = best_result.rename(columns={
metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})
best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])
else:
best_result, best_filename = results[0], filenames[0]
best_result = pd.concat([previous_result, best_result], axis=1)

# add 'is_best' column at summary file
summary_df['is_best'] = summary_df['filename'] == best_filename

# save files
summary_df.to_parquet(os.path.join(node_dir, "summary.parquet"), index=False)
best_result.to_parquet(os.path.join(node_dir, f"best_{os.path.splitext(best_filename)[0]}.parquet"), index=False)

return best_result


def evaluate_one_query_expansion_node(retrieval_funcs: List[Callable],
retrieval_params: List[Dict],
expanded_queries: List[List[str]],
retrieval_gt: List[List[str]],
metrics: List[str],
project_dir,
previous_result: pd.DataFrame) -> pd.DataFrame:
previous_result['queries'] = expanded_queries
retrieval_results = list(map(lambda x: x[0](project_dir=project_dir, previous_result=previous_result, **x[1]),
zip(retrieval_funcs, retrieval_params)))
evaluation_results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, metrics),
retrieval_results))
best_result, _ = select_best_average(evaluation_results, metrics)
best_result = pd.concat([previous_result, best_result], axis=1)
return best_result


def make_retrieval_callable_params(strategy_dict: Dict):
"""
strategy_dict looks like this:

.. Code:: json

{
"metrics": ["retrieval_f1", "retrieval_recall"],
"top_k": 50,
"retrieval_modules": [
{"module_type": "bm25"},
{"module_type": "vectordb", "embedding_model": ["openai", "huggingface"]}
]
}

"""
node_dict = deepcopy(strategy_dict)
retrieval_module_list: Optional[List[Dict]] = node_dict.pop('retrieval_modules', None)
if retrieval_module_list is None:
retrieval_module_list = [{
'module_type': 'bm25',
}]
node_params = node_dict
modules = list(map(lambda module_dict: get_support_modules(module_dict.pop('module_type')),
retrieval_module_list))
param_combinations = list(map(lambda module_dict: make_combinations({**module_dict, **node_params}),
retrieval_module_list))
return explode(modules, param_combinations)
4 changes: 4 additions & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from typing import Callable, Dict
import importlib

Expand All @@ -14,6 +15,8 @@ def dynamically_find_function(key: str, target_dict: Dict) -> Callable:

def get_support_modules(module_name: str) -> Callable:
support_modules = {
'query_decompose': ('autorag.nodes.queryexpansion', 'query_decompose'),
'hyde': ('autorag.nodes.queryexpansion', 'hyde'),
'bm25': ('autorag.nodes.retrieval', 'bm25'),
'vectordb': ('autorag.nodes.retrieval', 'vectordb'),
'fstring': ('autorag.nodes.promptmaker', 'fstring'),
Expand All @@ -24,6 +27,7 @@ def get_support_modules(module_name: str) -> Callable:

def get_support_nodes(node_name: str) -> Callable:
support_nodes = {
'query_expansion': ('autorag.nodes.queryexpansion.run', 'run_query_expansion_node'),
'retrieval': ('autorag.nodes.retrieval.run', 'run_retrieval_node'),
'generator': ('autorag.nodes.generator.run', 'run_generator_node'),
'prompt_maker': ('autorag.nodes.promptmaker.run', 'run_prompt_maker_node'),
Expand Down
27 changes: 27 additions & 0 deletions docs/source/api_spec/autorag.nodes.promptmaker.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,33 @@
autorag.nodes.promptmaker package
=================================

Submodules
----------

autorag.nodes.promptmaker.base module
-------------------------------------

.. automodule:: autorag.nodes.promptmaker.base
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.promptmaker.fstring module
----------------------------------------

.. automodule:: autorag.nodes.promptmaker.fstring
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.promptmaker.run module
------------------------------------

.. automodule:: autorag.nodes.promptmaker.run
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
16 changes: 16 additions & 0 deletions docs/source/api_spec/autorag.nodes.queryexpansion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ autorag.nodes.queryexpansion package
Submodules
----------

autorag.nodes.queryexpansion.base module
----------------------------------------

.. automodule:: autorag.nodes.queryexpansion.base
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.queryexpansion.hyde module
----------------------------------------

Expand All @@ -20,6 +28,14 @@ autorag.nodes.queryexpansion.query\_decompose module
:undoc-members:
:show-inheritance:

autorag.nodes.queryexpansion.run module
---------------------------------------

.. automodule:: autorag.nodes.queryexpansion.run
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ autorag.strategy module
:undoc-members:
:show-inheritance:

autorag.support module
----------------------

.. automodule:: autorag.support
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
11 changes: 10 additions & 1 deletion tests/autorag/nodes/queryexpansion/test_hyde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from llama_index.llms.openai import OpenAI

from autorag.nodes.queryexpansion import hyde
from tests.autorag.nodes.queryexpansion.test_query_expansion_base import project_dir, previous_result, \
base_query_expansion_node_test, ingested_vectordb_node

sample_query = ["How many members are in Newjeans?", "What is visconde structure?"]


def test_hyde():
llm = OpenAI(max_tokens=64)
result = hyde(sample_query, llm)
original_hyde = hyde.__wrapped__
result = original_hyde(sample_query, llm, prompt="")
assert len(result[0]) == 1
assert len(result) == 2


def test_hyde_node(ingested_vectordb_node):
result_df = hyde(project_dir=project_dir, previous_result=previous_result,
llm="openai", max_tokens=64)
base_query_expansion_node_test(result_df)
Loading
Loading