Skip to content

Commit

Permalink
Improve auto-generation of refs with backticks (#67)
Browse files Browse the repository at this point in the history
* Improve auto-generation of refs from backticks

- Make it possible to specify some packages to include in a yaml-style
  mkdocs header. `backtick` refs in that doc file will then be able to
  reference objects in those modules (or objects in that list)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with refs in code blocks, add comment

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix tests to reflect added backticks in ref text

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Oct 11, 2024
1 parent 5f7bfac commit a5acd0b
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 22 deletions.
6 changes: 1 addition & 5 deletions docs/features/jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ There are lots of good reasons why you might want to let Lightning handle the tr
which are very well described [here](https://lightning.ai/docs/pytorch/stable/).

??? note "What about end-to-end training in Jax?"
This template doesn't include a way to do end-to-end, fully-jitted training in Jax, however, it _might_ be possible to do so in this way:

- add a new configuration in the `trainer` config group, with a `_target_` pointing to a
trainer-like object with a `fit`, `evaluate` and `test` method mimicking those of PyTorch-Lightning.
- add a new configuration in the `algorithm` config group pointing to a learning algorithm class that isn't a LightningModule.
See the [Jax RL Example (coming soon!)](https://github.com/mila-iqia/ResearchTemplate/pull/55)

If you want an example of how to do this, please make an issue (or like an existing issue) on GitHub.

## `JaxExample`: a LightningModule that uses Jax

Expand Down
111 changes: 99 additions & 12 deletions project/utils/autoref_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

import functools
import inspect
import re
import types

import lightning
import torch
Expand All @@ -21,23 +23,28 @@
# Same as in the mkdocs_autorefs plugin.
logger = get_plugin_logger(__name__)

known_things = [
default_reference_sources = [
lightning.Trainer,
lightning.LightningModule,
lightning.LightningDataModule,
torch.nn.Module,
]
"""
IDEA: IF we see `Trainer`, and we know that that's the `lightning.Trainer`, then we
create the proper ref.
"""These are some "known objects" that can be referenced with backticks anywhere in the docs.
TODO: Ideally this would contain every object / class that we know of in this project.
Additionally, if there were modules in here, then any of their public members can also be
referenced.
"""
from mkdocstrings.plugin import MkdocstringsPlugin # noqa
from mkdocstrings_handlers.python.handler import PythonHandler # noqa


class CustomAutoRefPlugin(BasePlugin):
"""Small mkdocs plugin that converts backticks to refs when possible."""

def __init__(self):
super().__init__()
self.default_reference_sources = sum(map(_expand, default_reference_sources), [])

def on_page_markdown(
self, markdown: str, /, *, page: Page, config: MkDocsConfig, files: Files
) -> str | None:
Expand All @@ -47,32 +54,77 @@ def on_page_markdown(
# - `package.foo.bar` -> [package.foo.bar][] (only if `package.foo.bar` is importable)
# - `baz` -> [baz][]

def _full_path(thing) -> str:
return thing.__module__ + "." + thing.__qualname__
# TODO: The idea here is to also make all the members of a module referentiable with
# backticks in the same module. The problem is that the "reference" page we create with
# mkdocstrings only contains a simple `::: project.path.to.module` and doesn't have any
# text, so we can't just replace the `backticks` with refs, since mkdocstrings hasn't yet
# processed the module into a page with the reference docs. This seems to be happening
# in a markdown extension (the `MkdocstringsExtension`).

# file = page.file.abs_src_path
# if file and "reference/project" in file:
# relative_path = file[file.index("reference/") :].removeprefix("reference/")
# module_path = relative_path.replace("/", ".").replace(".md", "")
# if module_path.endswith(".index"):
# module_path = module_path.removesuffix(".index")
# logger.error(
# f"file {relative_path} is the reference page for the python module {module_path}"
# )
# if "algorithms/example" in file:
# assert False, markdown
# additional_objects = _expand(module_path)
if referenced_packages := page.meta.get("additional_python_references", []):
additional_objects: list[object] = _get_referencable_objects_from_doc_page_header(
referenced_packages
)
else:
additional_objects = []

if additional_objects:
additional_objects = [
obj
for obj in additional_objects
if (
inspect.isfunction(obj)
or inspect.isclass(obj)
or inspect.ismodule(obj)
or inspect.ismethod(obj)
)
# and (hasattr(obj, "__name__") or hasattr(obj, "__qualname__"))
]

known_thing_names = [t.__name__ for t in known_things]
known_objects_for_this_module = self.default_reference_sources + additional_objects
known_object_names = [t.__name__ for t in known_objects_for_this_module]

new_markdown = []
# TODO: This changes things inside code blocks, which is not desired!
in_code_block = False

for line_index, line in enumerate(markdown.splitlines(keepends=True)):
# Can't convert `this` to `[this][]` in headers, otherwise they break.
if line.lstrip().startswith("#"):
new_markdown.append(line)
continue
if "```" in line:
in_code_block = not in_code_block
if in_code_block:
new_markdown.append(line)
continue

matches = re.findall(r"`([^`]+)`", line)

for match in matches:
thing_name = match
if "." not in thing_name and thing_name in known_thing_names:
thing = known_things[known_thing_names.index(thing_name)]
if thing_name in known_object_names:
# References like `JaxTrainer` (which are in a module that we're aware of).
thing = known_objects_for_this_module[known_object_names.index(thing_name)]
else:
thing = _try_import_thing(thing_name)

if thing is None:
logger.debug(f"Unable to import {thing_name}, leaving it as-is.")
continue

new_ref = f"[{thing_name}][{_full_path(thing)}]"
new_ref = f"[`{thing_name}`][{_full_path(thing)}]"
logger.info(
f"Replacing `{thing_name}` with {new_ref} in {page.file.abs_src_path}:{line_index}"
)
Expand All @@ -83,6 +135,41 @@ def _full_path(thing) -> str:
return "".join(new_markdown)


def _expand(obj: types.ModuleType | object) -> list[object]:
if not inspect.ismodule(obj):
# The ref is something else (a class, function, etc.)
return [obj]

# The ref is a package, so we import everything from it.
# equivalent of `from package import *`
if hasattr(obj, "__all__"):
return [getattr(obj, name) for name in obj.__all__]
else:
objects_in_global_scope = [v for k, v in vars(obj).items() if not k.startswith("_")]
# Don't consider any external modules that were imported in the global scope.
source_file = inspect.getsourcefile(obj)
# too obtuse, but whatever
return [
v
for v in objects_in_global_scope
if not (inspect.ismodule(v) and inspect.getsourcefile(v) != source_file)
]


def _get_referencable_objects_from_doc_page_header(doc_page_references: list[str]):
additional_objects: list[object] = []
for package in doc_page_references:
additional_ref_source = import_object(package)
additional_objects.extend(_expand(additional_ref_source))
return additional_objects


def _full_path(thing) -> str:
if inspect.ismodule(thing):
return thing.__name__
return thing.__module__ + "." + getattr(thing, "__qualname__", thing.__name__)


@functools.cache
def _try_import_thing(thing: str):
try:
Expand Down
36 changes: 31 additions & 5 deletions project/utils/autoref_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
(_header := "## Some header with a ref `lightning.Trainer`", _header),
(
"a backtick ref: `lightning.Trainer`",
"a backtick ref: [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer]",
"a backtick ref: [`lightning.Trainer`][lightning.pytorch.trainer.trainer.Trainer]",
),
("`torch.Tensor`", "[torch.Tensor][torch.Tensor]"),
("`torch.Tensor`", "[`torch.Tensor`][torch.Tensor]"),
(
"a proper full ref: "
+ (
Expand All @@ -28,14 +28,14 @@
(
"`jax.Array`",
# not sure if this will make a proper link in mkdocs though.
"[jax.Array][jax.Array]",
"[`jax.Array`][jax.Array]",
),
("`Trainer`", "[Trainer][lightning.pytorch.trainer.trainer.Trainer]"),
("`Trainer`", "[`Trainer`][lightning.pytorch.trainer.trainer.Trainer]"),
# since `Trainer` is in the `known_things` list, we add the proper ref.
],
)
def test_autoref_plugin(input: str, expected: str):
config = MkDocsConfig("mkdocs.yaml")
config: MkDocsConfig = MkDocsConfig("mkdocs.yaml") # type: ignore (weird!)
plugin = CustomAutoRefPlugin()
result = plugin.on_page_markdown(
input,
Expand All @@ -53,3 +53,29 @@ def test_autoref_plugin(input: str, expected: str):
files=Files([]),
)
assert result == expected


def test_ref_using_additional_python_references():
mkdocs_config: MkDocsConfig = MkDocsConfig("mkdocs.yaml") # type: ignore (weird!)

plugin = CustomAutoRefPlugin()

page = Page(
title="Test",
file=File(
"test.md",
src_dir="bob",
dest_dir="bobo",
use_directory_urls=False,
),
config=mkdocs_config,
)
page.meta = {"additional_python_references": ["project.algorithms.example"]}

result = plugin.on_page_markdown(
"`ExampleAlgorithm`",
page=page,
config=mkdocs_config,
files=Files([]),
)
assert result == "[`ExampleAlgorithm`][project.algorithms.example.ExampleAlgorithm]"
7 changes: 7 additions & 0 deletions project/utils/hydra_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,15 @@ def import_object(target_path: str):
assert not target_path.endswith(
".py"
), "expect a valid python path like 'module.submodule.object'"
if "." not in target_path:
return importlib.import_module(target_path)

parts = target_path.split(".")
try:
return importlib.import_module(name=parts[-1], package=".".join(parts[:-1]))
except (ModuleNotFoundError, AttributeError):
pass

for i in range(1, len(parts)):
module_name = ".".join(parts[:i])
obj_path = parts[i:]
Expand Down

0 comments on commit a5acd0b

Please sign in to comment.