Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable version of Hi-Fi Mocks #1

Open
EiffL opened this issue May 5, 2023 · 3 comments
Open

Differentiable version of Hi-Fi Mocks #1

EiffL opened this issue May 5, 2023 · 3 comments

Comments

@EiffL
Copy link

EiffL commented May 5, 2023

Here the beginning of an example about how to implement this using JAX:
https://colab.research.google.com/drive/1uneGJL4ewmV-Oyn-9k1H22gQnFPn_yB2?usp=sharing

@EiffL
Copy link
Author

EiffL commented May 5, 2023

And just for completeness, here is the notebook that can run HMC on a lognormal model:
https://colab.research.google.com/drive/1U6HymNm0mJD-Kj07YkN7FLFfp_mZCfGW?usp=sharing

@EiffL
Copy link
Author

EiffL commented May 5, 2023

@andrejobuljen I had a little bit of time on the train and tried to code up a full example... with moderate success ^^

image

The code is in this new notebook

I only implemented up to dG2, not d3, I don't know how much that matters.

And I saw that you had an orthogonalization step that I didn't implement because we didn't discuss it earlier today, probably that matters?

And finally I am not 100% sure of the conversion from my unitless conventions of scales to h/Mpc needed in to compute the filters (I'm like 80% sure). So if everything looks ok but the results still don't make sense it could come from there.

@andrejobuljen
Copy link
Owner

@EiffL great thanks!

Yes, sorry, I also realised later I forgot to mention the orthogonalisation step when we discussed yesterday. For that step I'd only need to measure various cross-Pk between fields. I tried running auto-pk but run into the following problem:

  • If I run: p1 = power_spectrum(shifted1, kmin=jnp.pi/256, dk=2.*jnp.pi/256, boxsize=box_size[0]), I get this error due to boxsize :"TypeError: 'float' object is not iterable".

  • If I run: p1 = power_spectrum(shifted1, kmin=jnp.pi/256, dk=2.*jnp.pi/256, boxsize=box_size), I get this error "AttributeError: 'list' object has no attribute 'prod'", again due to boxsize.

  • So what I did is to substitute P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') with P = ((Psum / Nsum)[1:-1] * jnp.array(boxsize).prod()).astype('float32'), run it with boxsize=box_size and it worked.

Do you maybe have somewhere a cross-pk function written in jax?

Also thanks for fixing the weights in cic_paint! I did something similar yesterday adding:
if weight is not None: kernel = kernel * weight[...,jnp.newaxis]
below this line:
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2], but I will just use your new version.

Btw, do you know how to match the jax seed to numpy.random.seed when generating initial conditions? In nbodykit it was enough to use the same random seed that was used in TNG simulation and I'd get same IC/phases. Here, in jax, it seems like it doesn't work... i.e. when I use the same seed I get a different field. This may not be super important, it would just allow to test the code quicker...

Cheers,
Andrej

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants