Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
getting datagen tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
alecgunny committed Nov 22, 2022
1 parent 961a7aa commit 640de50
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 59 deletions.
14 changes: 7 additions & 7 deletions projects/sandbox/datagen/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ channels:
- conda-forge
dependencies:
- python=3.9
- gwpy=2.1.0
- astropy<5.0.0
- gwpy
- astropy
# frame file I/O dependencies
- python-ldas-tools-framecpp
- python-nds2-client
# omicron dependencies
- python-ligo-lw<1.8.0
- python-ligo-lw
- pyomicron
- omicron=2.4.2
- dqsegdb2<1.1.0
- uproot>=4.3
- omicron
- dqsegdb2
- uproot
# injection dependencies
- bilby<1.2
- lalsuite>=7.4
- lalsuite
4 changes: 3 additions & 1 deletion projects/sandbox/datagen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ typeo = {git = "https://github.com/ML4GW/typeo.git", branch = "main"}
ml4gw = {path = "../../../ml4gw", develop = true}
bilby = "<1.2" # set this so ml4gw doesn't update it on accident

"bbhnet.parallelize" = {path = "../../../libs/parallelize", develop = true}
"bbhnet.io" = {path = "../../../libs/io", develop = true}
"bbhnet.logging" = {path = "../../../libs/logging", develop = true}

[tool.poetry.group.dev.dependencies]
pytest = "^6.2"
pytest = "^7.0"

[build-system]
requires = ["poetry>=1.2"]
Expand Down
6 changes: 4 additions & 2 deletions projects/sandbox/datagen/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

@pytest.fixture(scope="function")
def datadir(tmp_path):
datadir = tmp_path.mkdir(parents=True, exist_ok=False) / "data"
datadir = tmp_path / "data"
datadir.mkdir(parents=True, exist_ok=False)
return datadir


@pytest.fixture(scope="function")
def logdir(tmp_path):
logdir = tmp_path.mkdir(parents=True, exist_ok=False) / "log"
logdir = tmp_path / "log"
logdir.mkdir(parents=True, exist_ok=False)
yield logdir
logging.shutdown()
4 changes: 3 additions & 1 deletion projects/sandbox/datagen/tests/scripts/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_generate_background(
ts = TimeSeries(np.ones(n_samples), times=times)

mock_ts = patch("gwpy.timeseries.TimeSeries.read", return_value=ts)
mock_datafind = patch("generate_background.find_urls", return_value=None)
mock_datafind = patch(
"datagen.scripts.background.find_urls", return_value=None
)
with mock_ts, mock_datafind:
generate_background(
start,
Expand Down
1 change: 0 additions & 1 deletion projects/sandbox/datagen/tests/scripts/test_glitches.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def trig_file(ifo):


def test_generate_glitch_dataset(
data_dir,
ifo,
window,
sample_rate,
Expand Down
95 changes: 53 additions & 42 deletions projects/sandbox/datagen/tests/scripts/test_timeslides.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import patch
from concurrent.futures._base import FINISHED
from unittest.mock import Mock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -40,11 +41,6 @@ def n_slides(request):
return request.param


@pytest.fixture(params=[10, 4096])
def file_length(request):
return request.param


@pytest.fixture(params=[32])
def highpass(request):
return request.param
Expand Down Expand Up @@ -85,7 +81,22 @@ def state_flag(request):
return request.param


def submit_mock(f, *args, **kwargs):
result = Mock()
result.result = Mock(return_value=f(*args, **kwargs))
result._state = FINISHED
return result


pool_mock = Mock()
pool_mock.submit = submit_mock


@patch("bbhnet.parallelize.AsyncExecutor.__enter__", return_value=pool_mock)
@patch("bbhnet.parallelize.AsyncExecutor.__exit__")
def test_timeslide_injections_no_segments(
mock1,
mock2,
logdir,
datadir,
prior,
Expand All @@ -94,7 +105,6 @@ def test_timeslide_injections_no_segments(
buffer_,
n_slides,
shifts,
file_length,
ifos,
minimum_frequency,
highpass,
Expand All @@ -115,23 +125,22 @@ def test_timeslide_injections_no_segments(
mock_datafind = patch("gwdatafind.find_urls", return_value=None)
with mock_datafind, mock_ts:
generate_timeslides(
start,
stop,
logdir,
datadir,
prior,
spacing,
jitter,
buffer_,
n_slides,
shifts,
ifos,
file_length,
minimum_frequency,
highpass,
sample_rate,
frame_type,
channel,
start=start,
stop=stop,
logdir=logdir,
datadir=datadir,
prior=prior,
spacing=spacing,
jitter=jitter,
buffer_=buffer_,
n_slides=n_slides,
shifts=shifts,
ifos=ifos,
minimum_frequency=minimum_frequency,
highpass=highpass,
sample_rate=sample_rate,
frame_type=frame_type,
channel=channel,
)

timeslides = datadir.iterdir()
Expand Down Expand Up @@ -161,7 +170,11 @@ def test_timeslide_injections_no_segments(
assert (injection_ts.path / "params.h5").exists()


@patch("bbhnet.parallelize.AsyncExecutor.__enter__", return_value=pool_mock)
@patch("bbhnet.parallelize.AsyncExecutor.__exit__")
def test_timeslide_injections_with_segments(
mock1,
mock2,
logdir,
datadir,
prior,
Expand All @@ -170,7 +183,6 @@ def test_timeslide_injections_with_segments(
buffer_,
n_slides,
shifts,
file_length,
ifos,
minimum_frequency,
highpass,
Expand Down Expand Up @@ -213,23 +225,22 @@ def fake_read(*args, **kwargs):

with mock_datafind, mock_ts, mock_segments:
generate_timeslides(
start,
stop,
logdir,
datadir,
prior,
spacing,
jitter,
buffer_,
n_slides,
shifts,
ifos,
file_length,
minimum_frequency,
highpass,
sample_rate,
frame_type,
channel,
start=start,
stop=stop,
logdir=logdir,
datadir=datadir,
prior=prior,
spacing=spacing,
jitter=jitter,
buffer_=buffer_,
n_slides=n_slides,
shifts=shifts,
ifos=ifos,
minimum_frequency=minimum_frequency,
highpass=highpass,
sample_rate=sample_rate,
frame_type=frame_type,
channel=channel,
state_flag=state_flag,
)

Expand Down
8 changes: 4 additions & 4 deletions projects/sandbox/datagen/tests/scripts/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def prior(request):


def test_check_file_contents(
data_dir,
log_dir,
datadir,
logdir,
n_samples,
waveform_duration,
sample_rate,
Expand All @@ -47,8 +47,8 @@ def test_check_file_contents(
signal_file = generate_waveforms(
prior,
n_samples,
log_dir,
data_dir,
logdir,
datadir,
reference_frequency,
minimum_frequency,
sample_rate,
Expand Down
1 change: 0 additions & 1 deletion projects/sandbox/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ ifos = "${base.ifos}"
channel = "${base.channel}"
frame_type = "${base.frame_type}"
state_flag = "${base.state_flag}"
file_length = 4096
min_segment_length = 1024

# timeslide parameters
Expand Down

0 comments on commit 640de50

Please sign in to comment.