Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP-related improvements to data module and logging #594

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/data/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: src.data.mnist_datamodule.MNISTDataModule
data_dir: ${paths.data_dir}
batch_size: 128
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [55_000, 5_000, 10_000]
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/hydra/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
filename: ${hydra.runtime.output_dir}/${task_name}.log
16 changes: 13 additions & 3 deletions src/data/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size

@property
def num_classes(self) -> int:
"""Get the number of classes.
Expand Down Expand Up @@ -112,6 +114,14 @@ def setup(self, stage: Optional[str] = None) -> None:

:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
Expand All @@ -130,7 +140,7 @@ def train_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
Expand All @@ -143,7 +153,7 @@ def val_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand All @@ -156,7 +166,7 @@ def test_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand Down
18 changes: 12 additions & 6 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
extras,
get_ranked_pylogger,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = get_ranked_pylogger(__name__)


@utils.task_wrapper
@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.

Expand All @@ -48,7 +54,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
Expand All @@ -63,7 +69,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
Expand All @@ -84,7 +90,7 @@ def main(cfg: DictConfig) -> None:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

evaluate(cfg)

Expand Down
24 changes: 16 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
extras,
get_metric_value,
get_ranked_pylogger,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = get_ranked_pylogger(__name__)


@utils.task_wrapper
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
Expand All @@ -53,10 +61,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
Expand All @@ -72,7 +80,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

if cfg.get("compile"):
log.info("Compiling model!")
Expand Down Expand Up @@ -110,13 +118,13 @@ def main(cfg: DictConfig) -> Optional[float]:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

# train the model
metric_dict, _ = train(cfg)

# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)

Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import log_hyperparameters
from src.utils.pylogger import get_pylogger
from src.utils.pylogger import get_ranked_pylogger
from src.utils.rich_utils import enforce_tags, print_config_tree
from src.utils.utils import extras, get_metric_value, task_wrapper
2 changes: 1 addition & 1 deletion src/utils/instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.get_ranked_pylogger(__name__)


def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
Expand Down
4 changes: 2 additions & 2 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import OmegaConf

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.get_ranked_pylogger(__name__)


@rank_zero_only
Expand Down
41 changes: 35 additions & 6 deletions src/utils/pylogger.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
import logging
from functools import wraps
from typing import Callable, Optional, ParamSpec, TypeVar

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only


def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
def get_ranked_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes with
their rank prefixed in the log message.

:param name: The name of the logger, defaults to ``__name__``.

:return: A logger object.
"""
T = TypeVar("T")
P = ParamSpec("P")

def _rank_prefixed_log(fn: Callable[P, T]) -> Callable[P, Optional[T]]:
"""Wrap a logging function to prefix its message with the rank of the process it's being
logged from.

If `'rank'` is provided in the wrapped functions kwargs, then the log will only occur on
that rank/process.
"""

@wraps(fn)
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
rank = getattr(rank_zero_only, "rank", None)
if rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
rank_to_log = kwargs.get("rank", None)
msg = rank_prefixed_message(args[0], rank)
if rank_to_log is None:
return fn(msg=msg, *args[1:], **kwargs)
elif rank == rank_to_log:
return fn(msg=msg, *args[1:], **kwargs)
else:
return None

return wrapped_fn

logger = logging.getLogger(name)

# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
# This ensures all logging levels get marked with the _rank_prefixed_log decorator.
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
setattr(logger, level, _rank_prefixed_log(getattr(logger, level)))

return logger
4 changes: 2 additions & 2 deletions src/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.get_ranked_pylogger(__name__)


@rank_zero_only
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)
log = pylogger.get_ranked_pylogger(__name__)


def extras(cfg: DictConfig) -> None:
Expand Down