Skip to content

Commit

Permalink
Feature add whitelist parameter issue 2548 (#2550)
Browse files Browse the repository at this point in the history
* Remove unnecessary code in BaseOutputHandler

Closes #2438

* Add ReduceLROnPlateauScheduler

Closes #1754

* Fix indentation issue

* Fix another indentation issue

* Fix PEP8 related issues

* Fix other PEP8 related issues

* Fix hopefully the last PEP8 related issue

* Fix hopefully the last PEP8 related issue

* Remove ReduceLROnPlateau's specific params and add link to it

Also fix bug in min_lr check

* Fix state_dict bug and add a test

* Update docs

* Add doctest and fix typo

* Add whitelist param and refactor

Closes #2548

* Fix docstrings and a bug

* Change reduction parameter

* Fix zero_grad place in trainer step

Closes #2459 with help of PR #2470

* autopep8 fix

* Fix bugs

* Fix bugs in loggers

* Fix bug in test_create_supervised

* Change reduction type hint in base_logger

* Fix mypy error

* Fix bug causing missing clearml histograms

Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: sadra-barikbin <sadra-barikbin@users.noreply.github.com>
  • Loading branch information
3 people authored May 3, 2022
1 parent a74b5b0 commit 4f6b8b1
Show file tree
Hide file tree
Showing 11 changed files with 767 additions and 270 deletions.
19 changes: 16 additions & 3 deletions examples/contrib/mnist/mnist_with_clearml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,29 @@ def compute_metrics(engine):
)

clearml_logger.attach(
trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)
trainer,
log_handler=WeightsScalarHandler(model, whitelist=["fc1"]),
event_name=Events.ITERATION_COMPLETED(every=100),
)

clearml_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
def is_conv(n, _):
return "conv" in n

clearml_logger.attach(
trainer,
log_handler=WeightsHistHandler(model, whitelist=is_conv),
event_name=Events.ITERATION_COMPLETED(every=100),
)

clearml_logger.attach(
trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)
)

clearml_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
clearml_logger.attach(
trainer,
log_handler=GradsHistHandler(model, whitelist=["fc2.weight"]),
event_name=Events.ITERATION_COMPLETED(every=100),
)

handler = Checkpoint(
{"model": model},
Expand Down
22 changes: 14 additions & 8 deletions examples/contrib/mnist/mnist_with_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,28 @@ def compute_metrics(engine):

tb_logger.attach_opt_params_handler(trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer)

tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
tb_logger.attach(
trainer,
log_handler=WeightsScalarHandler(model, whitelist=["fc1"]),
event_name=Events.ITERATION_COMPLETED(every=100),
)

def is_conv(n, _):
return "conv" in n

tb_logger.attach(
trainer,
log_handler=WeightsHistHandler(
model,
whitelist=[
"conv",
],
),
log_handler=WeightsHistHandler(model, whitelist=is_conv),
event_name=Events.ITERATION_COMPLETED(every=100),
)

tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))

tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
tb_logger.attach(
trainer,
log_handler=GradsHistHandler(model, whitelist=["fc2.weight"]),
event_name=Events.ITERATION_COMPLETED(every=100),
)

def score_function(engine):
return engine.state.metrics["accuracy"]
Expand Down
70 changes: 49 additions & 21 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,43 @@ def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events])
pass


class BaseWeightsHandler(BaseHandler):
"""
Base handler for logging weights or their gradients.
"""

def __init__(
self,
model: nn.Module,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
):

if not isinstance(model, torch.nn.Module):
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")

self.model = model
self.tag = tag

weights = {}
if whitelist is None:

weights = dict(model.named_parameters())
elif callable(whitelist):

for n, p in model.named_parameters():
if whitelist(n, p):
weights[n] = p
else:

for n, p in model.named_parameters():
for item in whitelist:
if n.startswith(item):
weights[n] = p

self.weights = weights.items()


class BaseOptimizerParamsHandler(BaseHandler):
"""
Base handler for logging optimizer parameters
Expand Down Expand Up @@ -136,42 +173,33 @@ def key_str_tf(tag: str, name: str, *args: str) -> str:
return metrics_state_attrs_dict


class BaseWeightsScalarHandler(BaseHandler):
class BaseWeightsScalarHandler(BaseWeightsHandler):
"""
Helper handler to log model's weights as scalars.
Helper handler to log model's weights or gradients as scalars.
"""

def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
if not isinstance(model, torch.nn.Module):
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")
def __init__(
self,
model: nn.Module,
reduction: Callable[[torch.Tensor], Union[float, torch.Tensor]] = torch.norm,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
):

super(BaseWeightsScalarHandler, self).__init__(model, tag=tag, whitelist=whitelist)

if not callable(reduction):
raise TypeError(f"Argument reduction should be callable, but given {type(reduction)}")

def _is_0D_tensor(t: torch.Tensor) -> bool:
def _is_0D_tensor(t: Any) -> bool:
return isinstance(t, torch.Tensor) and t.ndimension() == 0

# Test reduction function on a tensor
o = reduction(torch.ones(4, 2))
if not (isinstance(o, numbers.Number) or _is_0D_tensor(o)):
raise TypeError(f"Output of the reduction function should be a scalar, but got {type(o)}")

self.model = model
self.reduction = reduction
self.tag = tag


class BaseWeightsHistHandler(BaseHandler):
"""
Helper handler to log model's weights as histograms.
"""

def __init__(self, model: nn.Module, tag: Optional[str] = None):
if not isinstance(model, torch.nn.Module):
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")

self.model = model
self.tag = tag


class BaseLogger(metaclass=ABCMeta):
Expand Down
Loading

0 comments on commit 4f6b8b1

Please sign in to comment.