diff --git a/Ska/engarchive/fetch.py b/Ska/engarchive/fetch.py index 4dae7433..0961c087 100644 --- a/Ska/engarchive/fetch.py +++ b/Ska/engarchive/fetch.py @@ -1424,7 +1424,10 @@ class MSIDset(collections.OrderedDict): """ MSID = MSID - def __init__(self, msids, start=LAUNCH_DATE, stop=None, filter_bad=False, stat=None): + def __init__(self, msids=None, start=LAUNCH_DATE, stop=None, filter_bad=False, stat=None): + if msids is None: + msids = [] + super(MSIDset, self).__init__() intervals = _get_table_intervals_as_list(start, check_overlaps=True) diff --git a/Ska/engarchive/tests/test_fetch.py b/Ska/engarchive/tests/test_fetch.py index 5710720d..6d78f63b 100644 --- a/Ska/engarchive/tests/test_fetch.py +++ b/Ska/engarchive/tests/test_fetch.py @@ -1,6 +1,5 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import print_function, division, absolute_import - +import pickle from copy import deepcopy import numpy as np @@ -69,6 +68,35 @@ def test_filter_bad_times_list(): assert np.all(dates == DATES_EXPECT1) +@pytest.mark.parametrize('stat', [None, '5min']) +def test_pickle_MSID(stat): + """Test pickling of MSID objects""" + msid = 'aoattqt1' + start, stop = '2022:001:00:00:00', '2022:001:00:15:00' + dat = fetch.MSID(msid, start, stop, stat=stat) + dat2 = pickle.loads(pickle.dumps(dat)) + attrs = ('tstart', 'tstop', 'datestart', 'datestop', 'data_source', 'content') + for attr in attrs: + assert getattr(dat, attr) == getattr(dat2, attr) + for attr in ('times', 'vals'): + assert np.all(getattr(dat, attr) == getattr(dat2, attr)) + + +@pytest.mark.parametrize('stat', [None, '5min']) +def test_pickle_MSIDset(stat): + """Test pickling of MSIDset objects""" + msid = 'aoattqt1' + start, stop = '2022:001:00:00:00', '2022:001:00:15:00' + dat = fetch.MSIDset([msid], start, stop, stat=stat) + dat2 = pickle.loads(pickle.dumps(dat)) + attrs = ('tstart', 'tstop', 'datestart', 'datestop') + assert dat.keys() == dat2.keys() + for attr in attrs: + assert getattr(dat, attr) == getattr(dat2, attr) + for attr in ('times', 'vals') + attrs: + assert np.all(getattr(dat[msid], attr) == getattr(dat2[msid], attr)) + + def test_msidset_filter_bad_times_list(): dat = fetch.MSIDset(['aogyrct1'], '2008:291:12:00:00', '2008:298:12:00:00') dat.filter_bad_times(table=BAD_TIMES)