diff --git a/cdqa/utils/converter.py b/cdqa/utils/converter.py index 9e78b82..888ac20 100644 --- a/cdqa/utils/converter.py +++ b/cdqa/utils/converter.py @@ -3,61 +3,62 @@ from tqdm import tqdm import uuid + def df2squad(df, squad_version='v2.0', output_dir=None, filename=None): - """ - Converts a pandas dataframe with columns ['title', 'content'] to a json file with SQuAD format. + """ + Converts a pandas dataframe with columns ['title', 'content'] to a json file with SQuAD format. - Parameters - ---------- - df : pandas.DataFrame - a pandas dataframe with columns ['title', 'content'] - squad_version : str, optional - the SQuAD dataset version format (the default is 'v2.0') - output_dir : str, optional - Enable export of output (the default is None) - filename : str, optional - [description] + Parameters + ---------- + df : pandas.DataFrame + a pandas dataframe with columns ['title', 'content'] + squad_version : str, optional + the SQuAD dataset version format (the default is 'v2.0') + output_dir : str, optional + Enable export of output (the default is None) + filename : str, optional + [description] - Returns - ------- - json_data: dict - A json object with SQuAD format + Returns + ------- + json_data: dict + A json object with SQuAD format - Examples - -------- - >>> from ast import literal_eval - >>> import pandas as pd - >>> from cdqa.utils.converter import df2squad, filter_paragraphs + Examples + -------- + >>> from ast import literal_eval + >>> import pandas as pd + >>> from cdqa.utils.converter import df2squad, filter_paragraphs - >>> df = pd.read_csv('../data/bnpp_newsroom_v1.0/bnpp_newsroom-v1.0.csv', converters={'paragraphs': literal_eval}) - >>> df['paragraphs'] = df['paragraphs'].apply(filter_paragraphs) + >>> df = pd.read_csv('../data/bnpp_newsroom_v1.0/bnpp_newsroom-v1.0.csv', converters={'paragraphs': literal_eval}) + >>> df['paragraphs'] = df['paragraphs'].apply(filter_paragraphs) - >>> json_data = df2squad(df=df, squad_version='v2.0', output_dir='../data', filename='bnpp_newsroom-v1.0') - """ + >>> json_data = df2squad(df=df, squad_version='v2.0', output_dir='../data', filename='bnpp_newsroom-v1.0') + """ + json_data = {} + json_data['version'] = squad_version + json_data['data'] = [] - json_data = {} - json_data['version'] = squad_version - json_data['data'] = [] + for index, row in tqdm(df.iterrows()): + temp = {'title': row['title'], + 'paragraphs': []} + for paragraph in row['paragraphs']: + temp['paragraphs'].append({'context': paragraph, + 'qas': []}) + json_data['data'].append(temp) - for index, row in tqdm(df.iterrows()): - temp = {'title': row['title'], - 'paragraphs': []} - for paragraph in row['paragraphs']: - temp['paragraphs'].append({'context': paragraph, - 'qas': []}) - json_data['data'].append(temp) + if output_dir: + with open(os.path.join(output_dir, '{}.json'.format(filename)), 'w') as outfile: + json.dump(json_data, outfile) - if output_dir: - with open(os.path.join(output_dir, '{}.json'.format(filename)), 'w') as outfile: - json.dump(json_data, outfile) + return json_data - return json_data def generate_squad_examples(question, closest_docs_indices, metadata): """ Creates a SQuAD examples json object for a given for a given question using outputs of retriever and document database. - + Parameters ---------- question : [type] @@ -66,7 +67,7 @@ def generate_squad_examples(question, closest_docs_indices, metadata): [description] metadata : [type] [description] - + Returns ------- squad_examples: list @@ -81,35 +82,35 @@ def generate_squad_examples(question, closest_docs_indices, metadata): """ - squad_examples = [] - + metadata_sliced = metadata.loc[closest_docs_indices] - + for index, row in tqdm(metadata_sliced.iterrows()): temp = {'title': row['title'], - 'paragraphs': []} - + 'paragraphs': []} + for paragraph in row['paragraphs']: temp['paragraphs'].append({'context': paragraph, 'qas': [{'answers': [], 'question': question, 'id': str(uuid.uuid1())}] - }) + }) - squad_examples.append(temp) + squad_examples.append(temp) return squad_examples -def filter_paragraphs(paragraphs): + +def filter_paragraphs(paragraphs, min_length=10, max_length=250): """ Filters out paragraphs shorter than X words and longer than Y words - + Parameters ---------- paragraphs : [type] [description] - + Returns ------- [type] @@ -126,5 +127,6 @@ def filter_paragraphs(paragraphs): """ - paragraphs_filtered = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= 10 and len(paragraph.split()) <= 250] + paragraphs_filtered = [paragraph for paragraph in paragraphs if len( + paragraph.split()) >= min_length and len(paragraph.split()) <= max_length] return paragraphs_filtered