diff --git a/straxen/scada.py b/straxen/scada.py index dc12916ed..2252e4839 100644 --- a/straxen/scada.py +++ b/straxen/scada.py @@ -18,16 +18,13 @@ from configparser import NoOptionError import sys -if any('jupyter' in arg for arg in sys.argv): - # In some cases we are not using any notebooks, - # Taken from 44952863 on stack overflow thanks! - from tqdm import tqdm_notebook as tqdm -else: - from tqdm import tqdm - export, __all__ = strax.exporter() +# Fancy tqdm style in notebooks +tqdm = strax.utils.tqdm + + @export class SCADAInterface: @@ -153,12 +150,12 @@ def get_scada_values(self, self._get_token() # Now loop over specified parameters and get the values for those. - iterator = enumerate(parameters.items()) - if self._use_progress_bar: - # wrap using progress bar - iterator = tqdm(iterator, total=len(parameters), desc='Load parameters') - - for ind, (k, p) in iterator: + for ind, (k, p) in tqdm( + enumerate(parameters.items()), + total=len(parameters), + desc='Load parameters', + disable=not self._use_progress_bar, + ): try: temp_df = self._query_single_parameter(start, end, k, p, @@ -175,7 +172,7 @@ def get_scada_values(self, f' {p} does not match the previous timestamps.') except ValueError as e: warnings.warn(f'Was not able to load parameters for "{k}". The reason was: "{e}".' - f'Continue without {k}.') + f'Continue without {k}.') temp_df = pd.DataFrame(columns=(k,)) if ind: diff --git a/tests/test_scada.py b/tests/test_scada.py index c5e1c15a3..f215ffed9 100644 --- a/tests/test_scada.py +++ b/tests/test_scada.py @@ -1,3 +1,5 @@ +import warnings +import pytz import numpy as np import straxen import unittest @@ -14,6 +16,71 @@ def setUp(self): self.start += 10**6 self.end = self.start + 5*10**9 + def test_wrong_querries(self): + parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} + + with self.assertRaises(ValueError): + # Runid but no context + df = self.sc.get_scada_values(parameters, + run_id='1', + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # No time range specified + df = self.sc.get_scada_values(parameters, + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # Start larger end + df = self.sc.get_scada_values(parameters, + start=2, + end=1, + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # Start and/or end not in ns unix time + df = self.sc.get_scada_values(parameters, + start=1, + end=2, + every_nth_value=1, + query_type_lab=False, ) + + def test_pmt_names(self): + """ + Tests different query options for pmt list. + """ + pmts_dict = self.sc.find_pmt_names(pmts=12, current=True) + assert 'PMT12_HV' in pmts_dict.keys() + assert 'PMT12_I' in pmts_dict.keys() + assert pmts_dict['PMT12_HV'] == 'XE1T.CTPC.BOARD04.CHAN003.VMON' + + pmts_dict = self.sc.find_pmt_names(pmts=(12, 13)) + assert 'PMT12_HV' in pmts_dict.keys() + assert 'PMT13_HV' in pmts_dict.keys() + + with self.assertRaises(ValueError): + self.sc.find_pmt_names(pmts=12, current=False, hv=False) + + def test_token_expires(self): + self.sc.token_expires_in() + + def test_convert_timezone(self): + parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + every_nth_value=1, + query_type_lab=False, ) + + df_strax = straxen.convert_time_zone(df, tz='strax') + assert df_strax.index.dtype.type is np.int64 + + df_etc = straxen.convert_time_zone(df, tz='Etc/GMT+0') + assert df_etc.index.dtype.tz is pytz.timezone('Etc/GMT+0') + def test_query_sc_values(self): """ Unity test for the SCADAInterface. Query a fixed range and check if @@ -40,6 +107,13 @@ def test_query_sc_values(self): query_type_lab=False,) assert np.all(np.isclose(df[:4], 2.079859)), 'First four values deviate from queried values.' assert np.all(np.isclose(df[4:], 2.117820)), 'Last two values deviate from queried values.' + print('Testing interpolation option:') + self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='interpolation', + every_nth_value=1, + query_type_lab=False,) print('Testing down sampling and averaging option:') parameters = {'SomeParameter': 'XE1T.CRY_TE101_TCRYOBOTT_AI.PI'}