diff --git a/cdqa/pipeline/cdqa_sklearn.py b/cdqa/pipeline/cdqa_sklearn.py index a409e92..cdf34b6 100644 --- a/cdqa/pipeline/cdqa_sklearn.py +++ b/cdqa/pipeline/cdqa_sklearn.py @@ -16,32 +16,34 @@ class QAPipeline(BaseEstimator): Parameters ---------- - metadata : pandas.DataFrame + metadata: pandas.DataFrame dataframe containing your corpus of documents metadata header should be of format: date, title, category, link, abstract, paragraphs, content. - model : str or .joblib object of a version of BERT model with sklearn wrapper, optional - bert_version : str + reader: str (path to .joblib) or .joblib object of an instance of BertQA (BERT model with sklearn wrapper), optional + bert_version: str Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese. + kwargs: kwargs for BertQA(), BertProcessor() and TfidfRetriever() + Please check documentation for these classes Examples -------- >>> from cdqa.pipeline.qa_pipeline import QAPipeline - >>> qa_pipe = QAPipeline(model='bert_qa_squad_vCPU-sklearn.joblib', metadata=df) - >>> qa_pipe.fit() - >>> prediction = qa_pipe.predict(X='When BNP Paribas was created?') + >>> qa_pipeline = QAPipeline(reader='bert_qa_squad_vCPU-sklearn.joblib') + >>> qa_pipeline.fit(X=df) + >>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?') >>> from cdqa.pipeline.qa_pipeline import QAPipeline - >>> qa_pipe = QAPipeline(metadata=df) - >>> qa_pipe.fit('train-v1.1.json', fit_reader=True) - >>> qa_pipe.fit() - >>> prediction = qa_pipe.predict(X='When BNP Paribas was created?') + >>> qa_pipeline = QAPipeline() + >>> qa_pipeline.fit_reader('train-v1.1.json') + >>> qa_pipeline.fit(X=df) + >>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?') """ - def __init__(self, model=None, **kwargs): + def __init__(self, reader=None, **kwargs): # Separating kwargs kwargs_bertqa = {key: value for key, value in kwargs.items() @@ -53,12 +55,12 @@ def __init__(self, model=None, **kwargs): kwargs_retriever = {key: value for key, value in kwargs.items() if key in TfidfRetriever.__init__.__code__.co_varnames} - if not model: - self.model = BertQA(**kwargs_bertqa) - elif type(model) == str: - self.model = joblib.load(model) + if not reader: + self.reader = BertQA(**kwargs_bertqa) + elif type(reader) == str: + self.reader = joblib.load(reader) else: - self.model = model + self.reader = reader self.processor_train = BertProcessor(is_training=True, **kwargs_processor) @@ -68,35 +70,43 @@ def __init__(self, model=None, **kwargs): self.retriever = TfidfRetriever(**kwargs_retriever) - def fit(self, X=None, y=None, fit_reader=False): - """ Fit the QAPipeline retriever to a list of documents in a dataframe if fit_reader is false, - fit the reader (QABert model) to a json file squad-like with questions and answers + def fit(self, X=None, y=None): + """ Fit the QAPipeline retriever to a list of documents in a dataframe. Parameters ---------- - X: dict or str - Dictionaire with questions and answers in SQUAD format or path to json file in SQUAD format - fit_reader: boolean, default false - Whether to fit reader (BertQA model) or retriever + X: pandas.Dataframe + Dataframe with the following columns: "title", "paragraphs" and "content" """ self.metadata = X + self.retriever.fit(self.metadata['content']) - if not fit_reader: - self.retriever.fit(self.metadata['content']) - else: - if not X: - raise RuntimeError( - 'fit_reader is True, please pass a json file in SQUAD format as input') - train_examples, train_features = self.processor_train.fit_transform(X) - self.model.fit(X=(train_examples, train_features)) + return self + + def fit_reader(self, X=None, y=None): + """Train the reader (BertQA instance) of QAPipeline object + + Parameters + ---------- + X = path to json file in SQUAD format + + """ + + train_examples, train_features = self.processor_train.fit_transform(X) + self.reader.fit(X=(train_examples, train_features)) return self - def predict(self, X): + def predict(self, X=None): """ Compute prediction of an answer to a question + Parameters + ---------- + X = str + Sample (question) to perform a prediction on + """ closest_docs_indices = self.retriever.predict(X, metadata=self.metadata) @@ -104,6 +114,6 @@ def predict(self, X): closest_docs_indices=closest_docs_indices, metadata=self.metadata) examples, features = self.processor_predict.fit_transform(X=squad_examples) - prediction = self.model.predict((examples, features)) + prediction = self.reader.predict((examples, features)) return prediction