Skip to content

Commit

Permalink
Implemented fit_reader() method and fixed fit() method. (#131)
Browse files Browse the repository at this point in the history
* replaced self.model by self.reader

* Implemented fit_reader(), fixed fit() and updated doc
  • Loading branch information
andrelmfarias authored and fmikaelian committed May 14, 2019
1 parent 29da9b6 commit 4ce5502
Showing 1 changed file with 43 additions and 33 deletions.
76 changes: 43 additions & 33 deletions cdqa/pipeline/cdqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,34 @@ class QAPipeline(BaseEstimator):
Parameters
----------
metadata : pandas.DataFrame
metadata: pandas.DataFrame
dataframe containing your corpus of documents metadata
header should be of format: date, title, category, link, abstract, paragraphs, content.
model : str or .joblib object of a version of BERT model with sklearn wrapper, optional
bert_version : str
reader: str (path to .joblib) or .joblib object of an instance of BertQA (BERT model with sklearn wrapper), optional
bert_version: str
Bert pre-trained model selected in the list: bert-base-uncased,
bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,
bert-base-multilingual-cased, bert-base-chinese.
kwargs: kwargs for BertQA(), BertProcessor() and TfidfRetriever()
Please check documentation for these classes
Examples
--------
>>> from cdqa.pipeline.qa_pipeline import QAPipeline
>>> qa_pipe = QAPipeline(model='bert_qa_squad_vCPU-sklearn.joblib', metadata=df)
>>> qa_pipe.fit()
>>> prediction = qa_pipe.predict(X='When BNP Paribas was created?')
>>> qa_pipeline = QAPipeline(reader='bert_qa_squad_vCPU-sklearn.joblib')
>>> qa_pipeline.fit(X=df)
>>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?')
>>> from cdqa.pipeline.qa_pipeline import QAPipeline
>>> qa_pipe = QAPipeline(metadata=df)
>>> qa_pipe.fit('train-v1.1.json', fit_reader=True)
>>> qa_pipe.fit()
>>> prediction = qa_pipe.predict(X='When BNP Paribas was created?')
>>> qa_pipeline = QAPipeline()
>>> qa_pipeline.fit_reader('train-v1.1.json')
>>> qa_pipeline.fit(X=df)
>>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?')
"""

def __init__(self, model=None, **kwargs):
def __init__(self, reader=None, **kwargs):

# Separating kwargs
kwargs_bertqa = {key: value for key, value in kwargs.items()
Expand All @@ -53,12 +55,12 @@ def __init__(self, model=None, **kwargs):
kwargs_retriever = {key: value for key, value in kwargs.items()
if key in TfidfRetriever.__init__.__code__.co_varnames}

if not model:
self.model = BertQA(**kwargs_bertqa)
elif type(model) == str:
self.model = joblib.load(model)
if not reader:
self.reader = BertQA(**kwargs_bertqa)
elif type(reader) == str:
self.reader = joblib.load(reader)
else:
self.model = model
self.reader = reader

self.processor_train = BertProcessor(is_training=True,
**kwargs_processor)
Expand All @@ -68,42 +70,50 @@ def __init__(self, model=None, **kwargs):

self.retriever = TfidfRetriever(**kwargs_retriever)

def fit(self, X=None, y=None, fit_reader=False):
""" Fit the QAPipeline retriever to a list of documents in a dataframe if fit_reader is false,
fit the reader (QABert model) to a json file squad-like with questions and answers
def fit(self, X=None, y=None):
""" Fit the QAPipeline retriever to a list of documents in a dataframe.
Parameters
----------
X: dict or str
Dictionaire with questions and answers in SQUAD format or path to json file in SQUAD format
fit_reader: boolean, default false
Whether to fit reader (BertQA model) or retriever
X: pandas.Dataframe
Dataframe with the following columns: "title", "paragraphs" and "content"
"""

self.metadata = X
self.retriever.fit(self.metadata['content'])

if not fit_reader:
self.retriever.fit(self.metadata['content'])
else:
if not X:
raise RuntimeError(
'fit_reader is True, please pass a json file in SQUAD format as input')
train_examples, train_features = self.processor_train.fit_transform(X)
self.model.fit(X=(train_examples, train_features))
return self

def fit_reader(self, X=None, y=None):
"""Train the reader (BertQA instance) of QAPipeline object
Parameters
----------
X = path to json file in SQUAD format
"""

train_examples, train_features = self.processor_train.fit_transform(X)
self.reader.fit(X=(train_examples, train_features))

return self

def predict(self, X):
def predict(self, X=None):
""" Compute prediction of an answer to a question
Parameters
----------
X = str
Sample (question) to perform a prediction on
"""

closest_docs_indices = self.retriever.predict(X, metadata=self.metadata)
squad_examples = generate_squad_examples(question=X,
closest_docs_indices=closest_docs_indices,
metadata=self.metadata)
examples, features = self.processor_predict.fit_transform(X=squad_examples)
prediction = self.model.predict((examples, features))
prediction = self.reader.predict((examples, features))

return prediction

0 comments on commit 4ce5502

Please sign in to comment.