Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CI #1

Merged
merged 8 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: pre-commit/action@v2.0.3
- uses: pre-commit/action@v3.0.0
test:
name: Run Tests
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ['3.8', '3.9', '3.10']
steps:
- name: Check out the code
uses: actions/checkout@v3
Expand All @@ -30,32 +29,41 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: snok/install-poetry@v1.1.1
uses: snok/install-poetry@v1.3.3
with:
version: 1.2.1

- name: Setup Poetry
run: |
poetry config virtualenvs.in-project true

- name: Cache
id: cache
uses: actions/cache@v3.2.2
with:
version: 1.1.4
path: '.venv'
key: run-tests-${{ hashFiles('poetry.lock') }}

- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
poetry config virtualenvs.create false
pip install -U certifi
if [ -d ".venv" ]; then rm -rf .venv; fi
poetry run pip install -U certifi
poetry install

- name: Run Tests
run: pytest --cov=elegy --cov-report=term-missing --cov-report=xml
run: poetry run pytest --cov=ciclo --cov-report=term-missing --cov-report=xml

- name: Upload coverage
uses: codecov/codecov-action@v1

- name: Test Examples
run: bash scripts/test-examples.sh
uses: codecov/codecov-action@v3.1.1

test-import:
name: Test Import without Dev Dependencies
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ['3.8', '3.9', '3.10']
steps:
- name: Check out the code
uses: actions/checkout@v3
Expand All @@ -67,15 +75,28 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: snok/install-poetry@v1.1.1
uses: snok/install-poetry@v1.3.3
with:
version: 1.1.4
version: 1.2.1

- name: Setup Poetry
run: |
poetry config virtualenvs.in-project true

- name: Cache
id: cache
uses: actions/cache@v3.2.2
with:
path: '.venv'
key: test-import-${{ hashFiles('poetry.lock') }}

- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -U certifi
poetry config virtualenvs.create false
if [ -d ".venv" ]; then rm -rf .venv; fi
poetry run pip install -U certifi
poetry install --no-dev

- name: Test Import Elegy
run: python -c "import elegy"
- name: Test Import
run: |
poetry run python -c "import ciclo"
214 changes: 117 additions & 97 deletions ciclo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum, auto
from typing import Any, Callable, Dict, Optional, Tuple, Union, overload

from flax.training import checkpoints as flax_checkpoints
from pkbar import Kbar
from tqdm import tqdm

Expand All @@ -24,16 +23,15 @@
from ciclo.utils import get_batch_size, is_scalar


# import wandb Run
def _get_Run():
if importlib.util.find_spec("wandb") is not None:
from wandb.wandb_run import Run
else:
locals()["Run"] = Any
return Run
def unavailable_dependency(msg: str) -> Any:
class DependencyNotAvailable(LoopCallbackBase[S]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError(msg)

def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
raise RuntimeError(msg)

Run = _get_Run()
return DependencyNotAvailable


class OptimizationMode(str, Enum):
Expand Down Expand Up @@ -96,83 +94,95 @@ def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
return self(loop_state.state)


class checkpoint(LoopCallbackBase[S]):
def __init__(
self,
ckpt_dir: Union[str, os.PathLike],
prefix: str = "checkpoint_",
keep: int = 1,
overwrite: bool = False,
keep_every_n_steps: Optional[int] = None,
async_manager: Optional[flax_checkpoints.AsyncManager] = None,
monitor: Optional[str] = None,
mode: Union[str, OptimizationMode] = "min",
):
if isinstance(mode, str):
mode = OptimizationMode[mode]

if mode not in OptimizationMode:
raise ValueError(
f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}"
)
else:
self.mode = mode

self.ckpt_dir = ckpt_dir
self.prefix = prefix
self.keep = keep
self.overwrite = overwrite
self.keep_every_n_steps = keep_every_n_steps
self.async_manager = async_manager
self.monitor = monitor
self.minimize = self.mode == OptimizationMode.min
self._best: Optional[float] = None

def __call__(
self, elapsed: Elapsed, state: S, logs: Optional[LogsLike] = None
) -> None:
save_checkpoint = True
step_or_metric = elapsed.steps
overwrite = self.overwrite
if importlib.util.find_spec("tensorflow") is not None:
from flax.training import checkpoints as flax_checkpoints

class checkpoint(LoopCallbackBase[S]):
def __init__(
self,
ckpt_dir: Union[str, os.PathLike],
prefix: str = "checkpoint_",
keep: int = 1,
overwrite: bool = False,
keep_every_n_steps: Optional[int] = None,
async_manager: Optional[flax_checkpoints.AsyncManager] = None,
monitor: Optional[str] = None,
mode: Union[str, OptimizationMode] = "min",
):
if isinstance(mode, str):
mode = OptimizationMode[mode]

if self.monitor is not None:
if logs is None:
if mode not in OptimizationMode:
raise ValueError(
"checkpoint callback requires logs to monitor a metric"
f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}"
)
if not isinstance(logs, Logs):
logs = Logs(logs)

try:
value = logs.entry_value(self.monitor)
except KeyError:
raise ValueError(f"Monitored value '{self.monitor}' not found in logs")

if (
self._best is None
or (self.minimize and value < self._best)
or (not self.minimize and value > self._best)
):
self._best = value
step_or_metric = value if self.mode == OptimizationMode.max else -value
else:
save_checkpoint = False

if save_checkpoint:
flax_checkpoints.save_checkpoint(
ckpt_dir=self.ckpt_dir,
target=state,
step=step_or_metric,
prefix=self.prefix,
keep=self.keep,
overwrite=overwrite,
keep_every_n_steps=self.keep_every_n_steps,
async_manager=self.async_manager,
)
self.mode = mode

self.ckpt_dir = ckpt_dir
self.prefix = prefix
self.keep = keep
self.overwrite = overwrite
self.keep_every_n_steps = keep_every_n_steps
self.async_manager = async_manager
self.monitor = monitor
self.minimize = self.mode == OptimizationMode.min
self._best: Optional[float] = None

def __call__(
self, elapsed: Elapsed, state: S, logs: Optional[LogsLike] = None
) -> None:
save_checkpoint = True
step_or_metric = elapsed.steps
overwrite = self.overwrite

if self.monitor is not None:
if logs is None:
raise ValueError(
"checkpoint callback requires logs to monitor a metric"
)
if not isinstance(logs, Logs):
logs = Logs(logs)

try:
value = logs.entry_value(self.monitor)
except KeyError:
raise ValueError(
f"Monitored value '{self.monitor}' not found in logs"
)

if (
self._best is None
or (self.minimize and value < self._best)
or (not self.minimize and value > self._best)
):
self._best = value
step_or_metric = (
value if self.mode == OptimizationMode.max else -value
)
else:
save_checkpoint = False

if save_checkpoint:
flax_checkpoints.save_checkpoint(
ckpt_dir=self.ckpt_dir,
target=state,
step=step_or_metric,
prefix=self.prefix,
keep=self.keep,
overwrite=overwrite,
keep_every_n_steps=self.keep_every_n_steps,
async_manager=self.async_manager,
)

def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
self(loop_state.elapsed, loop_state.state, loop_state.accumulated_logs)
return {}, loop_state.state
def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
self(loop_state.elapsed, loop_state.state, loop_state.accumulated_logs)
return {}, loop_state.state

else:
checkpoint = unavailable_dependency(
"'tensorflow' package is not available, please install it to use the 'checkpoint' callback"
)


class early_stopping(LoopCallbackBase[S]):
Expand Down Expand Up @@ -443,25 +453,35 @@ def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
return {}, loop_state.state


class wandb_logger(LoopCallbackBase[S]):
def __init__(self, run: Run):
self.run = run
if importlib.util.find_spec("wandb") is not None:
from wandb.wandb_run import Run

def __call__(self, elapsed: Elapsed, logs: LogsLike) -> None:
data = {}
for collection, collection_logs in logs.items():
for key, value in collection_logs.items():
if is_scalar(value):
if key in data:
key = f"{collection}.{key}"
data[key] = value
class wandb_logger(LoopCallbackBase[S]):
def __init__(self, run: Run):
from wandb.wandb_run import Run

if len(data) > 0:
self.run.log(data, step=elapsed.steps)
self.run: Run = run

def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
self(loop_state.elapsed, loop_state.logs)
return {}, loop_state.state
def __call__(self, elapsed: Elapsed, logs: LogsLike) -> None:
data = {}
for collection, collection_logs in logs.items():
for key, value in collection_logs.items():
if is_scalar(value):
if key in data:
key = f"{collection}.{key}"
data[key] = value

if len(data) > 0:
self.run.log(data, step=elapsed.steps)

def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
self(loop_state.elapsed, loop_state.logs)
return {}, loop_state.state

else:
wandb_logger = unavailable_dependency(
"'wandb' package is not available, please install it to use the 'wandb_logger' callback"
)


class NoOp(LoopCallbackBase[S]):
Expand Down
Loading