Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Oct 25, 2024
1 parent c686a5e commit 8cf1378
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 52 deletions.
13 changes: 4 additions & 9 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class MLFlowLogger(LoggerDestination):
log_duplicated_metric_every_n_steps (int, optional): The number of steps to wait before
logging the duplicated metric value. Duplicated metric value means the new step has the
same value as the previous step. (default: ``100``)
log_duplicated_metric_every_n_millis (int, optional): The number of milliseconds to wait
before logging the duplicated metric value. (default: ``600000``)
"""

def __init__(
Expand All @@ -146,7 +144,6 @@ def __init__(
resume: bool = False,
logging_buffer_seconds: Optional[int] = 10,
log_duplicated_metric_every_n_steps: int = 100,
log_duplicated_metric_every_n_millis: int = 600000,
) -> None:
try:
import mlflow
Expand Down Expand Up @@ -189,7 +186,6 @@ def __init__(
mlflow.set_system_metrics_sampling_interval(5)

self.log_duplicated_metric_every_n_steps = log_duplicated_metric_every_n_steps
self.log_duplicated_metric_every_n_millis = log_duplicated_metric_every_n_millis
self._metrics_cache = {}

self._rank_zero_only = rank_zero_only
Expand Down Expand Up @@ -403,25 +399,24 @@ def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> No

metrics_to_log = {}
step = step or 0
current_time_millis = int(time.time() * 1000)
for k, v in metrics.items():
if any(fnmatch.fnmatch(k, pattern) for pattern in self.ignore_metrics):
continue
if k in self._metrics_cache:
value, last_step, last_time = self._metrics_cache[k]
if value == v and step < last_step + self.log_duplicated_metric_every_n_steps and current_time_millis < last_time + self.log_duplicated_metric_every_n_millis:
value, last_step = self._metrics_cache[k]
if value == v and step < last_step + self.log_duplicated_metric_every_n_steps:
# Skip logging the metric if it has the same value as the last step and it's
# within the step and time window.
continue
else:
# Log the metric if it has a different value or it's outside the step and time
# window, and update the metrics cache.
self._metrics_cache[k] = (v, step, current_time_millis)
self._metrics_cache[k] = (v, step)
metrics_to_log[self.rename(k)] = float(v)
else:
# Log the metric if it's the first time it's being logged, and update the metrics
# cache.
self._metrics_cache[k] = (v, step, current_time_millis)
self._metrics_cache[k] = (v, step)
metrics_to_log[self.rename(k)] = float(v)

log_metrics(
Expand Down
60 changes: 17 additions & 43 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def test_mlflow_ignore_metrics(self, num_batches, device, ignore_metrics, expect
logger = MLFlowLogger(
tracking_uri=tmp_path / Path('my-test-mlflow-uri'),
ignore_metrics=ignore_metrics,
log_duplicated_metric_every_n_steps=0,
)

file_path = self.run_trainer(logger, num_batches)
Expand Down Expand Up @@ -854,52 +855,25 @@ def test_mlflow_logging_with_metrics_dedupping(tmp_path):
run_name='test_run',
logging_buffer_seconds=2,
log_duplicated_metric_every_n_steps=3,
log_duplicated_metric_every_n_millis=10000,
)
test_mlflow_logger.init(state=mock_state, logger=mock_logger)
# # Test dedupping of metrics and duplicated metrics get logged per
# # `log_duplicated_metric_every_n_steps` steps.
# steps = 10
# for i in range(steps):
# # 'foo' always have different values, while 'bar' always have the same value.
# metrics = {
# 'foo': i,
# 'bar': 0,
# }
# test_mlflow_logger.log_metrics(metrics, step=i)

# if i % 3 == 0:
# # 'bar' will be logged every 3 steps.
# mock_log_metrics.assert_called_with(metrics={'foo': float(i), 'bar': 0.0}, step=i, synchronous=False)
# else:
# # 'bar' will not be logged.
# mock_log_metrics.assert_called_with(metrics={'foo': float(i)}, step=i, synchronous=False)

# Test dedupping of metrics and duplicated metrics get logged per
# `log_duplicated_metric_every_n_millis` milliseconds.
timestamps = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
# Reset the metrics cache.
test_mlflow_logger._metrics_cache = {}
with patch('time.time', side_effect=timestamps):
for i in range(len(timestamps)):
# 'foo' always have different values, while 'bar' always have the same value.
metrics = {
'foo': i,
'bar': 0,
}
test_mlflow_logger.log_metrics(metrics, step=0)

if i % 2 == 0:
# 'bar' will be logged every 2 steps.
mock_log_metrics.assert_called_with(
metrics={
'foo': float(i),
'bar': 0.0,
}, step=0, synchronous=False,
)
else:
# 'bar' will not be logged.
mock_log_metrics.assert_called_with(metrics={'foo': float(i)}, step=0, synchronous=False)
# `log_duplicated_metric_every_n_steps` steps.
steps = 10
for i in range(steps):
# 'foo' always have different values, while 'bar' always have the same value.
metrics = {
'foo': i,
'bar': 0,
}
test_mlflow_logger.log_metrics(metrics, step=i)

if i % 3 == 0:
# 'bar' will be logged every 3 steps.
mock_log_metrics.assert_called_with(metrics={'foo': float(i), 'bar': 0.0}, step=i, synchronous=False)
else:
# 'bar' will not be logged.
mock_log_metrics.assert_called_with(metrics={'foo': float(i)}, step=i, synchronous=False)

test_mlflow_logger.post_close()

Expand Down

0 comments on commit 8cf1378

Please sign in to comment.