Skip to content

Commit

Permalink
Fix auto-schema generation for nested config entries with _target_ (#48)
Browse files Browse the repository at this point in the history
* Debugging issues in auto-schema

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* WIP

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bugs in auto_schema, simplify the code

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Sep 9, 2024
1 parent 5770ff0 commit e6925d9
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 243 deletions.
151 changes: 65 additions & 86 deletions project/utils/auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,18 @@ def _get_schema_file_path(config_file: Path, schemas_dir: Path):
return schema_file


def _all_subentries_with_target(config: dict) -> dict[tuple[str, ...], dict]:
"""Iterator that yields all the nested config entries that have a _target_."""
entries = {}
if "_target_" in config:
entries[()] = config
for key, value in config.items():
if isinstance(value, dict):
for subkey, subvalue in _all_subentries_with_target(value).items():
entries[(key, *subkey)] = subvalue
return entries


def create_schema_for_config(
config: dict | DictConfig, config_file: Path, configs_dir: Path | None
) -> Schema | ObjectSchema:
Expand All @@ -553,6 +565,12 @@ def create_schema_for_config(
- Only the top-level config (`config`) can have a `defaults: list[str]` key.
- Should ideally load the defaults and merge this schema on top of them.
"""

_config_dict = (
OmegaConf.to_container(config, resolve=False) if isinstance(config, DictConfig) else config
)
assert isinstance(_config_dict, dict)

schema = copy.deepcopy(HYDRA_CONFIG_SCHEMA)
pretty_path = config_file.relative_to(configs_dir) if configs_dir else config_file
schema["title"] = f"Auto-generated schema for {pretty_path}"
Expand All @@ -569,99 +587,49 @@ def create_schema_for_config(
configs_dir=configs_dir,
)

if target_name := config.get("_target_"):
# There's a '_target_' key at the top level in the config file.
target = hydra.utils.get_object(target_name)
schema["description"] = f"Based on the signature of {target}."
if "properties" not in schema:
schema["properties"] = {}
assert "properties" in schema and isinstance(schema["properties"], dict)
schema["properties"]["_target_"] = PropertySchema(
type="string",
title="Target",
const=target_name,
# pattern=r"", # todo: Use a pattern to match python module import strings.
description=(
f"Target to instantiate, in this case: `{target_name}`\n"
# f"* Source: <file://{relative_to_cwd(inspect.getfile(target))}>\n"
# f"* Config file: <file://{config_file}>\n"
f"See the Hydra docs for '_target_': https://hydra.cc/docs/advanced/instantiate_objects/overview/\n"
),
)

nested_value_schema_from_target_signature = _get_schema_from_target(config)
# logger.debug(f"Schema from signature of {target}: {schema_from_target_signature}")

schema = _merge_dicts(
nested_value_schema_from_target_signature, # type: ignore
schema, # type: ignore
conflict_handler=overwrite,
)
# Config file that contains entries that may or may not have a _target_.
schema["additionalProperties"] = "_target_" not in config

return schema
for keys, value in _all_subentries_with_target(_config_dict).items():
is_top_level: bool = not keys

# Config file that contains entries that may or may not have a _target_.
schema["additionalProperties"] = True

def all_subentries_with_target(config: dict) -> dict[tuple[str, ...], dict]:
"""Iterator that yields all the nested config entries that have a _target_."""
entries = {}
for key, value in config.items():
if isinstance(value, dict) and "_target_" in value.keys():
entries[(key,)] = value
elif isinstance(value, dict):
for subkey, subvalue in all_subentries_with_target(value).items():
entries[(key, *subkey)] = subvalue
return entries
logger.debug(f"Handling key {'.'.join(keys)} in config at path {config_file}")

_config_dict = (
OmegaConf.to_container(config, resolve=False) if isinstance(config, DictConfig) else config
)
assert isinstance(_config_dict, dict)
for keys, value in all_subentries_with_target(_config_dict).items():
# Go over all the values in the config. If any of them have a `_target_`, then we can
# add a schema at that entry.
assert "_target_" in value
target = hydra.utils.get_object(value["_target_"])
nested_value_schema = _get_schema_from_target(value)

# try:
nested_value_schema_from_target_signature = _get_schema_from_target(value)
# except omegaconf.errors.InterpolationToMissingValueError:
# logger.warning(
# f"Unable to get the schema for {value['_target_']} at key {'.'.join(keys)} "
# f"in file {config_file}."
# )
# continue

if "$defs" in nested_value_schema_from_target_signature:
if "$defs" in nested_value_schema:
# note: can't have a $defs key in the schema.
schema.setdefault("$defs", {}).update( # type: ignore
nested_value_schema_from_target_signature.pop("$defs")
nested_value_schema.pop("$defs")
)
assert "properties" in nested_value_schema

logger.debug(
f"Getting schema from target {value['_target_']} at key {'.'.join(keys)} in "
f"{config_file}."
)

assert "properties" in nested_value_schema_from_target_signature
if is_top_level:
schema = _merge_dicts(schema, nested_value_schema, conflict_handler=overwrite)
continue

parent_keys, last_key = keys[:-1], keys[-1]
where_to_set: Schema | ObjectSchema = schema
for key in parent_keys:
where_to_set = where_to_set.setdefault("properties", {}).setdefault(key, {}) # type: ignore
where_to_set = where_to_set.setdefault("properties", {}).setdefault(
key, {"type": "object"}
) # type: ignore
if "_target_" not in where_to_set:
where_to_set["additionalProperties"] = True

logger.debug(f"Using schema from nested value at keys {keys}: {nested_value_schema}")

if "properties" not in where_to_set:
where_to_set["properties"] = {last_key: nested_value_schema_from_target_signature} # type: ignore
where_to_set["properties"] = {last_key: nested_value_schema} # type: ignore
elif last_key not in where_to_set["properties"]:
assert isinstance(last_key, str)
where_to_set["properties"][last_key] = nested_value_schema_from_target_signature # type: ignore
where_to_set["properties"][last_key] = nested_value_schema # type: ignore
else:
where_to_set["properties"] = _merge_dicts( # type: ignore
where_to_set["properties"],
nested_value_schema_from_target_signature, # type: ignore
{last_key: nested_value_schema}, # type: ignore
conflict_handler=overwrite,
)
raise NotImplementedError("todo: use merge_dicts here")

return schema

Expand Down Expand Up @@ -956,25 +924,17 @@ def _get_schema_from_target(config: dict | DictConfig) -> ObjectSchema | Schema:
by_alias=False,
)
json_schema = typing.cast(Schema, json_schema)

# if "$defs" in json_schema:
# for key, class_schema in list(json_schema["$defs"].items()):
# logger.debug(f"Before resolving {key}: {json_schema}")
# json_schema["$defs"].pop(key)
# class_schema["title"] = key
# expanded = json.dumps(json_schema).replace(
# json.dumps({"$ref": f"#/$defs/{key}"}), json.dumps(class_schema)
# )
# logger.debug(f"Expanded: {expanded}")
# json_schema = json.loads(expanded)
# logger.debug(f"After resolving {key}: {json_schema}")
# assert False, json_schema
assert "properties" in json_schema
except pydantic.PydanticSchemaGenerationError as e:
raise NotImplementedError(f"Unable to get the schema with pydantic: {e}")

assert "properties" in json_schema

# Add a description
json_schema["description"] = f"Based on the signature of {target}.\n" + json_schema.get(
"description", ""
)

docs_to_search: list[dp.Docstring] = []

if inspect.isclass(target):
Expand Down Expand Up @@ -1006,6 +966,25 @@ def _get_schema_from_target(config: dict | DictConfig) -> ObjectSchema | Schema:

if config.get("_partial_"):
json_schema["required"] = []
# Add some info on the target.
if "_target_" not in json_schema["properties"]:
json_schema["properties"]["_target_"] = {}
else:
assert isinstance(json_schema["properties"]["_target_"], dict)
json_schema["properties"]["_target_"].update(
PropertySchema(
type="string",
title="Target",
const=config["_target_"],
# pattern=r"", # todo: Use a pattern to match python module import strings.
description=(
f"Target to instantiate, in this case: `{target}`\n"
# f"* Source: <file://{relative_to_cwd(inspect.getfile(target))}>\n"
# f"* Config file: <file://{config_file}>\n"
f"See the Hydra docs for '_target_': https://hydra.cc/docs/advanced/instantiate_objects/overview/\n"
),
)
)

# if the target takes **kwargs, then we don't restrict additional properties.
json_schema["additionalProperties"] = inspect.getfullargspec(target).varkw is not None
Expand Down
26 changes: 19 additions & 7 deletions project/utils/auto_schema_test/nested.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,40 +70,52 @@
"description": "Whether instantiating this config should recursively instantiate children configs.\nSee: https://hydra.cc/docs/advanced/instantiate_objects/overview/#recursive-instantiation"
},
"a": {
"type": "object",
"additionalProperties": true,
"properties": {
"b": {
"type": "object",
"additionalProperties": true,
"properties": {
"foo": {
"properties": {
"_target_": {
"default": "project.utils.auto_schema_test.Foo",
"title": " Target ",
"default": "project.utils.auto_schema_test.Bar",
"title": "Target",
"type": "string",
"description": "The _target_ parameter of the Foo."
"description": "Target to instantiate, in this case: `<class 'project.utils.auto_schema_test.Bar'>`\nSee the Hydra docs for '_target_': https://hydra.cc/docs/advanced/instantiate_objects/overview/\n",
"const": "project.utils.auto_schema_test.Bar"
},
"_recursive_": {
"default": false,
"title": " Recursive ",
"type": "boolean",
"description": "The _recursive_ parameter of the Foo."
"description": "The _recursive_ parameter of the Bar."
},
"_convert_": {
"default": "all",
"title": " Convert ",
"type": "string",
"description": "The _convert_ parameter of the Foo."
"description": "The _convert_ parameter of the Bar."
},
"bar": {
"title": "Bar",
"type": "string",
"description": "Description of the `bar` argument."
},
"baz": {
"title": "Baz",
"type": "integer",
"description": "description of the `baz` argument from the cls docstring instead of the init docstring."
}
},
"required": [
"bar"
"bar",
"baz"
],
"title": "Foo",
"title": "Bar",
"type": "object",
"description": "Based on the signature of <class 'project.utils.auto_schema_test.Bar'>.\n",
"additionalProperties": false
}
}
Expand Down
5 changes: 3 additions & 2 deletions project/utils/auto_schema_test/nested.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
a:
b:
foo:
_target_: project.utils.auto_schema_test.Foo
bar: "bob"
_target_: project.utils.auto_schema_test.Bar
bar: "boo"
baz: 123
72 changes: 36 additions & 36 deletions project/utils/auto_schema_test/partial.json
Original file line number Diff line number Diff line change
@@ -1,35 +1,7 @@
{
"title": "Foo",
"description": "Based on the signature of <class 'project.utils.auto_schema_test.Foo'>.\n",
"properties": {
"_target_": {
"default": "project.utils.auto_schema_test.Foo",
"title": "Target",
"type": "string",
"description": "Target to instantiate, in this case: `project.utils.auto_schema_test.Foo`\nSee the Hydra docs for '_target_': https://hydra.cc/docs/advanced/instantiate_objects/overview/\n",
"const": "project.utils.auto_schema_test.Foo"
},
"_recursive_": {
"default": false,
"title": "Recursive",
"type": "boolean",
"description": "Whether instantiating this config should recursively instantiate children configs.\nSee: https://hydra.cc/docs/advanced/instantiate_objects/overview/#recursive-instantiation"
},
"_convert_": {
"default": "all",
"title": "Convert",
"type": "string",
"description": "See https://hydra.cc/docs/advanced/instantiate_objects/overview/#parameter-conversion-strategies",
"enum": [
"none",
"partial",
"object",
"all"
]
},
"bar": {
"title": "Bar",
"type": "string",
"description": "Description of the `bar` argument."
},
"defaults": {
"title": "Hydra defaults",
"description": "Hydra defaults for this config. See https://hydra.cc/docs/advanced/defaults_list/",
Expand Down Expand Up @@ -71,17 +43,42 @@
},
"uniqueItems": true
},
"_target_": {
"type": "string",
"title": "Target",
"description": "Target to instantiate, in this case: `<class 'project.utils.auto_schema_test.Foo'>`\nSee the Hydra docs for '_target_': https://hydra.cc/docs/advanced/instantiate_objects/overview/\n",
"default": "project.utils.auto_schema_test.Foo",
"const": "project.utils.auto_schema_test.Foo"
},
"_convert_": {
"type": "string",
"enum": [
"none",
"partial",
"object",
"all"
],
"title": " Convert ",
"description": "The _convert_ parameter of the Foo.",
"default": "all"
},
"_partial_": {
"type": "boolean",
"title": "Partial",
"description": "Whether this config calls the target function when instantiated, or creates a `functools.partial` that will call the target.\nSee: https://hydra.cc/docs/advanced/instantiate_objects/overview"
},
"_recursive_": {
"type": "boolean",
"title": " Recursive ",
"description": "The _recursive_ parameter of the Foo.",
"default": false
},
"bar": {
"title": "Bar",
"type": "string",
"description": "Description of the `bar` argument."
}
},
"required": [],
"title": "Auto-generated schema for partial.yaml",
"type": "object",
"additionalProperties": false,
"description": "Based on the signature of <class 'project.utils.auto_schema_test.Foo'>.",
"dependentRequired": {
"_convert_": [
"_target_"
Expand All @@ -95,5 +92,8 @@
"_recursive_": [
"_target_"
]
}
},
"additionalProperties": false,
"required": [],
"type": "object"
}
Loading

0 comments on commit e6925d9

Please sign in to comment.