Skip to content

Commit

Permalink
Update tests, remove repeat code
Browse files Browse the repository at this point in the history
  • Loading branch information
emirkmo committed Mar 10, 2022
1 parent cce8148 commit 7d264ec
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 65 deletions.
46 changes: 1 addition & 45 deletions flows/load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,48 +528,6 @@ def get_obstime(self):
'PairTel': PairTel, 'TJO': TJO}


def edge_mask(img, value=0):
"""
Create boolean mask of given value near edge of image.
Parameters:
img (ndarray): Image of
value (float): Value to detect near edge. Default=0.
Returns:
ndarray: Pixel mask with given values on the edge of image.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""

mask1 = (img == value)
mask = np.zeros_like(img, dtype='bool')

# Mask entire rows and columns which are only the value:
mask[np.all(mask1, axis=1), :] = True
mask[:, np.all(mask1, axis=0)] = True

# Detect "uneven" edges column-wise in image:
a = np.argmin(mask1, axis=0)
b = np.argmin(np.flipud(mask1), axis=0)
for col in range(img.shape[1]):
if mask1[0, col]:
mask[:a[col], col] = True
if mask1[-1, col]:
mask[-b[col]:, col] = True

# Detect "uneven" edges row-wise in image:
a = np.argmin(mask1, axis=1)
b = np.argmin(np.fliplr(mask1), axis=1)
for row in range(img.shape[0]):
if mask1[row, 0]:
mask[row, :a[row]] = True
if mask1[row, -1]:
mask[row, -b[row]:] = True

return mask


def load_image(filename: str, target_coord: typing.Union[coords.SkyCoord, typing.Tuple[float, float]] = None):
"""
Load FITS image using FlowsImage class and Instrument Classes.
Expand All @@ -581,10 +539,8 @@ def load_image(filename: str, target_coord: typing.Union[coords.SkyCoord, typing
for all other images it is ignored.
Returns:
FlowsImage: instance of FlowsImage with valuues populated based on instrument.
FlowsImage: instance of FlowsImage with values populated based on instrument.
.. codeauthor:: Emir Karamehmetoglu <emir.k@phys.au.dk>
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
ext = 0 # Default extension in HDUList, individual instruments may override this.
mask = None # Instrument can override, default is to only mask all non-finite values, override is additive.
Expand Down
40 changes: 20 additions & 20 deletions tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,15 @@
import os.path
import conftest # noqa: F401
from tendrils import api
from flows.load_image import load_image
from flows.load_image import load_image, FlowsImage, instruments

# Get list of all available filters:
ALL_FILTERS = set(api.get_filters().keys())

#--------------------------------------------------------------------------------------------------
@pytest.mark.parametrize('fpath,siteid', [
['SN2020aatc_K_20201213_495s.fits.gz', 13],
['ADP.2021-10-15T11_40_06.553.fits.gz', 2],
#['TJO2459406.56826_V_imc.fits.gz', 22],
#['lsc1m009-fa04-20210704-0044-e91_v1.fits.gz', 4],
#['SN2021rcp_59409.931159242_B.fits.gz', 22],
#['SN2021rhu_59465.86130221_B.fits.gz', 22],
#['20200613_SN2020lao_u_stacked_meandiff.fits.gz', 1],
#['2021aess_20220104_K.fits.gz', 5],
#['2021aess_B01_20220207v1.fits.gz', 5],
])
['ADP.2021-10-15T11_40_06.553.fits.gz', 2],])
def test_load_image(fpath, siteid):
# Get list of all available filters:
all_filters = set(api.get_filters().keys())

# The test input directory containing the test-images:
INPUT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'input')
Expand All @@ -45,18 +36,27 @@ def test_load_image(fpath, siteid):
# Check the attributes of the image object:
assert isinstance(img.image, np.ndarray)
assert img.image.dtype in ('float32', 'float64')
assert isinstance(img.mask, np.ndarray)
assert img.mask.dtype == 'bool'
if img.mask is not None:
assert isinstance(img.mask, np.ndarray)
assert img.mask.dtype == 'bool'
assert isinstance(img.clean, np.ma.MaskedArray)
assert img.clean.dtype == img.image.dtype
assert isinstance(img.obstime, Time)

assert isinstance(img.exptime, float)
assert img.exptime > 0
assert img.exptime > 0.
assert isinstance(img.wcs, WCS)
assert isinstance(img.site, dict)
assert img.site['siteid'] == siteid

assert isinstance(img.photfilter, str)
assert img.photfilter in all_filters
assert img.photfilter in ALL_FILTERS


def test_instruments():
for instrument_name, instrument_class in instruments.items():
instrument = instrument_class()
# get site:
site = api.get_site(instrument.siteid)

assert site['siteid'] == instrument.siteid


# --------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 7d264ec

Please sign in to comment.