Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up equinox mlp implementation #144

Merged
merged 6 commits into from
Feb 29, 2024
Merged

clean up equinox mlp implementation #144

merged 6 commits into from
Feb 29, 2024

Conversation

grfrederic
Copy link
Collaborator

getting back in the saddle :)

@pawel-czyz
Copy link
Member

I appreciate the effort, but I don't think it'll work without further adjustments for the following reason: jax.nn.relu is a parameter-free function and Optax has problems with initialising such layers...

src/bmi/estimators/neural/_estimators.py:167: in estimate
    return self.estimate_with_info(x, y).mi_estimate
src/bmi/estimators/neural/_estimators.py:144: in estimate_with_info
    training_log, new_critic = basic_training(
src/bmi/estimators/neural/_basic_training.py:56: in basic_training
    opt_state = optimizer.init(critic)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in init_fn
    return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in <genexpr>
    return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:353: in init_fn
    mu = jax.tree_util.tree_map(  # First moment
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:354: in <lambda>
    lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2248: in zeros_like

One possible solution in Equinox is filtering, but I worry that it'll make the rest of the codebase a bit more complex. What do you think?

@grfrederic grfrederic merged commit d1f94da into main Feb 29, 2024
2 checks passed
@grfrederic grfrederic deleted the clean_up_mlp branch February 29, 2024 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants