Skip to content

Commit

Permalink
feat(registry): make registry iterable (yields protocol names)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Dec 15, 2023
1 parent bac3225 commit da5794b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
5 changes: 5 additions & 0 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Changelog
#########

develop
~~~~~~~

- feat(registry): make registry iterable (yields protocol names)

Version 5.0.1 (2023-04-21)
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
85 changes: 50 additions & 35 deletions pyannote/database/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@
from .database import Database
import yaml


# controls what to do in case of protocol name conflict
class LoadingMode(Enum):
OVERRIDE = 0 # override existing protocol
KEEP = 1 # keep existing protocol
ERROR = 2 # raise an error
KEEP = 1 # keep existing protocol
ERROR = 2 # raise an error


# To ease the understanding of future me, all comments inside Registry codebase
# assume the existence of the following database.yml files.
Expand All @@ -52,7 +54,7 @@ class LoadingMode(Enum):
# Content of /path/to/first/database.yml
# ======================================
# Databases:
# DatabaseA:
# DatabaseA:
# - relative/path/A/trn/{uri}.wav
# - relative/path/A/dev/{uri}.wav
# - relative/path/A/tst/{uri}.wav
Expand Down Expand Up @@ -95,7 +97,7 @@ class LoadingMode(Enum):
# DatabaseC:
# SpeakerDiarization:
# Protocol:
# ...
# ...


class Registry:
Expand All @@ -109,13 +111,12 @@ class Registry:
"""

def __init__(self) -> None:

# Mapping of database.yml paths to their config in a dictionary
# Example after loading both database.yml:
# {"/path/to/first/database.yml": {
# "Databases":{
# "DatabaseA": ["relative/path/A/trn/{uri}.wav", "relative/path/A/dev/{uri}.wav", relative/path/A/tst/{uri}.wav]
# "DatabaseB": "/absolute/path/B/{uri}.wav"
# "DatabaseB": "/absolute/path/B/{uri}.wav"
# },
# "Protocols":{
# "DatabaseA":{
Expand All @@ -134,7 +135,7 @@ def __init__(self) -> None:
# "/path/to/second/database.yml": {
# "Databases":{
# "DatabaseC": /absolute/path/C/{uri}.wav
# "DatabaseB": "/absolute/path/B/{uri}.wav"
# "DatabaseB": "/absolute/path/B/{uri}.wav"
# },
# "Protocols":{
# "DatabaseB":{"SpeakerDiarization": {"Protocol": {...}}},
Expand All @@ -144,7 +145,6 @@ def __init__(self) -> None:
# }
self.configs: Dict[Path, Dict] = dict()


# Content of the "Database" root item (= where to find file content)
# Example after loading both database.yml:
# {
Expand Down Expand Up @@ -176,7 +176,7 @@ def load_database(
Parameters
----------
path : str or Path
Path to YAML configuration file.
Path to YAML configuration file.
mode : LoadingMode, optional
Controls how to handle conflicts in protocol names.
Defaults to overriding the existing protocol.
Expand Down Expand Up @@ -210,7 +210,7 @@ def _load_database_helper(
# make path absolute
database_yml = Path(database_yml).expanduser().resolve()

# stop here if configuration file is already being loaded
# stop here if configuration file is already being loaded
# (possibly because of circular requirements)
if database_yml in loading:
return
Expand All @@ -221,9 +221,9 @@ def _load_database_helper(
# load configuration
with open(database_yml, "r") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

# load every requirement
requirements = config.pop('Requirements', list())
requirements = config.pop("Requirements", list())
if not isinstance(requirements, list):
requirements = [requirements]
for requirement_yaml in requirements:
Expand All @@ -244,9 +244,7 @@ def _load_database_helper(

# load protocols of each database
for db_name, db_entries in protocols.items():
self._load_protocols(
db_name, db_entries, database_yml, mode=mode
)
self._load_protocols(db_name, db_entries, database_yml, mode=mode)

# process "Databases" section
databases = config.get("Databases", dict())
Expand All @@ -265,7 +263,6 @@ def _load_database_helper(
# save configuration for later reloading of meta-protocols
self.configs[database_yml] = config


def get_database(self, database_name, **kwargs) -> Database:
"""Get database by name
Expand All @@ -284,7 +281,6 @@ def get_database(self, database_name, **kwargs) -> Database:
database = self.databases[database_name]

except KeyError:

if database_name == "X":
msg = (
"Could not find any meta-protocol. Please refer to "
Expand All @@ -302,7 +298,9 @@ def get_database(self, database_name, **kwargs) -> Database:

return database(**kwargs)

def get_protocol(self, name, preprocessors: Optional[Preprocessors] = None) -> Protocol:
def get_protocol(
self, name, preprocessors: Optional[Preprocessors] = None
) -> Protocol:
"""Get protocol by full name
Parameters
Expand All @@ -329,6 +327,14 @@ def get_protocol(self, name, preprocessors: Optional[Preprocessors] = None) -> P
protocol.name = name
return protocol

# iterate over all protocols by name
def __iter__(self):
for database_name in self.databases:
database = self.get_database(database_name)
for task_name in database.get_tasks():
for protocol_name in database.get_protocols(task_name):
yield f"{database_name}.{task_name}.{protocol_name}"

def _load_protocols(
self,
db_name,
Expand Down Expand Up @@ -367,7 +373,9 @@ def _load_protocols(
# If needed, merge old protocols dict with the new one (according to current override rules)
if db_name in self.databases:
old_protocols = self.databases[db_name]._protocols
_merge_protocols_inplace(protocols, old_protocols, mode, db_name, database_yml)
_merge_protocols_inplace(
protocols, old_protocols, mode, db_name, database_yml
)

# create database class on-the-fly
protocol_list = [
Expand All @@ -389,13 +397,14 @@ def _reload_meta_protocols(self):
for db_yml, config in self.configs.items():
databases = config.get("Protocols", dict())
if "X" in databases:
self._load_protocols("X", databases["X"], db_yml, mode=LoadingMode.OVERRIDE)

self._load_protocols(
"X", databases["X"], db_yml, mode=LoadingMode.OVERRIDE
)


def _env_config_paths() -> List[Path]:
"""Parse PYANNOTE_DATABASE_CONFIG environment variable
PYANNOTE_DATABASE_CONFIG may contain multiple paths separation by ";".
Returns
Expand All @@ -413,10 +422,11 @@ def _env_config_paths() -> List[Path]:
paths.append(path)
return paths


def _find_default_ymls() -> List[Path]:
"""Get paths to default YAML configuration files
* $HOME/.pyannote/database.yml
* $HOME/.pyannote/database.yml
* $CWD/database.yml
* PYANNOTE_DATABASE_CONFIG environment variable
Expand All @@ -431,21 +441,23 @@ def _find_default_ymls() -> List[Path]:
home_db_yml = Path("~/.pyannote/database.yml").expanduser()
if home_db_yml.is_file():
paths.append(home_db_yml)

cwd_db_yml = Path.cwd() / "database.yml"
if cwd_db_yml.is_file():
paths.append(cwd_db_yml)

paths += _env_config_paths()

return paths


def _merge_protocols_inplace(
new_protocols: Dict[Tuple[Text, Text], Type],
old_protocols: Dict[Tuple[Text, Text], Type],
mode: LoadingMode,
db_name: str,
database_yml: str):
new_protocols: Dict[Tuple[Text, Text], Type],
old_protocols: Dict[Tuple[Text, Text], Type],
mode: LoadingMode,
db_name: str,
database_yml: str,
):
"""Merge new and old protocols inplace into the passed new_protocol.
Warning, merging order might be counterintuitive : "KEEP" strategy keeps element from the OLD protocol
Expand All @@ -471,7 +483,6 @@ def _merge_protocols_inplace(

# for all previously defined protocol (in old_protocols)
for p_id, old_p in old_protocols.items():

# if this protocol is redefined
if p_id in new_protocols:
t_name, p_name = p_id
Expand All @@ -480,13 +491,16 @@ def _merge_protocols_inplace(
# raise an error
if mode == LoadingMode.ERROR:
raise RuntimeError(
f"Cannot load {realname} protocol from '{database_yml}' as it already exists.")
f"Cannot load {realname} protocol from '{database_yml}' as it already exists."
)

# keep the new protocol
elif mode == LoadingMode.OVERRIDE:
warnings.warn(f"Replacing existing {realname} protocol by the one defined in '{database_yml}'.")
warnings.warn(
f"Replacing existing {realname} protocol by the one defined in '{database_yml}'."
)
pass

# keep the old protocol
elif mode == LoadingMode.KEEP:
warnings.warn(
Expand All @@ -498,9 +512,10 @@ def _merge_protocols_inplace(
else:
new_protocols[p_id] = old_p


# initialize the registry singleton
registry = Registry()

# load all database yaml files found at startup
for yml in _find_default_ymls():
registry.load_database(yml)
registry.load_database(yml)

0 comments on commit da5794b

Please sign in to comment.