diff --git a/straxen/storage/rucio_remote.py b/straxen/storage/rucio_remote.py index 508f77ff4..39c504fbb 100644 --- a/straxen/storage/rucio_remote.py +++ b/straxen/storage/rucio_remote.py @@ -33,7 +33,9 @@ class RucioRemoteFrontend(strax.StorageFrontend): local_did_cache = None path = None - def __init__(self, download_heavy=False, staging_dir="./strax_data", *args, **kwargs): + def __init__( + self, download_heavy=False, staging_dir="./strax_data", rses_only=tuple(), *args, **kwargs + ): """ :param download_heavy: option to allow downloading of heavy data through RucioRemoteBackend :param args: Passed to strax.StorageFrontend @@ -46,7 +48,7 @@ def __init__(self, download_heavy=False, staging_dir="./strax_data", *args, **kw if HAVE_ADMIX: self.backends = [ - RucioRemoteBackend(staging_dir, download_heavy=download_heavy), + RucioRemoteBackend(staging_dir, download_heavy=download_heavy, rses_only=rses_only), ] else: self.log.warning( @@ -91,7 +93,7 @@ class RucioRemoteBackend(strax.FileSytemBackend): # for caching RSE locations dset_cache: Dict[str, str] = {} - def __init__(self, staging_dir, download_heavy=False, **kwargs): + def __init__(self, staging_dir, download_heavy=False, rses_only=tuple(), **kwargs): """ :param staging_dir: Path (a string) where to save data. Must be a writable location. @@ -114,13 +116,19 @@ def __init__(self, staging_dir, download_heavy=False, **kwargs): super().__init__(**kwargs) self.staging_dir = staging_dir self.download_heavy = download_heavy + self.rses_only = strax.to_str_tuple(rses_only) + + def _get_rse(self, dset_did): + rses = admix.rucio.get_rses(dset_did) + rses = list(set(rses) & set(self.rses_only)) if self.rses_only else rses + rse = admix.downloader.determine_rse(rses) + return rse def _get_metadata(self, dset_did, **kwargs): if dset_did in self.dset_cache: rse = self.dset_cache[dset_did] else: - rses = admix.rucio.get_rses(dset_did) - rse = admix.downloader.determine_rse(rses) + rse = self._get_rse(dset_did) self.dset_cache[dset_did] = rse metadata_did = strax.RUN_METADATA_PATTERN % dset_did @@ -156,8 +164,7 @@ def _read_chunk(self, dset_did, chunk_info, dtype, compressor): if dset_did in self.dset_cache: rse = self.dset_cache[dset_did] else: - rses = admix.rucio.get_rses(dset_did) - rse = admix.downloader.determine_rse(rses) + rse = self._get_rse(dset_did) self.dset_cache[dset_did] = rse downloaded = admix.download(chunk_did, rse=rse, location=self.staging_dir)