Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
model: scikit: Updated to include is_trained flag
Browse files Browse the repository at this point in the history
  • Loading branch information
programmer290399 authored and pdxjohnny committed Nov 9, 2021
1 parent dae433a commit 5635dc8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions model/scikit/dffml_model_scikit/scikit_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def __aenter__(self) -> "Scikit":
self.clf_path = self._filepath / "ScikitFeatures.joblib"
if self.clf_path.is_file():
self.clf = self.joblib.load(str(self.clf_path))
self.is_trained = True
return self

async def __aexit__(self, exc_type, exc_value, traceback):
Expand All @@ -92,6 +93,7 @@ async def __aenter__(self) -> "ScikitUnsprvised":
self.clf_path = self._filepath / "ScikitUnsupervised.json"
if self.clf_path.is_file():
self.clf = self.joblib.load(str(self.clf_path))
self.is_trained = True

return self

Expand Down Expand Up @@ -180,11 +182,12 @@ async def train(self, sources: Sources):
"Model does not support multi-output. Please refer the docs to find a suitable model entrypoint."
)
self.parent.clf.fit(xdata, ydata)
self.is_trained = True

async def predict(
self, sources: SourcesContext
) -> AsyncIterator[Tuple[Record, Any, float]]:
if not self.parent.clf_path.is_file():
if not self.is_trained:
raise ModelNotTrained("Train model before prediction.")
async for record in sources.with_features(self.features):
record_data = []
Expand Down Expand Up @@ -239,11 +242,12 @@ async def train(self, sources: Sources):
xdata = self.np.array(xdata)
self.logger.info("Number of input records: {}".format(len(xdata)))
self.parent.clf.fit(xdata)
self.is_trained = True

async def predict(
self, sources: SourcesContext
) -> AsyncIterator[Tuple[Record, Any, float]]:
if not self.parent.clf_path.is_file():
if not self.is_trained:
raise ModelNotTrained("Train model before prediction.")
estimator_type = self.parent.clf._estimator_type
if estimator_type == "clusterer":
Expand Down
2 changes: 1 addition & 1 deletion model/scikit/dffml_model_scikit/scikit_model_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SklearnModelAccuracyContext(AccuracyContext):
async def score(
self, mctx: ModelContext, sctx: SourcesContext, *features: Feature,
):
if not mctx.parent.clf_path.is_file():
if not mctx.is_trained:
raise ModelNotTrained("Train model before assessing for accuracy.")

if mctx.parent.clf._estimator_type not in ("classifier", "regressor"):
Expand Down

0 comments on commit 5635dc8

Please sign in to comment.