Skip to content

Commit

Permalink
use memmap when reading in model grids
Browse files Browse the repository at this point in the history
  • Loading branch information
keflavich committed Jun 30, 2022
1 parent ec8722e commit d7ef406
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
7 changes: 5 additions & 2 deletions sedfitter/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Fitter(object):

def __init__(self, filter_names, apertures, model_dir,
extinction_law=None, av_range=None, distance_range=None,
remove_resolved=False):
remove_resolved=False, use_memmap=True):

validate_array('apertures', apertures, domain='positive', ndim=1, physical_type='angle')
validate_array('distance_range', distance_range, domain='positive', ndim=1, shape=(2,), physical_type='length')
Expand All @@ -82,7 +82,10 @@ def __init__(self, filter_names, apertures, model_dir,
self.filters.append(filt)

# Read in models
self.models = Models.read(model_dir, self.filters, distance_range=distance_range, remove_resolved=remove_resolved)
self.models = Models.read(model_dir, self.filters,
distance_range=distance_range,
remove_resolved=remove_resolved,
use_memmap=use_memmap)

# Add wavelength to filters
for i, f in enumerate(self.filters):
Expand Down
32 changes: 25 additions & 7 deletions sedfitter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def log_fluxes_mJy(self):
return values

@classmethod
def read(cls, directory, filters, distance_range=None, remove_resolved=False):
def read(cls, directory, filters, distance_range=None,
remove_resolved=False, use_memmap=True):
modpar = parfile.read("%s/models.conf" % directory, 'conf')
if modpar.get('version', 1) == 1:
return cls._read_version_1(directory, filters,
Expand All @@ -143,7 +144,8 @@ def read(cls, directory, filters, distance_range=None, remove_resolved=False):
else:
return cls._read_version_2(directory, filters,
distance_range=distance_range,
remove_resolved=remove_resolved)
remove_resolved=remove_resolved,
use_memmap=use_memmap)

@classmethod
def _read_version_1(cls, directory, filters, distance_range=None, remove_resolved=None):
Expand Down Expand Up @@ -231,7 +233,7 @@ def _read_version_1(cls, directory, filters, distance_range=None, remove_resolve
return m

@classmethod
def _read_version_2(cls, directory, filters, distance_range=None, remove_resolved=None):
def _read_version_2(cls, directory, filters, distance_range=None, remove_resolved=None, use_memmap=True):

m = cls()

Expand Down Expand Up @@ -268,16 +270,32 @@ def _read_version_2(cls, directory, filters, distance_range=None, remove_resolve

# Start off by reading in main flux cube
from .sed.cube import SEDCube
cube = SEDCube.read(os.path.join(directory, 'flux.fits'))
cube = SEDCube.read(os.path.join(directory, 'flux.fits'), memmap=use_memmap)

if use_memmap:
from tempfile import mkdtemp
import os.path as path
mffilename = path.join(mkdtemp(), 'model_fluxes.dat')

# Initialize model flux array and array to indicate whether models are
# extended
if m.n_distances is None:
model_fluxes = np.zeros((cube.n_models, len(filters))) * u.mJy
shape = (cube.n_models, len(filters))
if use_memmap:
model_fluxes = np.memmap(mffilename, dtype='float32', mode='w+', shape=shape) * u.mJy
else:
model_fluxes = np.zeros(shape) * u.mJy
extended = None
else:
model_fluxes = np.zeros((cube.n_models, m.n_distances, len(filters))) * u.mJy
extended = np.zeros((cube.n_models, m.n_distances, len(filters)), dtype=bool)
shape = (cube.n_models, m.n_distances, len(filters))
if use_memmap:
model_fluxes = np.memmap(mffilename, dtype='float32', mode='w+', shape=shape) * u.mJy
extfilename = path.join(mkdtemp(), 'extended.dat')
extended = np.memmap(extfilename, shape=shape, dtype=bool, mode='w+')
else:
model_fluxes = np.zeros(shape) * u.mJy
extended = np.zeros(shape, dtype=bool)
print(f"Data shape={shape}. use_memmap={use_memmap}")

# Define empty wavelength array
m.wavelengths = np.zeros(len(filters)) * u.micron
Expand Down
4 changes: 2 additions & 2 deletions sedfitter/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ def plot(input_fits, output_dir=None, select_format=("N", 1), plot_max=None,

if flux.ndim > 1:
for j in range(flux.shape[1]):
lines.append(np.column_stack([s.wav, flux[:, j]]))
lines.append(np.column_stack([s.wav.value if hasattr(s.wav, 'value') else s.wav, flux[:, j]]))
colors.append(color[color_type][j])
else:
lines.append(np.column_stack([s.wav, flux]))
lines.append(np.column_stack([s.wav.value if hasattr(s.wav, 'value') else s.wav, flux]))
colors.append(color[color_type])

if show_convolved:
Expand Down

0 comments on commit d7ef406

Please sign in to comment.