Skip to content

Commit

Permalink
refactor image.py masking.
Browse files Browse the repository at this point in the history
  • Loading branch information
emirkmo committed Feb 8, 2023
1 parent 0b79d64 commit 2ab39d7
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions flows/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from enum import Enum
import numpy as np
from numpy.typing import NDArray
from dataclasses import dataclass
import warnings
from typing import Union
from astropy.time import Time
from astropy.wcs import WCS, FITSFixedWarning
from typing import Tuple, Dict, Any, Optional
from typing import Tuple, Dict, Any, Optional, TypeGuard
from .utilities import create_logger
logger = create_logger()

Expand Down Expand Up @@ -43,38 +44,52 @@ class FlowsImage:
subclean: Optional[np.ma.MaskedArray] = None
error: Optional[np.ma.MaskedArray] = None

def __post_init__(self):
def __post_init__(self) -> None:
self.shape = self.image.shape
self.wcs = self.create_wcs()
# Make empty mask
if self.mask is None:
self.mask = np.zeros_like(self.image, dtype='bool')
self.check_finite()
# Create mask
self.initialize_mask()

def initialize_mask(self) -> None:
self.update_mask(self.mask)

def check_finite(self):
self.mask |= ~np.isfinite(self.image)
def check_finite(self) -> None:
if self.ensure_mask(self.mask):
self.mask |= ~np.isfinite(self.image)

def mask_non_linear(self) -> None:
if self.peakmax is None:
return
if self.ensure_mask(self.mask):
self.mask |= self.image >= self.peakmax

def ensure_mask(self, mask: Optional[np.ndarray]) -> TypeGuard[NDArray[np.bool_]]:
if mask is None:
self.mask = np.zeros_like(self.image, dtype='bool')
return True

def update_mask(self, mask):
def update_mask(self, mask) -> None:
self.mask = mask
self.check_finite()
self.mask_non_linear()

def create_wcs(self) -> WCS:
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=FITSFixedWarning)
return WCS(header=self.header, relax=True)

def create_masked_image(self):
def create_masked_image(self) -> None:
"""Warning: this is destructive and will overwrite image data setting masked values to NaN"""
self.image[self.mask] = np.NaN
self.clean = np.ma.masked_array(data=self.image, mask=self.mask, copy=False)

def set_edge_rows_to_value(self, y: Tuple[float] = None, value: Union[int, float, np.float64] = 0):
def set_edge_rows_to_value(self, y: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None:
if y is None:
pass
for row in y:
self.image[row] = value

def set_edge_columns_to_value(self, x: Tuple[float] = None, value: Union[int, float, np.float64] = 0):
def set_edge_columns_to_value(self, x: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None:
if x is None:
pass
for col in x:
Expand Down

0 comments on commit 2ab39d7

Please sign in to comment.