Skip to content

Commit

Permalink
Introduce Imager2 for simple images with just source and lens and no …
Browse files Browse the repository at this point in the history
…metadata output
  • Loading branch information
jiwoncpark committed Jan 2, 2021
1 parent 74c2093 commit 575e21b
Showing 1 changed file with 118 additions and 1 deletion.
119 changes: 118 additions & 1 deletion baobab/sim_utils/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,132 @@
import copy
import sys
import numpy as np
# Lenstronomy modules
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.ImSim.image_model import ImageModel
from baobab.sim_utils import mag_to_amp_extended, mag_to_amp_point, get_lensed_total_flux, get_unlensed_total_flux_numerical
from lenstronomy.LensModel.Solver.lens_equation_solver import LensEquationSolver
from lenstronomy.SimulationAPI.data_api import DataAPI
from lenstronomy.PointSource.point_source import PointSource

from baobab.sim_utils import psf_utils


__all__ = ['Imager']
__all__ = ['Imager', 'Imager2']

class Imager2:
"""Dev-mode class, more flexible than Imager.
Note
----
Accompanying `generate` script doesn't exist yet.
"""
def __init__(self, lens_model_list, src_model_list,
n_pix, pixel_scale,
psf_type, psf_kernel_size=None, which_psf_maps=None,
kwargs_numerics={'supersampling_factor': 1}):
# Define models
self.lens_model = LensModel(lens_model_list=lens_model_list)
self.src_model = LightModel(light_model_list=src_model_list)
#self.ps_model = ps_model
#self.lens_light_model = lens_light_model
# Set detector specs
self.n_pix = n_pix
self.pixel_scale = pixel_scale
self.psf_type = psf_type
self.psf_kernel_size = psf_kernel_size
self.which_psf_maps = which_psf_maps
self.kwargs_numerics = kwargs_numerics
# Initialize kwargs (must be set using setter)
self._survey = None
self._lens_kwargs = None
self._src_kwargs = None
#self._ps_kwargs = None
#self._lens_light_kwargs = None

@property
def survey_kwargs(self):
"""Ordered dict containing detector information. Length is number of
bandpasses. Should be set before the model kwargs.
"""
return self._survey_kwargs

@survey_kwargs.setter
def survey_kwargs(self, survey_kwargs):
survey_name = survey_kwargs['survey_name']
bandpass_list = survey_kwargs['bandpass_list']
coadd_years = survey_kwargs.get('coadd_years')
override_obs_kwargs = survey_kwargs.get('override_obs_kwargs', {})
override_camera_kwargs = survey_kwargs.get('override_camera_kwargs', {})

import lenstronomy.SimulationAPI.ObservationConfig as ObsConfig
from importlib import import_module
sys.path.insert(0, ObsConfig.__path__[0])
SurveyClass = getattr(import_module(survey_name), survey_name)
self._data_api = [] # init
self._image_model = [] # init
for bp in bandpass_list:
survey_obj = SurveyClass(band=bp,
psf_type=self.psf_type,
coadd_years=coadd_years)
# Override as specified in survey_kwargs
survey_obj.camera.update(override_camera_kwargs)
survey_obj.obs.update(override_obs_kwargs)
# This is what we'll actually use
kwargs_detector = survey_obj.kwargs_single_band()
data_api = DataAPI(self.n_pix, **kwargs_detector)
psf_model = psf_utils.get_PSF_model(self.psf_type,
self.pixel_scale,
seeing=kwargs_detector['seeing'],
kernel_size=self.psf_kernel_size,
which_psf_maps=self.which_psf_maps)
image_model_bp = ImageModel(data_api.data_class,
psf_model,
self.lens_model,
self.src_model,
None,
None,
kwargs_numerics=self.kwargs_numerics)
self._data_api.append(data_api)
self._image_model.append(image_model_bp)

@property
def lens_kwargs(self):
return self._lens_kwargs

@lens_kwargs.setter
def lens_kwargs(self, lens_kwargs):
self._lens_kwargs = lens_kwargs

@property
def src_kwargs(self):
return self._src_kwargs

@src_kwargs.setter
def src_kwargs(self, src_kwargs):
for i, data_api_bp in enumerate(self._data_api):
# Convert magnitude to amp recognized by the profile
if 'magnitude' in src_kwargs[i]:
src_kwargs[i] = mag_to_amp_extended([src_kwargs[i]],
self.src_model,
data_api_bp)[0]
self._src_kwargs = src_kwargs

def generate_image(self):
n_filters = len(self._image_model)
img_canvas = np.empty([n_filters, self.n_pix, self.n_pix])
for i, image_model_bp in enumerate(self._image_model):
img = image_model_bp.image(self.lens_kwargs,
self.src_kwargs,
None, None,
lens_light_add=False,
point_source_add=False)
img = np.maximum(0.0, img) # safeguard against negative pixel values
img_canvas[i, :, :] = img
return img_canvas

class Imager:
"""Deterministic utility class for imaging the objects on a pixel grid
Expand Down

0 comments on commit 575e21b

Please sign in to comment.