Skip to content

Commit

Permalink
Merge pull request #365 from vhaasteren/pulsar_mods
Browse files Browse the repository at this point in the history
Modifications to the Pulsar class
  • Loading branch information
vhaasteren authored May 9, 2024
2 parents a4876a9 + 61ba269 commit 37e24d2
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 48 deletions.
149 changes: 103 additions & 46 deletions enterprise/pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
"""Class containing pulsar data from timing package [tempo2/PINT].
"""

import contextlib
import json
import logging
import os
import pickle
from io import StringIO

import numpy as np
from ephem import Ecliptic, Equatorial
Expand Down Expand Up @@ -159,20 +161,18 @@ def filter_data(self, start_time=None, end_time=None):

self.sort_data()

def drop_not_picklable(self):
"""Drop all attributes that cannot be pickled.
Derived classes should implement this if they have
any such attributes.
"""
pass

def to_pickle(self, outdir=None):
"""Save object to pickle file."""

# drop t2pulsar object
if hasattr(self, "t2pulsar"):
del self.t2pulsar
msg = "t2pulsar object cannot be pickled and has been removed."
logger.warning(msg)

if hasattr(self, "pint_toas"):
del self.pint_toas
del self.model
msg = "pint_toas and model objects cannot be pickled and have been removed."
logger.warning(msg)
self.drop_not_picklable()

if outdir is None:
outdir = os.getcwd()
Expand Down Expand Up @@ -315,7 +315,7 @@ def sunssb(self):

@property
def telescope(self):
"""Return telescope vector at all timestamps"""
"""Return telescope name at all timestamps"""
return self._telescope[self._isort]


Expand All @@ -327,7 +327,11 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):

if not drop_pintpsr:
self.model = model
self.parfile = model.as_parfile()
self.pint_toas = toas
with StringIO() as tim:
toas.write_TOA_file(tim)
self.timfile = tim.getvalue()

# these are TDB but not barycentered
# self._toas = np.array(toas.table["tdbld"], dtype="float64") * 86400
Expand All @@ -336,20 +340,17 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):
self._stoas = np.array(toas.get_mjds().value, dtype="float64") * 86400
self._residuals = np.array(resids(toas, model).time_resids.to(u.s), dtype="float64")
self._toaerrs = np.array(toas.get_errors().to(u.s), dtype="float64")
self._designmatrix = model.designmatrix(toas)[0]
self._designmatrix, self.fitpars, self.designmatrix_units = model.designmatrix(toas)
self._ssbfreqs = np.array(model.barycentric_radio_freq(toas), dtype="float64")
self._telescope = np.array(toas.get_obss())

# fitted parameters
self.fitpars = ["Offset"] + [par for par in model.params if not getattr(model, par).frozen]

# gather DM/DMX information if available
self._set_dm(model)

# set parameters
spars = [par for par in model.params]
self.setpars = [sp for sp in spars if sp not in self.fitpars]
self.setpars = [sp for sp in model.params if sp not in self.fitpars]

# FIXME: this can be done more cleanly using PINT
self._flags = {}
for ii, obsflags in enumerate(toas.get_flags()):
for jj, flag in enumerate(obsflags):
Expand All @@ -360,6 +361,7 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):

# convert flags to arrays
# TODO probably better way to do this
# -- PINT always stores flags as strings
for key, val in self._flags.items():
if isinstance(val[0], u.quantity.Quantity):
self._flags[key] = np.array([v.value for v in val])
Expand All @@ -384,6 +386,21 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):

self.sort_data()

def drop_pintpsr(self):
with contextlib.suppress(NameError):
del self.model
del self.parfile
del self.pint_toas
del self.timfile

def drop_not_picklable(self):
with contextlib.suppress(AttributeError):
del self.model
del self.pint_toas
logger.warning("pint_toas and model objects cannot be pickled and have been removed.")

return super().drop_not_picklable()

def _set_dm(self, model):
pars = [par for par in model.params if not getattr(model, par).frozen]

Expand Down Expand Up @@ -460,7 +477,15 @@ def _get_sunssb(self, toas, model):


class Tempo2Pulsar(BasePulsar):
def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True):
def __init__(
self,
t2pulsar,
sort=True,
drop_t2pulsar=True,
planets=True,
par_name=None,
tim_name=None,
):
self._sort = sort
self.t2pulsar = t2pulsar
self.planets = planets
Expand Down Expand Up @@ -507,6 +532,15 @@ def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True):
self.sort_data()

if drop_t2pulsar:
self.drop_tempopsr()
else:
if par_name is not None and os.path.exists(par_name):
self.parfile = open(par_name).read()
if tim_name is not None and os.path.exists(tim_name):
self.timfile = open(tim_name).read()

def drop_tempopsr(self):
with contextlib.suppress(NameError):
del self.t2pulsar

# gather DM/DMX information if available
Expand Down Expand Up @@ -569,7 +603,7 @@ def _get_sunssb(self, t2pulsar):
sunssb = None
if self.planets:
# for ii in range(1, 10):
# tag = 'DMASSPLANET' + str(ii)
# tag = 'DMASSPLANET' + str(ii)@pytest.mark.skipif(t2 is None, reason="TEMPO2/libstempo not available")
# self.t2pulsar[tag].val = 0.0
self.t2pulsar.formbats()
sunssb = np.zeros((len(self._toas), 6))
Expand All @@ -586,6 +620,12 @@ def _get_sunssb(self, t2pulsar):
# then replace them with pickleable objects that can be inflated
# to numpy arrays with SharedMemory storage

def drop_not_picklable(self):
with contextlib.suppress(AttributeError):
del self.t2pulsar
logger.warning("t2pulsar object cannot be pickled and has been removed.")
return super().drop_not_picklable()

_todeflate = ["_designmatrix", "_planetssb", "_sunssb", "_flags"]
_deflated = "pristine"

Expand Down Expand Up @@ -622,7 +662,9 @@ def Pulsar(*args, **kwargs):
sort = kwargs.get("sort", True)
drop_t2pulsar = kwargs.get("drop_t2pulsar", True)
drop_pintpsr = kwargs.get("drop_pintpsr", True)
timing_package = kwargs.get("timing_package", "tempo2")
timing_package = kwargs.get("timing_package", None)
if timing_package is not None:
timing_package = timing_package.lower()

if pint is not None:
toas = [x for x in args if isinstance(x, TOAs)]
Expand Down Expand Up @@ -650,31 +692,46 @@ def Pulsar(*args, **kwargs):
reltimfile = timfiletup[-1]
relparfile = os.path.relpath(parfile[0], dirname)

if timing_package is None:
if t2 is not None:
timing_package = "tempo2"
elif pint is not None: # pragma: no cover
timing_package = "pint"
else: # pragma: no cover
raise ValueError("No timing package available with which to load a pulsar")

# get current directory
cwd = os.getcwd()

# Change directory to the base directory of the tim-file to deal with
# INCLUDE statements in the tim-file
os.chdir(dirname)

if timing_package.lower() == "pint":
if (clk is not None) and (bipm_version is None):
bipm_version = clk.split("(")[1][:-1]
model, toas = get_model_and_toas(
relparfile,
reltimfile,
ephem=ephem,
bipm_version=bipm_version,
planets=planets,
)
os.chdir(cwd)
return PintPulsar(toas, model, sort=sort, drop_pintpsr=drop_pintpsr, planets=planets)

elif timing_package.lower() == "tempo2":
# hack to set maxobs
maxobs = get_maxobs(reltimfile) + 100
t2pulsar = t2.tempopulsar(relparfile, reltimfile, maxobs=maxobs, ephem=ephem, clk=clk)
try:
# Change directory to the base directory of the tim-file to deal with
# INCLUDE statements in the tim-file
os.chdir(dirname)
if timing_package.lower() == "tempo2":
if t2 is None: # pragma: no cover
raise ValueError("tempo2 requested but tempo2 is not available")
# hack to set maxobs
maxobs = get_maxobs(reltimfile) + 100
t2pulsar = t2.tempopulsar(relparfile, reltimfile, maxobs=maxobs, ephem=ephem, clk=clk)
return Tempo2Pulsar(
t2pulsar,
sort=sort,
drop_t2pulsar=drop_t2pulsar,
planets=planets,
par_name=relparfile,
tim_name=reltimfile,
)
elif timing_package.lower() == "pint":
if pint is None: # pragma: no cover
raise ValueError("PINT requested but PINT is not available")
if (clk is not None) and (bipm_version is None):
bipm_version = clk.split("(")[1][:-1]
model, toas = get_model_and_toas(
relparfile, reltimfile, ephem=ephem, bipm_version=bipm_version, planets=planets
)
os.chdir(cwd)
return PintPulsar(toas, model, sort=sort, drop_pintpsr=drop_pintpsr, planets=planets)
else:
raise ValueError(f"Unknown timing package {timing_package}")
finally:
os.chdir(cwd)
return Tempo2Pulsar(t2pulsar, sort=sort, drop_t2pulsar=drop_t2pulsar, planets=planets)

raise ValueError("Unknown arguments {}".format(args))
raise ValueError("Pulsar (par/tim) not specified in {args} or {kwargs}")
75 changes: 73 additions & 2 deletions tests/test_pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,49 @@
from pint.models import get_model_and_toas


class TestTimingPackageExceptions(unittest.TestCase):
def test_unkown_timing_package(self):
# initialize Pulsar class
with self.assertRaises(ValueError):
self.psr = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par",
datadir + "/B1855+09_NANOGrav_9yv1.tim",
timing_package="foobar",
)

def test_clk_but_no_bipm(self):
self.psr = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par",
datadir + "/B1855+09_NANOGrav_9yv1.tim",
clk="TT(BIPM2020)",
timing_package="pint",
)


class TestPulsar(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Setup the Pulsar object."""

# initialize Pulsar class
cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")
cls.psr = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim", drop_t2pulsar=True
)

@classmethod
def tearDownClass(cls):
shutil.rmtree("pickle_dir", ignore_errors=True)

def test_droppsr(self):
self.psr_nodrop = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim", drop_t2pulsar=False
)

self.psr_nodrop.drop_tempopsr()

with self.assertRaises(AttributeError):
_ = self.psr.t2pulsar

def test_residuals(self):
"""Check Residual shape."""

Expand Down Expand Up @@ -195,13 +226,53 @@ def setUpClass(cls):

# initialize Pulsar class
cls.psr = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par",
datadir + "/B1855+09_NANOGrav_9yv1.tim",
ephem="DE430",
drop_pintpsr=True,
timing_package="pint",
)

def test_droppsr(self):
self.psr_nodrop = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par",
datadir + "/B1855+09_NANOGrav_9yv1.tim",
ephem="DE430",
drop_pintpsr=False,
timing_package="pint",
)

self.psr_nodrop.drop_pintpsr()

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.model

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.parfile

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.pint_toas

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.timfile

def test_drop_not_picklable(self):
self.psr_nodrop = Pulsar(
datadir + "/B1855+09_NANOGrav_9yv1.gls.par",
datadir + "/B1855+09_NANOGrav_9yv1.tim",
ephem="DE430",
drop_pintpsr=False,
timing_package="pint",
)

self.psr_nodrop.drop_not_picklable()

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.model

with self.assertRaises(AttributeError):
_ = self.psr_nodrop.pint_toas

def test_deflate_inflate(self):
pass

Expand All @@ -225,7 +296,7 @@ def test_no_planet(self):
model, toas = get_model_and_toas(
datadir + "/J0030+0451_NANOGrav_9yv1.gls.par", datadir + "/J0030+0451_NANOGrav_9yv1.tim", planets=False
)
Pulsar(model, toas, planets=True)
Pulsar(model, toas, planets=True, drop_pintpsr=False)
msg = "obs_earth_pos is not in toas.table.colnames. Either "
msg += "`planet` flag is not True in `toas` or further Pint "
msg += "development to add additional planets is needed."
Expand Down

0 comments on commit 37e24d2

Please sign in to comment.