Skip to content

Commit

Permalink
Fix bug with config target type inference
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Oct 7, 2024
1 parent 2301c46 commit 0990cef
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion project/algorithms/jax_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def create(
env: Environment[TEnvState, TEnvParams] | None = None,
env_params: TEnvParams | None = None,
hp: PPOHParams | None = None,
):
) -> JaxRLExample[TEnvState, TEnvParams]:
from brax.envs import _envs as brax_envs
from rejax.compat.brax2gymnax import create_brax

Expand Down
13 changes: 13 additions & 0 deletions project/algorithms/jax_rl_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict
from project.utils.testutils import run_for_all_configs_of_type

from .jax_rl_example import (
EvalMetrics,
Expand All @@ -55,9 +56,11 @@
make_actor,
render_episode,
)
from .testsuites.algorithm_tests import LearningAlgorithmTests

logger = getLogger(__name__)


@pytest.fixture(params=["Pendulum-v1"])
def env_id(request: pytest.FixtureRequest) -> str:
# env_id = "halfcheetah"
Expand Down Expand Up @@ -174,6 +177,7 @@ def _add_gitignore_if_needed(original_datadir: Path):
gitignore_file.parent.mkdir(exist_ok=True, parents=True)
gitignore_file.write_text("*.gif\n")


@pytest.mark.slow
@pytest.mark.timeout(35)
def test_train_ours(
Expand Down Expand Up @@ -359,11 +363,13 @@ def test_ours_with_vmap(
gif_path=figures_dir / "pure_jax_avg.gif",
)


## Pytorch-Lightning wrapper around this learner:

# Don't allow tests to run for more than 5 seconds.
# pytestmark = pytest.mark.timeout(5)


class PPOLightningModule(lightning.LightningModule):
"""Uses the same code as [project.algorithms.jax_rl_example.JaxRLExample][], but the training
loop is run with pytorch-lightning.
Expand Down Expand Up @@ -615,6 +621,13 @@ def log(
# )


# TODO: potentially just use the Lightning adapter for unit tests for now?
@pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.")
@run_for_all_configs_of_type("algorithm", JaxRLExample)
class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore
pass


@pytest.fixture
def lightning_trainer(max_epochs: int, tmp_path: Path):
return lightning.Trainer(
Expand Down
7 changes: 5 additions & 2 deletions project/utils/hydra_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,21 @@ def get_all_configs_in_group_of_type(
config_name: get_target_of_config(config_group, config_name)
for config_name in config_names
}

names_to_types: dict[str, type] = {}
for name, target in names_to_targets.items():
if inspect.isclass(target):
names_to_types[name] = target
continue

if (
inspect.isfunction(target)
(inspect.isfunction(target) or inspect.ismethod(target))
and (annotations := typing.get_type_hints(target))
and (return_type := annotations.get("return"))
and inspect.isclass(return_type)
and (inspect.isclass(return_type) or inspect.isclass(typing.get_origin(return_type)))
):
# Resolve generic aliases if present.
return_type = typing.get_origin(return_type) or return_type
logger.info(
f"Assuming that the function {target} creates objects of type {return_type} based "
f"on its return type annotation."
Expand Down

0 comments on commit 0990cef

Please sign in to comment.