Skip to content

Commit

Permalink
Use cached datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Nov 29, 2020
1 parent 2e903c3 commit 4771734
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/callbacks/test_data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
)
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_override(
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir
):
""" Test logging interval set by log_every_n_steps argument. """
monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps)
model = LitMNIST(num_workers=0)
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
Expand All @@ -43,11 +43,11 @@ def test_base_log_interval_override(
)
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_fallback(
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir
):
""" Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """
monitor = TrainingDataMonitor()
model = LitMNIST(num_workers=0)
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=log_every_n_steps,
Expand Down Expand Up @@ -81,10 +81,10 @@ def test_base_unsupported_logger_warning(tmpdir):


@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram")
def test_training_data_monitor(log_histogram, tmpdir):
def test_training_data_monitor(log_histogram, tmpdir, datadir):
""" Test that the TrainingDataMonitor logs histograms of data points going into training_step. """
monitor = TrainingDataMonitor()
model = LitMNIST()
model = LitMNIST(data_dir=datadir)
trainer = Trainer(
default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor],
)
Expand Down

0 comments on commit 4771734

Please sign in to comment.