Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to support remote access from Windows #19

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions ska_tdb/tdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
"""
import os
import re
import glob

import numpy as np
import six

from cheta.fetch import local_or_remote_function

__all__ = ['msids', 'tables', 'set_tdb_version', 'get_tdb_version',
'TableView', 'MsidView']


SKA = os.environ.get('SKA', os.path.join(os.sep, 'proj', 'sot', 'ska'))

# Set None values for module globals that are set in set_tdb_version
TDB_VERSIONS = None
TDB_VERSION = None
Expand Down Expand Up @@ -52,6 +51,35 @@
""".lower().split()


@local_or_remote_function("Get TDB Versions from remote server...")
def get_versions():
import os
from pathlib import Path # noqa
version_dirs = (Path(os.environ["SKA"]) / "data" / "Ska.tdb").glob("p0??")
tdb_versions = sorted(int(vdir.name[2:]) for vdir in version_dirs)
return tdb_versions


@local_or_remote_function("Get TDB Data Directory from remote server...")
def get_data_path(version):
from pathlib import Path # noqa
return str(Path(os.environ["SKA"]) / "data" / "Ska.tdb" / f"p{version:03d}")


@local_or_remote_function("Loading file from remote server...")
def _read_npy_file(item, data_dir):
filename = os.path.join(data_dir, item + '.npy')
tv = TableView(np.load(filename))
return filename, tv


@local_or_remote_function("Finding *.npy files on remote server...")
def _get_all_npy_files(data_dir):
import glob
files = glob.glob(os.path.join(data_dir, '*.npy'))
return files


def set_tdb_version(version=None):
"""
Set the version of the TDB which is used.
Expand All @@ -66,8 +94,7 @@ def set_tdb_version(version=None):
global DATA_DIR
global tables
global msids
version_dirs = glob.glob(os.path.join(SKA, 'data', 'Ska.tdb', 'p0??'))
TDB_VERSIONS = sorted([int(os.path.basename(vdir)[2:]) for vdir in version_dirs])
TDB_VERSIONS = get_versions()

if version is None:
if TDB_VERSIONS:
Expand All @@ -78,7 +105,7 @@ def set_tdb_version(version=None):
raise ValueError('TDB version must be one of the following: {}'.format(TDB_VERSIONS))

TDB_VERSION = version
DATA_DIR = os.path.join(SKA, 'data', 'Ska.tdb', 'p{:03d}'.format(TDB_VERSION))
DATA_DIR = get_data_path(version)
tables = TableDict()
msids = MsidView()

Expand All @@ -91,18 +118,17 @@ def get_tdb_version():


class TableDict(dict):

def __getitem__(self, item):
if item not in self:
try:
filename = os.path.join(DATA_DIR, item + '.npy')
self[item] = TableView(np.load(filename))
filename, self[item] = _read_npy_file(item, DATA_DIR)
except IOError:
raise KeyError("Table {} not in TDB files (no file {})".format(item, filename))
return dict.__getitem__(self, item)

def keys(self):
import glob
files = glob.glob(os.path.join(DATA_DIR, '*.npy'))
files = _get_all_npy_files(DATA_DIR)
return [os.path.basename(x)[:-4] for x in files]


Expand Down
10 changes: 10 additions & 0 deletions ska_tdb/tests/test_tdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import os
import pytest

from .. import msids, tables, set_tdb_version, get_tdb_version
from ska_tdb.tdb import get_data_path

SKA_ACCESS_REMOTELY = os.environ.get("SKA_ACCESS_REMOTELY") == "True"
# Set to fixed version for regression testing
TDB_VERSION = 14
set_tdb_version(TDB_VERSION)
Expand All @@ -13,6 +18,11 @@
'DESCRIPTION', 'EHS_HEADER_FLAG')


@pytest.mark.skipif(not SKA_ACCESS_REMOTELY, reason="remote access not being tested")
def test_remote_access():
assert str(get_data_path(10)) == "/proj/sot/ska3/flight/data/Ska.tdb/p010"


def test_msids():
tephin = msids['tephin']
assert tephin.Tmsrment.colnames == tmsrment_colnames
Expand Down