-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dev: Transformations. #105
Comments
On a related note, it seems that the log-det-jacobian term is missing from the objective function that is passed to HMC. |
The PR refactors transformations #105. Transformations have been removed from objective functions e.g. the marginal likelihood. So instead of e.g., objective = posterior.mll(D, transformation, negative=True), we now do objective = posterior.mll(D, negative=True). The transformations are instead placed with training loops. To remove clutter, we remove the dictionary constrainers and unconstrainers notion, and instead define transformations via bijectors. In particular: The ParameterState dataclass now comprises params,trainablesand bijectors. So e.g. calling parameter_state = gpx.initialise(posterior, key)and then parameter_state.unpack() returns a tuple params, trainables, bijectors. The function gpx.transform has been removed. To transform parameters to the constrained space, we now do gpx.constrain(params, bijectors) and likewise to transform to the unconstrained space, we do gpx.unconstrain(params, bijectors). Finally, we now define training loops in abstractions e.g. fit directly on a ParameterState dataclass (i.e., we don't pass through the params dictionary, trainables dictionary and the transformations, as we used to do).
Thanks for spotting this @murphyk (and apologies for the slow reply)! We recently refactored |
Perhaps we should not require parameter transformations (via
transform
) in objective functions - they should rest with model training.If I have an ELBO or the marginal log-likelihood, shouldn't I just be able to pass my parameters to it without transforming anything?
For example, for the marginal log-likelihood of GP regression, in the
fit
abstraction we currently have to objective function asobjective = posterior.mll(D, transformation, negative=True)
andabstractions.py
defines a loss function to train the parameters (that stops gradients):Perhaps it would be nicer to define an objective with
objective = posterior.mll(D, negative=True)
(i.e. no transforms specified) and then have the transform in the training loop instead e.g.,The training loop could even possibly take a bijector argument and
abstractions.py
could manage forward and reverse transformations (gpx.initialise
could return a dictionary of bijectors instead of theconstrainer
andunconstrainer
convention).The text was updated successfully, but these errors were encountered: