diff --git a/pyproject.toml b/pyproject.toml index 91398a3e..0a37aa51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,9 @@ addopts = [ "--color=yes", "--disable-pytest-warnings", ] +markers = [ + "online: run tests that require internet connection", +] filterwarnings = ["error::FutureWarning"] # todo: "error::DeprecationWarning" xfail_strict = true diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py new file mode 100644 index 00000000..3aac31cf --- /dev/null +++ b/tests/unittests/conftest.py @@ -0,0 +1,30 @@ +# Copyright The Lightning AI team. + +import os +import shutil +import tempfile +from pathlib import Path + +import pytest + +from unittests import _PATH_ROOT + +_PATH_DOCS = os.path.join(_PATH_ROOT, "docs", "source") + + +@pytest.fixture(scope="session") +def temp_docs(): + """Create a dummy documentation folder.""" + # create a folder for docs + docs_folder = Path(tempfile.mkdtemp()) + # copy all real docs from _PATH_DOCS to local temp_docs + for root, _, files in os.walk(_PATH_DOCS): + for file in files: + fpath = os.path.join(root, file) + temp_path = docs_folder / os.path.relpath(fpath, _PATH_DOCS) + temp_path.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "rb") as fopen: + temp_path.write_bytes(fopen.read()) + yield str(docs_folder) + # remove the folder + shutil.rmtree(docs_folder.parent, ignore_errors=True) diff --git a/tests/unittests/docs/test_retriever.py b/tests/unittests/docs/test_retriever.py index 3c56bbc1..2b503e1f 100644 --- a/tests/unittests/docs/test_retriever.py +++ b/tests/unittests/docs/test_retriever.py @@ -1,30 +1,44 @@ import os.path import shutil +import pytest from lightning_utilities.docs import fetch_external_assets -from unittests import _PATH_ROOT - -def test_retriever_s3(): - path_docs = os.path.join(_PATH_ROOT, "docs", "source") - path_index = os.path.join(path_docs, "index.rst") - path_page = os.path.join(path_docs, "any", "extra", "page.rst") +@pytest.mark.online() +def test_retriever_s3(temp_docs): + # take the index page + path_index = os.path.join(temp_docs, "index.rst") + # copy it to another location to test depth + path_page = os.path.join(temp_docs, "any", "extra", "page.rst") os.makedirs(os.path.dirname(path_page), exist_ok=True) shutil.copy(path_index, path_page) - fetch_external_assets(docs_folder=path_docs) + def _get_line_with_figure(path_rst: str) -> str: + with open(path_rst, encoding="UTF-8") as fopen: + lines = fopen.readlines() + # find the first line with figure reference + return next(ln for ln in lines if ln.startswith(".. figure::")) - with open(path_index, encoding="UTF-8") as fopen: - body = fopen.read() + # validate the initial expectations + line = _get_line_with_figure(path_index) # that the image exists~ - assert "Lightning.gif" in body - # but it is not sourced from S3 - assert ".s3." not in body + assert "Lightning.gif" in line + # and it is sourced in S3 + assert ".s3." in line + + fetch_external_assets(docs_folder=temp_docs) - with open(path_page, encoding="UTF-8") as fopen: - body = fopen.read() + # validate the final state of index page + line = _get_line_with_figure(path_index) # that the image exists~ - assert "Lightning.gif" in body - # check the proper depth - assert os.path.sep.join(["..", ".."]) in body + assert os.path.join("fetched-s3-assets", "Lightning.gif") in line + # but it is not sourced from S3 + assert ".s3." not in line + + # validate the final state of additional page + line = _get_line_with_figure(path_page) + # that the image exists in the proper depth + assert os.path.join("..", "..", "fetched-s3-assets", "Lightning.gif") in line + # but it is not sourced from S3 + assert ".s3." not in line