From a5acd0bf71950a8fac6ba30d4eeded792b18c9fd Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 11 Oct 2024 13:44:49 -0400 Subject: [PATCH] Improve auto-generation of refs with backticks (#67) * Improve auto-generation of refs from backticks - Make it possible to specify some packages to include in a yaml-style mkdocs header. `backtick` refs in that doc file will then be able to reference objects in those modules (or objects in that list) Signed-off-by: Fabrice Normandin * Fix issue with refs in code blocks, add comment Signed-off-by: Fabrice Normandin * Fix tests to reflect added backticks in ref text Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- docs/features/jax.md | 6 +- project/utils/autoref_plugin.py | 111 ++++++++++++++++++++++++--- project/utils/autoref_plugin_test.py | 36 +++++++-- project/utils/hydra_config_utils.py | 7 ++ 4 files changed, 138 insertions(+), 22 deletions(-) diff --git a/docs/features/jax.md b/docs/features/jax.md index 83f242d6..3938129d 100644 --- a/docs/features/jax.md +++ b/docs/features/jax.md @@ -10,13 +10,9 @@ There are lots of good reasons why you might want to let Lightning handle the tr which are very well described [here](https://lightning.ai/docs/pytorch/stable/). ??? note "What about end-to-end training in Jax?" - This template doesn't include a way to do end-to-end, fully-jitted training in Jax, however, it _might_ be possible to do so in this way: - - add a new configuration in the `trainer` config group, with a `_target_` pointing to a - trainer-like object with a `fit`, `evaluate` and `test` method mimicking those of PyTorch-Lightning. - - add a new configuration in the `algorithm` config group pointing to a learning algorithm class that isn't a LightningModule. + See the [Jax RL Example (coming soon!)](https://github.com/mila-iqia/ResearchTemplate/pull/55) - If you want an example of how to do this, please make an issue (or like an existing issue) on GitHub. ## `JaxExample`: a LightningModule that uses Jax diff --git a/project/utils/autoref_plugin.py b/project/utils/autoref_plugin.py index b4f97145..0abbd948 100644 --- a/project/utils/autoref_plugin.py +++ b/project/utils/autoref_plugin.py @@ -3,7 +3,9 @@ """ import functools +import inspect import re +import types import lightning import torch @@ -21,23 +23,28 @@ # Same as in the mkdocs_autorefs plugin. logger = get_plugin_logger(__name__) -known_things = [ +default_reference_sources = [ lightning.Trainer, lightning.LightningModule, lightning.LightningDataModule, torch.nn.Module, ] -""" -IDEA: IF we see `Trainer`, and we know that that's the `lightning.Trainer`, then we -create the proper ref. +"""These are some "known objects" that can be referenced with backticks anywhere in the docs. -TODO: Ideally this would contain every object / class that we know of in this project. +Additionally, if there were modules in here, then any of their public members can also be +referenced. """ +from mkdocstrings.plugin import MkdocstringsPlugin # noqa +from mkdocstrings_handlers.python.handler import PythonHandler # noqa class CustomAutoRefPlugin(BasePlugin): """Small mkdocs plugin that converts backticks to refs when possible.""" + def __init__(self): + super().__init__() + self.default_reference_sources = sum(map(_expand, default_reference_sources), []) + def on_page_markdown( self, markdown: str, /, *, page: Page, config: MkDocsConfig, files: Files ) -> str | None: @@ -47,24 +54,69 @@ def on_page_markdown( # - `package.foo.bar` -> [package.foo.bar][] (only if `package.foo.bar` is importable) # - `baz` -> [baz][] - def _full_path(thing) -> str: - return thing.__module__ + "." + thing.__qualname__ + # TODO: The idea here is to also make all the members of a module referentiable with + # backticks in the same module. The problem is that the "reference" page we create with + # mkdocstrings only contains a simple `::: project.path.to.module` and doesn't have any + # text, so we can't just replace the `backticks` with refs, since mkdocstrings hasn't yet + # processed the module into a page with the reference docs. This seems to be happening + # in a markdown extension (the `MkdocstringsExtension`). + + # file = page.file.abs_src_path + # if file and "reference/project" in file: + # relative_path = file[file.index("reference/") :].removeprefix("reference/") + # module_path = relative_path.replace("/", ".").replace(".md", "") + # if module_path.endswith(".index"): + # module_path = module_path.removesuffix(".index") + # logger.error( + # f"file {relative_path} is the reference page for the python module {module_path}" + # ) + # if "algorithms/example" in file: + # assert False, markdown + # additional_objects = _expand(module_path) + if referenced_packages := page.meta.get("additional_python_references", []): + additional_objects: list[object] = _get_referencable_objects_from_doc_page_header( + referenced_packages + ) + else: + additional_objects = [] + + if additional_objects: + additional_objects = [ + obj + for obj in additional_objects + if ( + inspect.isfunction(obj) + or inspect.isclass(obj) + or inspect.ismodule(obj) + or inspect.ismethod(obj) + ) + # and (hasattr(obj, "__name__") or hasattr(obj, "__qualname__")) + ] - known_thing_names = [t.__name__ for t in known_things] + known_objects_for_this_module = self.default_reference_sources + additional_objects + known_object_names = [t.__name__ for t in known_objects_for_this_module] new_markdown = [] + # TODO: This changes things inside code blocks, which is not desired! + in_code_block = False + for line_index, line in enumerate(markdown.splitlines(keepends=True)): # Can't convert `this` to `[this][]` in headers, otherwise they break. if line.lstrip().startswith("#"): new_markdown.append(line) continue + if "```" in line: + in_code_block = not in_code_block + if in_code_block: + new_markdown.append(line) + continue matches = re.findall(r"`([^`]+)`", line) - for match in matches: thing_name = match - if "." not in thing_name and thing_name in known_thing_names: - thing = known_things[known_thing_names.index(thing_name)] + if thing_name in known_object_names: + # References like `JaxTrainer` (which are in a module that we're aware of). + thing = known_objects_for_this_module[known_object_names.index(thing_name)] else: thing = _try_import_thing(thing_name) @@ -72,7 +124,7 @@ def _full_path(thing) -> str: logger.debug(f"Unable to import {thing_name}, leaving it as-is.") continue - new_ref = f"[{thing_name}][{_full_path(thing)}]" + new_ref = f"[`{thing_name}`][{_full_path(thing)}]" logger.info( f"Replacing `{thing_name}` with {new_ref} in {page.file.abs_src_path}:{line_index}" ) @@ -83,6 +135,41 @@ def _full_path(thing) -> str: return "".join(new_markdown) +def _expand(obj: types.ModuleType | object) -> list[object]: + if not inspect.ismodule(obj): + # The ref is something else (a class, function, etc.) + return [obj] + + # The ref is a package, so we import everything from it. + # equivalent of `from package import *` + if hasattr(obj, "__all__"): + return [getattr(obj, name) for name in obj.__all__] + else: + objects_in_global_scope = [v for k, v in vars(obj).items() if not k.startswith("_")] + # Don't consider any external modules that were imported in the global scope. + source_file = inspect.getsourcefile(obj) + # too obtuse, but whatever + return [ + v + for v in objects_in_global_scope + if not (inspect.ismodule(v) and inspect.getsourcefile(v) != source_file) + ] + + +def _get_referencable_objects_from_doc_page_header(doc_page_references: list[str]): + additional_objects: list[object] = [] + for package in doc_page_references: + additional_ref_source = import_object(package) + additional_objects.extend(_expand(additional_ref_source)) + return additional_objects + + +def _full_path(thing) -> str: + if inspect.ismodule(thing): + return thing.__name__ + return thing.__module__ + "." + getattr(thing, "__qualname__", thing.__name__) + + @functools.cache def _try_import_thing(thing: str): try: diff --git a/project/utils/autoref_plugin_test.py b/project/utils/autoref_plugin_test.py index 614db34d..1cc4e382 100644 --- a/project/utils/autoref_plugin_test.py +++ b/project/utils/autoref_plugin_test.py @@ -12,9 +12,9 @@ (_header := "## Some header with a ref `lightning.Trainer`", _header), ( "a backtick ref: `lightning.Trainer`", - "a backtick ref: [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer]", + "a backtick ref: [`lightning.Trainer`][lightning.pytorch.trainer.trainer.Trainer]", ), - ("`torch.Tensor`", "[torch.Tensor][torch.Tensor]"), + ("`torch.Tensor`", "[`torch.Tensor`][torch.Tensor]"), ( "a proper full ref: " + ( @@ -28,14 +28,14 @@ ( "`jax.Array`", # not sure if this will make a proper link in mkdocs though. - "[jax.Array][jax.Array]", + "[`jax.Array`][jax.Array]", ), - ("`Trainer`", "[Trainer][lightning.pytorch.trainer.trainer.Trainer]"), + ("`Trainer`", "[`Trainer`][lightning.pytorch.trainer.trainer.Trainer]"), # since `Trainer` is in the `known_things` list, we add the proper ref. ], ) def test_autoref_plugin(input: str, expected: str): - config = MkDocsConfig("mkdocs.yaml") + config: MkDocsConfig = MkDocsConfig("mkdocs.yaml") # type: ignore (weird!) plugin = CustomAutoRefPlugin() result = plugin.on_page_markdown( input, @@ -53,3 +53,29 @@ def test_autoref_plugin(input: str, expected: str): files=Files([]), ) assert result == expected + + +def test_ref_using_additional_python_references(): + mkdocs_config: MkDocsConfig = MkDocsConfig("mkdocs.yaml") # type: ignore (weird!) + + plugin = CustomAutoRefPlugin() + + page = Page( + title="Test", + file=File( + "test.md", + src_dir="bob", + dest_dir="bobo", + use_directory_urls=False, + ), + config=mkdocs_config, + ) + page.meta = {"additional_python_references": ["project.algorithms.example"]} + + result = plugin.on_page_markdown( + "`ExampleAlgorithm`", + page=page, + config=mkdocs_config, + files=Files([]), + ) + assert result == "[`ExampleAlgorithm`][project.algorithms.example.ExampleAlgorithm]" diff --git a/project/utils/hydra_config_utils.py b/project/utils/hydra_config_utils.py index a6093656..5feccef5 100644 --- a/project/utils/hydra_config_utils.py +++ b/project/utils/hydra_config_utils.py @@ -111,8 +111,15 @@ def import_object(target_path: str): assert not target_path.endswith( ".py" ), "expect a valid python path like 'module.submodule.object'" + if "." not in target_path: + return importlib.import_module(target_path) parts = target_path.split(".") + try: + return importlib.import_module(name=parts[-1], package=".".join(parts[:-1])) + except (ModuleNotFoundError, AttributeError): + pass + for i in range(1, len(parts)): module_name = ".".join(parts[:i]) obj_path = parts[i:]