Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
model: pytorch: Updated to include is_trained flag
Browse files Browse the repository at this point in the history
  • Loading branch information
programmer290399 authored and pdxjohnny committed Nov 9, 2021
1 parent 8554f09 commit dae433a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PytorchAccuracyContext(AccuracyContext):
async def score(
self, mctx: ModelContext, sctx: Sources, *features: Features
):
if not mctx.parent.model_path.exists():
if not mctx.is_trained:
raise ModelNotTrained("Train model before assessing for accuracy.")

dataset, size = await mctx.dataset_generator(sctx)
Expand Down
4 changes: 3 additions & 1 deletion model/pytorch/dffml_model_pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,15 @@ async def train(self, sources: Sources):
"Best Validation Accuracy: {:4f}".format(best_acc)
)
self.parent.model.load_state_dict(best_model_wts)
self.is_trained = True

async def predict(
self, sources: SourcesContext
) -> AsyncIterator[Tuple[Record, Any, float]]:
"""
Uses trained data to make a prediction about the quality of a record.
"""
if not self.parent.model_path.exists():
if not self.is_trained:
raise ModelNotTrained("Train model before prediction.")

self.parent.model.eval()
Expand Down Expand Up @@ -399,6 +400,7 @@ async def __aenter__(self) -> "PyTorchModel":
if self.model_path.exists():
self.logger.info(f"Using saved model from {self.model_path}")
self.model = torch.load(self.model_path)
self.is_trained = True
else:
self.model = self.createModel()
return self
Expand Down

0 comments on commit dae433a

Please sign in to comment.