Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP-related improvements to data module and logging #594

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/data/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: src.data.mnist_datamodule.MNISTDataModule
data_dir: ${paths.data_dir}
batch_size: 128
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [55_000, 5_000, 10_000]
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/hydra/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
filename: ${hydra.runtime.output_dir}/${task_name}.log
16 changes: 13 additions & 3 deletions src/data/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size

@property
def num_classes(self) -> int:
"""Get the number of classes.
Expand Down Expand Up @@ -112,6 +114,14 @@ def setup(self, stage: Optional[str] = None) -> None:

:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
Expand All @@ -130,7 +140,7 @@ def train_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
Expand All @@ -143,7 +153,7 @@ def val_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand All @@ -156,7 +166,7 @@ def test_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand Down
18 changes: 12 additions & 6 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
RankedLogger,
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = RankedLogger(__name__, rank_zero_only=True)


@utils.task_wrapper
@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.

Expand All @@ -48,7 +54,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
Expand All @@ -63,7 +69,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
Expand All @@ -84,7 +90,7 @@ def main(cfg: DictConfig) -> None:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

evaluate(cfg)

Expand Down
24 changes: 16 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = RankedLogger(__name__, rank_zero_only=True)


@utils.task_wrapper
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
Expand All @@ -53,10 +61,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
Expand All @@ -72,7 +80,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

if cfg.get("compile"):
log.info("Compiling model!")
Expand Down Expand Up @@ -110,13 +118,13 @@ def main(cfg: DictConfig) -> Optional[float]:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

# train the model
metric_dict, _ = train(cfg)

# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)

Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import log_hyperparameters
from src.utils.pylogger import get_pylogger
from src.utils.pylogger import RankedLogger
from src.utils.rich_utils import enforce_tags, print_config_tree
from src.utils.utils import extras, get_metric_value, task_wrapper
2 changes: 1 addition & 1 deletion src/utils/instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
Expand Down
4 changes: 2 additions & 2 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import OmegaConf

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


@rank_zero_only
Expand Down
56 changes: 43 additions & 13 deletions src/utils/pylogger.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,51 @@
import logging
from typing import Mapping, Optional

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only


def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
class RankedLogger(logging.LoggerAdapter):
"""A multi-GPU-friendly python command line logger."""

:param name: The name of the logger, defaults to ``__name__``.
def __init__(
self,
name: str = __name__,
rank_zero_only: bool = False,
extra: Optional[Mapping[str, object]] = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.

:return: A logger object.
"""
logger = logging.getLogger(name)
:param name: The name of the logger. Default is ``__name__``.
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
"""
logger = logging.getLogger(name)
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only

# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.

return logger
:param level: The level to log at. Look at `logging.__init__.py` for more information.
:param msg: The message to log.
:param rank: The rank to log at.
:param args: Additional args to pass to the underlying logging function.
:param kwargs: Any additional keyword args to pass to the underlying logging function.
"""
if self.isEnabledFor(level):
msg, kwargs = self.process(msg, kwargs)
current_rank = getattr(rank_zero_only, "rank", None)
if current_rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
msg = rank_prefixed_message(msg, current_rank)
if self.rank_zero_only:
if current_rank == 0:
self.logger.log(level, msg, *args, **kwargs)
else:
if rank is None:
self.logger.log(level, msg, *args, **kwargs)
elif current_rank == rank:
self.logger.log(level, msg, *args, **kwargs)
4 changes: 2 additions & 2 deletions src/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


@rank_zero_only
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


def extras(cfg: DictConfig) -> None:
Expand Down
Loading