Skip to content

Commit

Permalink
fix: updating save_load notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
hmgomes committed May 30, 2024
1 parent eaad513 commit c4e11b0
Showing 1 changed file with 100 additions and 44 deletions.
144 changes: 100 additions & 44 deletions notebooks/save_and_load_model.ipynb
Original file line number Diff line number Diff line change
@@ -1,106 +1,162 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"id": "3b115501e7927be",
"metadata": {},
"source": [
"# Save and Load a Model\n",
"\n",
"In this tutorial, we illustrate the process of saving and loading a model using CapyMOA. The ElectricityTiny dataset serves as our data source, and we employ the AdaptiveRandomForestClassifier for model training. The trained model is then saved to a file, specifically 'capymoa_model.pkl'. Subsequently, we reload the model from this file and evaluate its performance on the ElectricityTiny dataset once more. As a final step, we delete the model file."
],
"id": "3b115501e7927be"
"In this tutorial, we illustrate the process of saving and loading a model using CapyMOA. \n",
"\n",
"* We use the SEA synthetic generator as the data source, and the AdaptiveRandomForestClassifier as the learner.\n",
"* The trained model is saved to a file, specifically 'capymoa_model.pkl'.\n",
"* Subsequently, we reload the model from the file and resume training and evaluating its performance on the SEA data.\n",
"* As a final step, we delete the model file."
]
},
{
"cell_type": "markdown",
"id": "b4f5698f-c632-44ef-b337-0549dc5a5168",
"metadata": {},
"source": [
"## 1. Training and saving the model\n",
"\n",
"* We train the model on 5k instances from SEA using the `evaluate_prequential` function\n",
"* We proceed to save the model with `save_model(learner, \"capymoa_ARF_model.pkl\")`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b7ca1c5addd95ba3",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-29T08:18:33.715465Z",
"start_time": "2024-05-29T08:18:27.317959Z"
}
},
"source": [
"from capymoa.classifier import AdaptiveRandomForestClassifier\n",
"from capymoa.evaluation import ClassificationEvaluator\n",
"from capymoa.datasets import ElectricityTiny\n",
"from capymoa.misc import save_model, load_model\n",
"\n",
"stream = ElectricityTiny()\n",
"schema = stream.get_schema()\n",
"learner = AdaptiveRandomForestClassifier(schema)\n",
"evaluator = ClassificationEvaluator(schema)\n",
"while stream.has_more_instances():\n",
" instance = stream.next_instance()\n",
" score = learner.predict(instance)\n",
" evaluator.update(instance.y_index, score)\n",
" learner.train(instance)\n",
" \n",
"acc = evaluator.accuracy()\n",
"print(f\"ACC: {acc:.2f}\")\n",
"save_model(learner, \"capymoa_model.pkl\") # Save model to capymoa_model.pkl"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ACC: 89.00\n"
"Accuracy: 87.96\n"
]
}
],
"execution_count": 1
"source": [
"from capymoa.classifier import AdaptiveRandomForestClassifier\n",
"from capymoa.evaluation import prequential_evaluation\n",
"from capymoa.stream.generator import SEA\n",
"from capymoa.misc import save_model, load_model\n",
"\n",
"stream = SEA()\n",
"learner = AdaptiveRandomForestClassifier(schema=stream.get_schema(), ensemble_size=10)\n",
"\n",
"results = prequential_evaluation(stream=stream, learner=learner, max_instances=5000)\n",
"\n",
"print(f\"Accuracy: {results['cumulative'].accuracy():.2f}\")\n",
"save_model(learner, \"capymoa_ARF_model.pkl\") # Save model to capymoa_model.pkl"
]
},
{
"cell_type": "markdown",
"id": "934db64f-b4ca-4628-9ad0-9a8d669a2c6b",
"metadata": {},
"source": [
"## 2. Loading and resuming training\n",
"\n",
"* We use `os.path.getsize()` to inspect the size (KB) of the saved file.\n",
"* We don't restart the synthetic stream, we just continue processing it through another call to `prequential_evaluation`\n",
"* Finally, we observe the accuracy "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2f7dd29e2ed65686",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-29T08:18:37.737826Z",
"start_time": "2024-05-29T08:18:33.717028Z"
}
},
"source": [
"restored_learner = load_model(\"capymoa_model.pkl\") # Load model from capymoa_model.pkl\n",
"stream = ElectricityTiny()\n",
"schema = stream.get_schema()\n",
"evaluator = ClassificationEvaluator(schema)\n",
"while stream.has_more_instances():\n",
" instance = stream.next_instance()\n",
" score = restored_learner.predict(instance)\n",
" evaluator.update(instance.y_index, score)\n",
" learner.train(instance)\n",
"print(f\"ACC: {acc:.2f}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ACC: 89.00\n"
"The saved model size: 616.66 KB\n",
"Updated accuracy: 89.32\n"
]
}
],
"execution_count": 2
"source": [
"import os\n",
"\n",
"model_file = 'capymoa_ARF_model.pkl'\n",
"\n",
"model_size = os.path.getsize(model_file)\n",
"print(f\"The saved model size: {model_size / 1024:.2f} KB\")\n",
"\n",
"restored_learner = load_model(\"capymoa_ARF_model.pkl\") # Load model from capymoa_model.pkl\n",
"\n",
"# Train for more 50k instances on the restored model\n",
"results = prequential_evaluation(stream=stream, learner=restored_learner, max_instances=5000)\n",
" \n",
"print(f\"Updated accuracy: {results['cumulative'].accuracy():.2f}\")"
]
},
{
"cell_type": "markdown",
"id": "d86a6b63-4a4d-4753-bda5-19a632aad41d",
"metadata": {},
"source": [
"## 3. Cleanup \n",
"\n",
"* As a last step, we delete the model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e00a292713d154ee",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-29T08:18:37.977906Z",
"start_time": "2024-05-29T08:18:37.739209Z"
}
},
"source": [
"!rm capymoa_model.pkl # Remove the model file"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File capymoa_ARF_model.pkl has been deleted.\n"
]
}
],
"source": [
"if os.path.exists(model_file):\n",
" os.remove(model_file)\n",
" print(f\"File {model_file} has been deleted.\")\n",
"else:\n",
" print(f\"File {model_file} not found.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59c7b60f-908e-4640-a6b8-50650d6c9287",
"metadata": {},
"outputs": [],
"execution_count": 3
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down

0 comments on commit c4e11b0

Please sign in to comment.