Skip to content

Commit

Permalink
Specify available RSE in RucioRemoteBackend (#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx authored Oct 4, 2024
1 parent a412b43 commit 35917e7
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions straxen/storage/rucio_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class RucioRemoteFrontend(strax.StorageFrontend):
local_did_cache = None
path = None

def __init__(self, download_heavy=False, staging_dir="./strax_data", *args, **kwargs):
def __init__(
self, download_heavy=False, staging_dir="./strax_data", rses_only=tuple(), *args, **kwargs
):
"""
:param download_heavy: option to allow downloading of heavy data through RucioRemoteBackend
:param args: Passed to strax.StorageFrontend
Expand All @@ -46,7 +48,7 @@ def __init__(self, download_heavy=False, staging_dir="./strax_data", *args, **kw

if HAVE_ADMIX:
self.backends = [
RucioRemoteBackend(staging_dir, download_heavy=download_heavy),
RucioRemoteBackend(staging_dir, download_heavy=download_heavy, rses_only=rses_only),
]
else:
self.log.warning(
Expand Down Expand Up @@ -91,7 +93,7 @@ class RucioRemoteBackend(strax.FileSytemBackend):
# for caching RSE locations
dset_cache: Dict[str, str] = {}

def __init__(self, staging_dir, download_heavy=False, **kwargs):
def __init__(self, staging_dir, download_heavy=False, rses_only=tuple(), **kwargs):
"""
:param staging_dir: Path (a string) where to save data. Must be
a writable location.
Expand All @@ -114,13 +116,19 @@ def __init__(self, staging_dir, download_heavy=False, **kwargs):
super().__init__(**kwargs)
self.staging_dir = staging_dir
self.download_heavy = download_heavy
self.rses_only = strax.to_str_tuple(rses_only)

def _get_rse(self, dset_did):
rses = admix.rucio.get_rses(dset_did)
rses = list(set(rses) & set(self.rses_only)) if self.rses_only else rses
rse = admix.downloader.determine_rse(rses)
return rse

def _get_metadata(self, dset_did, **kwargs):
if dset_did in self.dset_cache:
rse = self.dset_cache[dset_did]
else:
rses = admix.rucio.get_rses(dset_did)
rse = admix.downloader.determine_rse(rses)
rse = self._get_rse(dset_did)
self.dset_cache[dset_did] = rse

metadata_did = strax.RUN_METADATA_PATTERN % dset_did
Expand Down Expand Up @@ -156,8 +164,7 @@ def _read_chunk(self, dset_did, chunk_info, dtype, compressor):
if dset_did in self.dset_cache:
rse = self.dset_cache[dset_did]
else:
rses = admix.rucio.get_rses(dset_did)
rse = admix.downloader.determine_rse(rses)
rse = self._get_rse(dset_did)
self.dset_cache[dset_did] = rse

downloaded = admix.download(chunk_did, rse=rse, location=self.staging_dir)
Expand Down

0 comments on commit 35917e7

Please sign in to comment.