diff --git a/model/scikit/dffml_model_scikit/scikit_base.py b/model/scikit/dffml_model_scikit/scikit_base.py index 37d2a48fa7..f23bead574 100644 --- a/model/scikit/dffml_model_scikit/scikit_base.py +++ b/model/scikit/dffml_model_scikit/scikit_base.py @@ -74,6 +74,7 @@ async def __aenter__(self) -> "Scikit": self.clf_path = self._filepath / "ScikitFeatures.joblib" if self.clf_path.is_file(): self.clf = self.joblib.load(str(self.clf_path)) + self.is_trained = True return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -92,6 +93,7 @@ async def __aenter__(self) -> "ScikitUnsprvised": self.clf_path = self._filepath / "ScikitUnsupervised.json" if self.clf_path.is_file(): self.clf = self.joblib.load(str(self.clf_path)) + self.is_trained = True return self @@ -180,11 +182,12 @@ async def train(self, sources: Sources): "Model does not support multi-output. Please refer the docs to find a suitable model entrypoint." ) self.parent.clf.fit(xdata, ydata) + self.is_trained = True async def predict( self, sources: SourcesContext ) -> AsyncIterator[Tuple[Record, Any, float]]: - if not self.parent.clf_path.is_file(): + if not self.is_trained: raise ModelNotTrained("Train model before prediction.") async for record in sources.with_features(self.features): record_data = [] @@ -239,11 +242,12 @@ async def train(self, sources: Sources): xdata = self.np.array(xdata) self.logger.info("Number of input records: {}".format(len(xdata))) self.parent.clf.fit(xdata) + self.is_trained = True async def predict( self, sources: SourcesContext ) -> AsyncIterator[Tuple[Record, Any, float]]: - if not self.parent.clf_path.is_file(): + if not self.is_trained: raise ModelNotTrained("Train model before prediction.") estimator_type = self.parent.clf._estimator_type if estimator_type == "clusterer": diff --git a/model/scikit/dffml_model_scikit/scikit_model_scorer.py b/model/scikit/dffml_model_scikit/scikit_model_scorer.py index 029e46da4e..9db4e31acd 100644 --- a/model/scikit/dffml_model_scikit/scikit_model_scorer.py +++ b/model/scikit/dffml_model_scikit/scikit_model_scorer.py @@ -26,7 +26,7 @@ class SklearnModelAccuracyContext(AccuracyContext): async def score( self, mctx: ModelContext, sctx: SourcesContext, *features: Feature, ): - if not mctx.parent.clf_path.is_file(): + if not mctx.is_trained: raise ModelNotTrained("Train model before assessing for accuracy.") if mctx.parent.clf._estimator_type not in ("classifier", "regressor"):