Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slp converter #7

Merged
merged 6 commits into from
Jun 22, 2024
Merged

Conversation

apasarkar
Copy link
Contributor

Key changes:

  • Fixes bug with smoother code
  • On CPU, provides an updated jax.lax.scan implementation that can do filtering on blocks of keypoints, parallelized at all stages (over keypoints, over time points, etc.)
  • Includes a GPU parallel scan implementation for the Kalman filter -- again fully "jitted". This is used to do parameter estimation for the smoothing parameter. Code computes the Kalman filter + nonnegative log likelihood very fast.
  • Optax (jax) optimizer differentiates through all of the above implementations, allowing fast MLE computation of the smoothing parameter. Note: Can in principle also compute the observation noise data with minor modifications.

@apasarkar apasarkar marked this pull request as draft June 22, 2024 21:46
@keeminlee keeminlee marked this pull request as ready for review June 22, 2024 23:06
@keeminlee keeminlee merged commit 598fd56 into paninski-lab:slp-converter Jun 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants