Skip to content

Commit

Permalink
Merge pull request #139 from renecotyfanboy/speedup-instrument
Browse files Browse the repository at this point in the history
Speedup data reading
  • Loading branch information
renecotyfanboy authored Apr 4, 2024
2 parents da5e942 + 25b3708 commit d9ed6fc
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 115 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxspec"
version = "0.0.3"
version = "0.0.4"
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
license = "MIT"
Expand All @@ -17,7 +17,7 @@ numpyro = ">=0.13.2,<0.15.0"
dm-haiku = ">=0.0.11,<0.0.13"
networkx = "^3.1"
matplotlib = "^3.8.0"
arviz = "^0.17.0"
arviz = "^0.17.1"
chainconsumer = "^1.0.0"
simpleeval = "^0.9.13"
cmasher = "^1.6.3"
Expand All @@ -27,6 +27,7 @@ tinygp = "^0.3.0"
seaborn = "^0.13.1"
mkdocstrings = "^0.24.0"
sparse = "^0.15.1"
scipy = "<1.13"


[tool.poetry.group.docs.dependencies]
Expand Down
172 changes: 101 additions & 71 deletions src/jaxspec/data/obsconf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import numpy as np
import xarray as xr
import sparse
import scipy
from .instrument import Instrument
from .observation import Observation


def densify_xarray(xarray):
return xr.DataArray(xarray.data.todense(), dims=xarray.dims, coords=xarray.coords, attrs=xarray.attrs, name=xarray.name)


class ObsConfiguration(xr.Dataset):
"""
Class to store the data of a folding model, which is the link between the unfolded and folded spectra.
Expand Down Expand Up @@ -56,8 +53,8 @@ def out_energies(self):

out_energies = np.stack(
(
np.asarray(self.coords["e_min_folded"].data.todense(), dtype=np.float64),
np.asarray(self.coords["e_max_folded"].data.todense(), dtype=np.float64),
np.asarray(self.coords["e_min_folded"].data, dtype=np.float64),
np.asarray(self.coords["e_max_folded"].data, dtype=np.float64),
)
)

Expand All @@ -67,92 +64,125 @@ def out_energies(self):
def from_pha_file(
cls, pha_path, rmf_path=None, arf_path=None, bkg_path=None, low_energy: float = 1e-20, high_energy: float = 1e20
):
from .util import data_loader
from .util import data_path_finder

pha, arf, rmf, bkg, metadata = data_loader(pha_path, arf_path=arf_path, rmf_path=rmf_path, bkg_path=bkg_path)
arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(pha_path)

instrument = Instrument.from_matrix(
rmf.sparse_matrix,
arf.specresp if arf is not None else np.ones_like(rmf.energ_lo),
rmf.energ_lo,
rmf.energ_hi,
rmf.e_min,
rmf.e_max,
)
arf_path = arf_path_default if arf_path is None else arf_path
rmf_path = rmf_path_default if rmf_path is None else rmf_path
bkg_path = bkg_path_default if bkg_path is None else bkg_path

if bkg is not None:
backratio = np.where(bkg.backscal > 0.0, pha.backscal / np.where(bkg.backscal > 0, bkg.backscal, 1.0), 0.0)
else:
backratio = np.ones_like(pha.counts)

observation = Observation.from_matrix(
pha.counts,
pha.grouping,
pha.channel,
pha.quality,
pha.exposure,
background=bkg.counts if bkg is not None else None,
backratio=backratio,
attributes=metadata,
)
instrument = Instrument.from_ogip_file(rmf_path, arf_path=arf_path)
observation = Observation.from_pha_file(pha_path, bkg_path=bkg_path)

return cls.from_instrument(instrument, observation, low_energy=low_energy, high_energy=high_energy)

@classmethod
def from_instrument(
cls, instrument: Instrument, observation: Observation, low_energy: float = 1e-20, high_energy: float = 1e20
):
# Exclude the bins flagged with bad quality
quality_filter = observation.quality == 0
grouping = observation.grouping * quality_filter
# First we unpack all the xarray data to classical np array for efficiency
# We also exclude the bins that are flagged with bad quality on the instrument
quality_filter = observation.quality.data == 0
grouping = scipy.sparse.csr_array(observation.grouping.data.to_scipy_sparse()) * quality_filter
e_min_channel = instrument.coords["e_min_channel"].data
e_max_channel = instrument.coords["e_max_channel"].data
e_min_unfolded = instrument.coords["e_min_unfolded"].data
e_max_unfolded = instrument.coords["e_max_unfolded"].data
redistribution = scipy.sparse.csr_array(instrument.redistribution.data.to_scipy_sparse())
area = instrument.area.data
exposure = observation.exposure.data

# Computing the lower and upper energies of the bins after grouping
# This is just a trick to compute it without 10 lines of code
e_min = (xr.where(grouping > 0, grouping, np.nan) * instrument.coords["e_min_channel"]).min(
skipna=True, dim="instrument_channel"
)
grouping_nan = observation.grouping.data * quality_filter
grouping_nan.fill_value = np.nan
e_min = sparse.nanmin(grouping_nan * e_min_channel, axis=1).todense()
e_max = sparse.nanmax(grouping_nan * e_max_channel, axis=1).todense()

e_max = (xr.where(grouping > 0, grouping, np.nan) * instrument.coords["e_max_channel"]).max(
skipna=True, dim="instrument_channel"
)

transfer_matrix = grouping @ (instrument.redistribution * instrument.area * observation.exposure)
transfer_matrix = transfer_matrix.assign_coords({"e_min_folded": e_min, "e_max_folded": e_max})
# Compute the transfer matrix
transfer_matrix = grouping @ (redistribution * area * exposure)

# Exclude bins out of the considered energy range, and bins without contribution from the RMF
row_idx = densify_xarray(((e_min > low_energy) & (e_max < high_energy)) * (grouping.sum(dim="instrument_channel") > 0))

col_idx = densify_xarray(
(instrument.coords["e_min_unfolded"] > 0) * (instrument.redistribution.sum(dim="instrument_channel") > 0)
)

# The transfer matrix is converted locally to csr format to allow FAST slicing
transfer_matrix_scipy = transfer_matrix.data.to_scipy_sparse().tocsr()
transfer_matrix_reduced = transfer_matrix_scipy[row_idx.data][:, col_idx.data]
transfer_matrix_reduced = sparse.COO.from_scipy_sparse(transfer_matrix_reduced)
row_idx = (e_min > low_energy) & (e_max < high_energy) & (grouping.sum(axis=1) > 0)
col_idx = (e_min_unfolded > 0) & (redistribution.sum(axis=0) > 0)

# A dummy zero matrix is put so that the slicing in xarray is fast
transfer_matrix.data = sparse.zeros_like(transfer_matrix.data)
transfer_matrix = transfer_matrix[row_idx][:, col_idx]

# The reduced transfer matrix is put back in the xarray
transfer_matrix.data = transfer_matrix_reduced

folded_counts = observation.folded_counts.copy().where(row_idx, drop=True)
# Apply this reduction to all the relevant arrays
transfer_matrix = sparse.COO.from_scipy_sparse(transfer_matrix[row_idx][:, col_idx])
folded_counts = observation.folded_counts.data[row_idx]
folded_backratio = observation.folded_backratio.data[row_idx]
area = instrument.area.data[col_idx]
e_min_folded = e_min[row_idx]
e_max_folded = e_max[row_idx]
e_min_unfolded = e_min_unfolded[col_idx]
e_max_unfolded = e_max_unfolded[col_idx]

if observation.folded_background is not None:
folded_background = observation.folded_background.copy().where(row_idx, drop=True)

folded_background = observation.folded_background.data[row_idx]
else:
folded_background = None
folded_background = np.zeros_like(folded_counts)

data_dict = {
"transfer_matrix": (
["folded_channel", "unfolded_channel"],
transfer_matrix,
{
"description": "Transfer matrix to use to fold the incoming spectrum. It is built and restricted using the grouping, redistribution matrix, effective area, quality flags and energy bands defined by the user."
},
),
"area": (
["unfolded_channel"],
area,
{"description": "Effective area with the same restrictions as the transfer matrix.", "units": "cm^2"},
),
"exposure": ([], exposure, {"description": "Total exposure", "unit": "s"}),
"folded_counts": (
["folded_channel"],
folded_counts,
{
"description": "Folded counts after grouping, with the same restrictions as the transfer matrix.",
"unit": "photons",
},
),
"folded_backratio": (
["folded_channel"],
folded_backratio,
{"description": "Background scaling after grouping, with the same restrictions as the transfer matrix."},
),
"folded_background": (
["folded_channel"],
folded_background,
{
"description": "Folded background counts after grouping, with the same restrictions as the transfer matrix.",
"unit": "photons",
},
),
}

return cls(
{
"transfer_matrix": transfer_matrix,
"area": instrument.area.copy().where(col_idx, drop=True),
"exposure": observation.exposure,
"folded_backratio": observation.folded_backratio.copy().where(row_idx, drop=True),
"folded_counts": folded_counts,
"folded_background": folded_background,
}
data_dict,
coords={
"e_min_folded": (
["folded_channel"],
e_min_folded,
{"description": "Low energy of folded channel"},
),
"e_max_folded": (
["folded_channel"],
e_max_folded,
{"description": "High energy of folded channel"},
),
"e_min_unfolded": (
["unfolded_channel"],
e_min_unfolded,
{"description": "Low energy of unfolded channel"},
),
"e_max_unfolded": (
["unfolded_channel"],
e_max_unfolded,
{"description": "High energy of unfolded channel"},
),
},
attrs=observation.attrs | instrument.attrs,
)
30 changes: 24 additions & 6 deletions src/jaxspec/data/observation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
import xarray as xr
from .ogip import DataPHA


class Observation(xr.Dataset):
Expand Down Expand Up @@ -95,11 +95,7 @@ def from_matrix(
)

@classmethod
def from_pha_file(cls, pha_file: str | os.PathLike, **kwargs):
from .util import data_loader

pha, arf, rmf, bkg, metadata = data_loader(pha_file)

def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
if bkg is not None:
backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
else:
Expand All @@ -116,6 +112,28 @@ def from_pha_file(cls, pha_file: str | os.PathLike, **kwargs):
attributes=metadata,
)

@classmethod
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
from .util import data_path_finder

arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
bkg_path = bkg_path_default if bkg_path is None else bkg_path

pha = DataPHA.from_file(pha_path)
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None

if metadata is None:
metadata = {}

metadata.update(
observation_file=pha_path,
background_file=bkg_path,
response_matrix_file=rmf_path,
ancillary_response_file=arf_path,
)

return cls.from_ogip_container(pha, bkg=bkg, **metadata)

def plot_counts(self, **kwargs):
"""
Plot the counts
Expand Down
48 changes: 17 additions & 31 deletions src/jaxspec/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pathlib import Path
from numpy.typing import ArrayLike
from collections.abc import Mapping
from typing import TypeVar
from typing import TypeVar, Tuple
from astropy.io import fits

from .ogip import DataPHA, DataARF, DataRMF
from . import Observation, Instrument, ObsConfiguration
from ..model.abc import SpectralModel
from ..fit import CountForwardModel
Expand Down Expand Up @@ -201,46 +201,32 @@ def obs_model(p):
return fakeits[0] if len(fakeits) == 1 else fakeits


def data_loader(pha_path: str, arf_path=None, rmf_path=None, bkg_path=None):
def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
"""
This function is a convenience function that allows to load PHA, ARF and RMF data
from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
specified filenames overwritten by the user.
This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
Parameters:
pha_path: The PHA file path.
Returns:
arf_path: The ARF file path.
rmf_path: The RMF file path.
bkg_path: The BKG file path.
"""

pha = DataPHA.from_file(pha_path)
directory = str(Path(pha_path).parent)

if arf_path is None:
if pha.ancrfile != "none" and pha.ancrfile != "":
arf_path = find_file_or_compressed_in_dir(pha.ancrfile, directory)

if rmf_path is None:
if pha.respfile != "none" and pha.respfile != "":
rmf_path = find_file_or_compressed_in_dir(pha.respfile, directory)

if bkg_path is None:
if pha.backfile.lower() != "none" and pha.backfile != "":
bkg_path = find_file_or_compressed_in_dir(pha.backfile, directory)
def find_path(file_name: str, directory: str) -> str | None:
if file_name.lower() != "none" and file_name != "":
return find_file_or_compressed_in_dir(file_name, directory)
else:
return None

arf = DataARF.from_file(arf_path) if arf_path is not None else None
rmf = DataRMF.from_file(rmf_path) if rmf_path is not None else None
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
header = fits.getheader(pha_path, "SPECTRUM")
directory = str(Path(pha_path).parent)

metadata = {
"observation_file": pha_path,
"background_file": bkg_path,
"response_matrix_file": rmf_path,
"ancillary_response_file": arf_path,
}
arf_path = find_path(header.get("ANCRFILE", "none"), directory)
rmf_path = find_path(header.get("RESPFILE", "none"), directory)
bkg_path = find_path(header.get("BACKFILE", "none"), directory)

return pha, arf, rmf, bkg, metadata
return arf_path, rmf_path, bkg_path


def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
Expand Down
9 changes: 4 additions & 5 deletions src/jaxspec/model/additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,17 @@ class Diskbb(AdditiveComponent):
`Diskpbb` with $p=0.75$
??? abstract "Parameters"
* $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
* $\text{norm}$ : $\cos i(r_{\text{in}}/d)^{2}$,
where $r_{\text{in}}$ is "an apparent" inner disk radius $\left[\text{km}\right]$,
* $d$ the distance to the source in units of $10 \text{kpc}$,
* $i$ the angle of the disk ($i=0$ is face-on)
* $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
$d$ the distance to the source in units of $10 \text{kpc}$, $i$ the angle of the disk ($i=0$ is face-on)
"""

def continuum(self, energy):
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
p = 0.75
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
tout = 0.0
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))

# Tout is set to 0 as it is evaluated at R=infinity
def integrand(kT, e, tin, p):
Expand Down
Loading

0 comments on commit d9ed6fc

Please sign in to comment.