Skip to content

Commit

Permalink
delete ids and scores module at summary, and add target_modules and t…
Browse files Browse the repository at this point in the history
…arget_module_params for deploy
  • Loading branch information
jeffrey committed Feb 3, 2024
1 parent 293b8f7 commit 25978e8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
25 changes: 24 additions & 1 deletion autorag/nodes/retrieval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,15 @@ def run_and_save(input_modules, input_module_params, filename_start: int):
target_modules = list(map(lambda x: x.pop('target_modules'), hybrid_module_params))
target_filenames = list(map(lambda x: select_result_for_hybrid(save_dir, x), target_modules))
ids_scores = list(map(lambda x: get_ids_and_scores(save_dir, x), target_filenames))
target_module_params = list(map(lambda x: get_module_params(save_dir, x), target_filenames))
hybrid_module_params = list(map(lambda x: {**x[0], **x[1]}, zip(hybrid_module_params, ids_scores)))
real_hybrid_times = list(map(lambda filename: get_hybrid_execution_times(save_dir, filename), target_filenames))
hybrid_results, hybrid_times, hybrid_summary_df = run_and_save(hybrid_modules, hybrid_module_params, filename_first)
hybrid_results, hybrid_times, hybrid_summary_df = run_and_save(hybrid_modules, hybrid_module_params,
filename_first)
filename_first += len(hybrid_modules)
hybrid_times = real_hybrid_times.copy()
hybrid_summary_df['execution_time'] = hybrid_times
hybrid_summary_df = edit_summary_df_params(hybrid_summary_df, target_modules, target_module_params)
else:
hybrid_results, hybrid_times, hybrid_summary_df = [], [], pd.DataFrame()

Expand Down Expand Up @@ -158,6 +161,26 @@ def select_best_among_module(df: pd.DataFrame, module_name: str):
return best_filenames


def get_module_params(node_dir: str, filenames: List[str]) -> List[Dict]:
summary_df = load_summary_file(os.path.join(node_dir, "summary.csv"))
best_results = summary_df[summary_df['filename'].isin(filenames)]
module_params = best_results['module_params'].tolist()
return module_params


def edit_summary_df_params(summary_df: pd.DataFrame, target_modules, target_module_params) -> pd.DataFrame:
def delete_ids_scores(x):
del x['ids']
del x['scores']
return x

summary_df['module_params'] = summary_df['module_params'].apply(delete_ids_scores)
summary_df['new_params'] = [{'target_modules': x, 'target_module_params': y} for x, y in zip(target_modules, target_module_params)]
summary_df['module_params'] = summary_df.apply(lambda row: {**row['module_params'], **row['new_params']}, axis=1)
summary_df = summary_df.drop(columns=['new_params'])
return summary_df


def get_ids_and_scores(node_dir: str, filenames: List[str]) -> Dict:
best_results_df = list(map(lambda filename: pd.read_parquet(os.path.join(node_dir, filename)), filenames))
ids = tuple(map(lambda df: df['retrieved_ids'].apply(list).tolist(), best_results_df))
Expand Down
10 changes: 10 additions & 0 deletions tests/autorag/nodes/retrieval/test_run_retrieval_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def test_run_retrieval_node(node_line_dir):
assert summary_df['filename'].nunique() == len(summary_df)
assert len(summary_df[summary_df['is_best'] == True]) == 1

# test summary_df hybrid retrieval convert well
assert all(summary_df['module_params'].apply(lambda x: 'ids' not in x))
assert all(summary_df['module_params'].apply(lambda x: 'scores' not in x))
hybrid_summary_df = summary_df[summary_df['module_name'].str.contains('hybrid')]
assert all(hybrid_summary_df['module_params'].apply(lambda x: 'target_modules' in x))
assert all(hybrid_summary_df['module_params'].apply(lambda x: 'target_module_params' in x))
assert all(hybrid_summary_df['module_params'].apply(lambda x: x['target_modules'] == ('bm25', 'vectordb')))
assert all(hybrid_summary_df['module_params'].apply(
lambda x: x['target_module_params'] == [{'top_k': 4}, {'top_k': 4, 'embedding_model': 'openai'}]))

# test the best file is saved properly
best_filename = summary_df[summary_df['is_best'] == True]['filename'].values[0]
best_path = os.path.join(node_line_dir, "retrieval", f'best_{best_filename}')
Expand Down

0 comments on commit 25978e8

Please sign in to comment.