Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated BaseOutputHandler to accept state attributes - Tensorboard #2137

Merged
merged 18 commits into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
state_attributes: Optional[List[str]] = None,
):

if metric_names is not None:
Expand All @@ -63,8 +64,8 @@ def __init__(
if output_transform is not None and not callable(output_transform):
raise TypeError(f"output_transform should be a function, got {type(output_transform)} instead.")

if output_transform is None and metric_names is None:
raise ValueError("Either metric_names or output_transform should be defined")
if output_transform is None and metric_names is None and state_attributes is None:
raise ValueError("Either metric_names, output_transform or state_attributes should be defined")

if global_step_transform is not None and not callable(global_step_transform):
raise TypeError(f"global_step_transform should be a function, got {type(global_step_transform)} instead.")
Expand All @@ -78,16 +79,17 @@ def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int
self.metric_names = metric_names
self.output_transform = output_transform
self.global_step_transform = global_step_transform
self.state_attributes = state_attributes

def _setup_output_metrics(
def _setup_output_metrics_state_attrs(
self, engine: Engine, log_text: Optional[bool] = False, key_tuple: Optional[bool] = True
) -> Dict[Any, Any]:
"""Helper method to setup metrics to log
"""Helper method to setup metrics and state attributes to log
"""
metrics = OrderedDict()
metrics_state_attrs = OrderedDict()
if self.metric_names is not None:
if isinstance(self.metric_names, str) and self.metric_names == "all":
metrics = OrderedDict(engine.state.metrics)
metrics_state_attrs = OrderedDict(engine.state.metrics)
else:
for name in self.metric_names:
if name not in engine.state.metrics:
Expand All @@ -96,17 +98,20 @@ def _setup_output_metrics(
f"in engine's state metrics: {list(engine.state.metrics.keys())}"
)
continue
metrics[name] = engine.state.metrics[name]
metrics_state_attrs[name] = engine.state.metrics[name]

if self.output_transform is not None:
output_dict = self.output_transform(engine.state.output)

if not isinstance(output_dict, dict):
output_dict = {"output": output_dict}

metrics.update({name: value for name, value in output_dict.items()})
metrics_state_attrs.update({name: value for name, value in output_dict.items()})

metrics_dict = {} # type: Dict[Any, Union[str, float, numbers.Number]]
if self.state_attributes is not None:
metrics_state_attrs.update({name: getattr(engine.state, name, None) for name in self.state_attributes})

metrics_state_attrs_dict = OrderedDict() # type: Dict[Any, Union[str, float, numbers.Number]]

def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
return (tag, name) + args
Expand All @@ -116,20 +121,20 @@ def key_str_tf(tag: str, name: str, *args: str) -> str:

key_tf = key_tuple_tf if key_tuple else key_str_tf

for name, value in metrics.items():
for name, value in metrics_state_attrs.items():
if isinstance(value, numbers.Number):
metrics_dict[key_tf(self.tag, name)] = value
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
metrics_dict[key_tf(self.tag, name)] = value.item()
metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
for i, v in enumerate(value):
metrics_dict[key_tf(self.tag, name, str(i))] = v.item()
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
else:
if isinstance(value, str) and log_text:
metrics_dict[key_tf(self.tag, name)] = value
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
else:
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
return metrics_dict
return metrics_state_attrs_dict


class BaseWeightsScalarHandler(BaseHandler):
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/clearml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str,
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler OutputHandler works only with ClearMLLogger")

metrics = self._setup_output_metrics(engine)
metrics = self._setup_output_metrics_state_attrs(engine)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str,
if not isinstance(logger, MLflowLogger):
raise TypeError("Handler 'OutputHandler' works only with MLflowLogger")

rendered_metrics = self._setup_output_metrics(engine)
rendered_metrics = self._setup_output_metrics_state_attrs(engine)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str,
if not isinstance(logger, NeptuneLogger):
raise TypeError("Handler OutputHandler works only with NeptuneLogger")

metrics = self._setup_output_metrics(engine, key_tuple=False)
metrics = self._setup_output_metrics_state_attrs(engine, key_tuple=False)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/polyaxon_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str
if not isinstance(logger, PolyaxonLogger):
raise RuntimeError("Handler 'OutputHandler' works only with PolyaxonLogger")

metrics = self._setup_output_metrics(engine, key_tuple=False)
metrics = self._setup_output_metrics_state_attrs(engine, key_tuple=False)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

Expand Down
25 changes: 22 additions & 3 deletions ignite/contrib/handlers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerPar


class OutputHandler(BaseOutputHandler):
"""Helper handler to log engine's output and/or metrics
"""Helper handler to log engine's output, engine's state attributes and/or metrics

Examples:

Expand Down Expand Up @@ -234,6 +234,21 @@ def global_step_transform(*args, **kwargs):
global_step_transform=global_step_transform
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
)

Another example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
are also logged along with the NLL and Accuracy after each iteration:

.. code-block:: python

tb_logger.attach(
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
trainer,
log_handler=OutputHandler(
tag="training",
metric_names=["nll", "accuracy"],
state_attributes=["alpha", "beta"],
),
event_name=Events.ITERATION_COMPLETED
)

Args:
tag: common title for all produced plots. For example, "training"
metric_names: list of metric names to plot or a string "all" to plot all available
Expand All @@ -247,6 +262,7 @@ def global_step_transform(*args, **kwargs):
Default is None, global_step based on attached engine. If provided,
uses function output as global_step. To setup global step from another engine, please use
:meth:`~ignite.contrib.handlers.tensorboard_logger.global_step_from_engine`.
state_attributes: list of attributes of the ``trainer.state`` to plot.

Note:

Expand All @@ -265,15 +281,18 @@ def __init__(
metric_names: Optional[List[str]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
state_attributes: Optional[List[str]] = None,
):
super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform)
super(OutputHandler, self).__init__(
tag, metric_names, output_transform, global_step_transform, state_attributes
)

def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]) -> None:

if not isinstance(logger, TensorboardLogger):
raise RuntimeError("Handler 'OutputHandler' works only with TensorboardLogger")

metrics = self._setup_output_metrics(engine, key_tuple=False)
metrics = self._setup_output_metrics_state_attrs(engine, key_tuple=False)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]
if not isinstance(global_step, int):
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, E
desc += f" [{global_step}/{max_num_of_closing_events}]"
logger.pbar.set_description(desc) # type: ignore[attr-defined]

rendered_metrics = self._setup_output_metrics(engine, log_text=True)
rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
metrics = OrderedDict()
for key, value in rendered_metrics.items():
key = "_".join(key[1:]) # tqdm has tag as description
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str,
if not isinstance(logger, VisdomLogger):
raise RuntimeError("Handler 'OutputHandler' works only with VisdomLogger")

metrics = self._setup_output_metrics(engine, key_tuple=False)
metrics = self._setup_output_metrics_state_attrs(engine, key_tuple=False)

global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, E
" Please check the output of global_step_transform."
)

metrics = self._setup_output_metrics(engine, log_text=True, key_tuple=False)
metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
logger.log(metrics, step=global_step, sync=self.sync)


Expand Down
68 changes: 61 additions & 7 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_base_output_handler_wrong_setup():
with pytest.raises(TypeError, match="output_transform should be a function"):
DummyOutputHandler("tag", metric_names=None, output_transform="abc")

with pytest.raises(ValueError, match="Either metric_names or output_transform should be defined"):
with pytest.raises(ValueError, match="Either metric_names, output_transform or state_attributes should be defined"):
DummyOutputHandler("tag", None, None)

with pytest.raises(TypeError, match="global_step_transform should be a function"):
Expand All @@ -58,36 +58,90 @@ def test_base_output_handler_setup_output_metrics():

# Only metric_names
handler = DummyOutputHandler("tag", metric_names=["a", "b"], output_transform=None)
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/a": 0, "tag/b": 1}

# Only metric_names with a warning
handler = DummyOutputHandler("tag", metric_names=["a", "c"], output_transform=None)
with pytest.warns(UserWarning):
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/a": 0}

# Only output as "output"
handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: x)
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/output": engine.state.output}

# Only output as "loss"
handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: {"loss": x})
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/loss": engine.state.output}

# Metrics and output
handler = DummyOutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/a": 0, "tag/b": 1, "tag/loss": engine.state.output}

# All metrics
handler = DummyOutputHandler("tag", metric_names="all", output_transform=None)
metrics = handler._setup_output_metrics(engine=engine, key_tuple=False)
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert metrics == {"tag/a": 0, "tag/b": 1}


def test_base_output_handler_setup_output_state_attrs():
engine = Engine(lambda engine, batch: None)
true_metrics = {"a": 0, "b": 1}
engine.state = State(metrics=true_metrics)
engine.state.alpha = 3.899
engine.state.beta = torch.tensor(5.499)
engine.state.gamma = torch.tensor([2106.0, 6.0])
engine.state.output = 12345

# Only State Attributes
handler = DummyOutputHandler(
tag="tag", metric_names=None, output_transform=None, state_attributes=["alpha", "beta", "gamma"]
)
state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert state_attrs == {
"tag/alpha": 3.899,
"tag/beta": torch.tensor(5.499),
"tag/gamma/0": 2106.0,
"tag/gamma/1": 6.0,
}

# Metrics and Attributes
handler = DummyOutputHandler(
tag="tag", metric_names=["a", "b"], output_transform=None, state_attributes=["alpha", "beta", "gamma"]
)
state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert state_attrs == {
"tag/a": 0,
"tag/b": 1,
"tag/alpha": 3.899,
"tag/beta": torch.tensor(5.499),
"tag/gamma/0": 2106.0,
"tag/gamma/1": 6.0,
}

# Metrics, Attributes and output
handler = DummyOutputHandler(
tag="tag",
metric_names="all",
output_transform=lambda x: {"loss": x},
state_attributes=["alpha", "beta", "gamma"],
)
state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
assert state_attrs == {
"tag/a": 0,
"tag/b": 1,
"tag/alpha": 3.899,
"tag/beta": torch.tensor(5.499),
"tag/gamma/0": 2106.0,
"tag/gamma/1": 6.0,
"tag/loss": engine.state.output,
}


def test_opt_params_handler_on_non_torch_optimizers():
tensor = torch.zeros([1], requires_grad=True)
base_optimizer = torch.optim.SGD([tensor], lr=0.1234)
Expand Down
22 changes: 14 additions & 8 deletions tests/ignite/contrib/handlers/test_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def test_output_handler_metric_names():
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

assert mock_logger.writer.add_scalar.call_count == 2
mock_logger.writer.add_scalar.assert_has_calls([call("tag/a", 12.23, 5), call("tag/b", 23.45, 5),], any_order=True)
mock_logger.writer.add_scalar.assert_has_calls(
[call("tag/a", 12.23, 5), call("tag/b", 23.45, 5),], any_order=True,
)

wrapper = OutputHandler("tag", metric_names=["a",])
wrapper = OutputHandler("tag", metric_names=["a",],)

mock_engine = MagicMock()
mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])})
Expand Down Expand Up @@ -131,7 +133,9 @@ def test_output_handler_metric_names():
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

assert mock_logger.writer.add_scalar.call_count == 1
mock_logger.writer.add_scalar.assert_has_calls([call("tag/a", 55.56, 7),], any_order=True)
mock_logger.writer.add_scalar.assert_has_calls(
[call("tag/a", 55.56, 7),], any_order=True,
)

# all metrics
wrapper = OutputHandler("tag", metric_names="all")
Expand All @@ -145,7 +149,9 @@ def test_output_handler_metric_names():
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

assert mock_logger.writer.add_scalar.call_count == 2
mock_logger.writer.add_scalar.assert_has_calls([call("tag/a", 12.23, 5), call("tag/b", 23.45, 5),], any_order=True)
mock_logger.writer.add_scalar.assert_has_calls(
[call("tag/a", 12.23, 5), call("tag/b", 23.45, 5),], any_order=True,
)

# log a torch tensor (ndimension = 0)
wrapper = OutputHandler("tag", metric_names="all")
Expand All @@ -160,7 +166,7 @@ def test_output_handler_metric_names():

assert mock_logger.writer.add_scalar.call_count == 2
mock_logger.writer.add_scalar.assert_has_calls(
[call("tag/a", torch.tensor(12.23).item(), 5), call("tag/b", torch.tensor(23.45).item(), 5),], any_order=True
[call("tag/a", torch.tensor(12.23).item(), 5), call("tag/b", torch.tensor(23.45).item(), 5),], any_order=True,
)


Expand Down Expand Up @@ -322,7 +328,7 @@ def test_weights_scalar_handler_frozen_layers(dummy_model_factory):
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

mock_logger.writer.add_scalar.assert_has_calls(
[call("weights_norm/fc2/weight", 12.0, 5), call("weights_norm/fc2/bias", math.sqrt(12.0), 5),], any_order=True
[call("weights_norm/fc2/weight", 12.0, 5), call("weights_norm/fc2/bias", math.sqrt(12.0), 5),], any_order=True,
)

with pytest.raises(AssertionError):
Expand Down Expand Up @@ -478,12 +484,12 @@ def test_grads_scalar_handler_frozen_layers(dummy_model_factory, norm_mock):
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

mock_logger.writer.add_scalar.assert_has_calls(
[call("grads_norm/fc2/weight", ANY, 5), call("grads_norm/fc2/bias", ANY, 5),], any_order=True
[call("grads_norm/fc2/weight", ANY, 5), call("grads_norm/fc2/bias", ANY, 5),], any_order=True,
)

with pytest.raises(AssertionError):
mock_logger.writer.add_scalar.assert_has_calls(
[call("grads_norm/fc1/weight", ANY, 5), call("grads_norm/fc1/bias", ANY, 5),], any_order=True
[call("grads_norm/fc1/weight", ANY, 5), call("grads_norm/fc1/bias", ANY, 5),], any_order=True,
)
assert mock_logger.writer.add_scalar.call_count == 2
assert norm_mock.call_count == 2
Expand Down