Skip to content

Commit

Permalink
feat(ci): add GitHub Action for Python tests, fix EarlyStopping logic…
Browse files Browse the repository at this point in the history
…, and add unit tests (#27)

This PR introduces several key changes:
1. A GitHub Actions workflow for continuous integration, running tests
across multiple Python versions.
2. Fixes to the `EarlyStopping` class logic, changing `best_score` to
`best_val_loss` to improve clarity and correctness.
3. A new test suite for the `EarlyStopping` class to ensure its correct
behavior after the logic fix.

### Changes:
- **Added Python application tests**:
- A new workflow named "Python Application Tests" runs on GitHub
Actions.
- Tests are executed across multiple Python versions (`3.9`, `3.10`,
`3.11`, `3.12`).
- The workflow is triggered on pushes and pull requests to the `main`
branch.
  - Pip dependencies are cached to optimize performance.

- **Fixed EarlyStopping logic**:
- The variable `best_score` was renamed to `best_val_loss` for clarity,
improving the code's readability and matching the purpose of the
variable.
- Logic for early stopping was updated to properly handle edge cases for
delta and stopping conditions based on validation loss behavior.

- **Added Unit Tests for `EarlyStopping`**:
- The file `test_early_stopping.py` contains a thorough set of tests
covering:
    - Initialization and attribute checks.
    - Behavior when validation loss improves.
    - Handling when validation loss does not improve.
    - Proper functionality of patience and delta parameters.
    - Edge cases for early stopping triggers and verbose output.

### Why this change is important:
- Automates testing across different Python versions to ensure
compatibility and reliability.
- Fixes logic issues in the `EarlyStopping` class, making it more robust
and clear.
- Ensures the `EarlyStopping` class functions as expected through the
added unit tests.
- Improves development velocity by automatically running tests on every
push and pull request.
  • Loading branch information
Bjarten authored Oct 15, 2024
1 parent 686db54 commit ffe12ee
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install python-semantic-release==9.11.0 # Locking to a specific version
pip install python-semantic-release
- name: Configure Git
run: |
Expand Down
41 changes: 41 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Python Application Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
# Step 1: Checkout the repository
- name: Checkout repository
uses: actions/checkout@v4

# Step 2: Set up Python
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

# Step 3: Install dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
# Step 4: Install pytest (if not in requirements.txt)
- name: Install pytest
run: pip install pytest

# Step 5: Run tests
- name: Run tests
run: pytest
197 changes: 104 additions & 93 deletions MNIST_Early_Stopping_example.ipynb

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions pytorchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import numpy as np
import torch


class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
Expand All @@ -22,31 +24,32 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.best_val_loss = None
self.early_stop = False
self.val_loss_min = np.Inf
self.val_loss_min = np.inf
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):

score = -val_loss
def __call__(self, val_loss, model):

if self.best_score is None:
self.best_score = score
if self.best_val_loss is None:
self.best_val_loss = val_loss
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
elif val_loss < self.best_val_loss - self.delta:
# Significant improvement detected
self.best_val_loss = val_loss
self.save_checkpoint(val_loss, model)
self.counter = 0 # Reset counter since improvement occurred
else:
# No significant improvement
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0

def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
'''Saves model when validation loss decreases.'''
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# requirements-dev.txt
-r requirements.txt
pytest
pytest-mock
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# requirements.txt
matplotlib
numpy
torchvision
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit ffe12ee

Please sign in to comment.