Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev: Transformations. #105

Closed
daniel-dodd opened this issue Aug 25, 2022 · 2 comments
Closed

dev: Transformations. #105

daniel-dodd opened this issue Aug 25, 2022 · 2 comments
Labels
enhancement New feature or request
Milestone

Comments

@daniel-dodd
Copy link
Member

Perhaps we should not require parameter transformations (via transform) in objective functions - they should rest with model training.

If I have an ELBO or the marginal log-likelihood, shouldn't I just be able to pass my parameters to it without transforming anything?

For example, for the marginal log-likelihood of GP regression, in the fit abstraction we currently have to objective function as objective = posterior.mll(D, transformation, negative=True) and abstractions.py defines a loss function to train the parameters (that stops gradients):

def loss(params):
        params = trainable_params(params, trainables)
        return objective(params)

Perhaps it would be nicer to define an objective with objective = posterior.mll(D, negative=True) (i.e. no transforms specified) and then have the transform in the training loop instead e.g.,

def loss(params): 
        params = trainable_params(params, trainables)
        params = transform(params, transform)
        return objective(params)

The training loop could even possibly take a bijector argument and abstractions.py could manage forward and reverse transformations (gpx.initialise could return a dictionary of bijectors instead of the constrainer and unconstrainer convention).

@daniel-dodd daniel-dodd added the enhancement New feature or request label Aug 25, 2022
@thomaspinder thomaspinder added this to the v0.5.0 milestone Sep 1, 2022
@murphyk
Copy link

murphyk commented Sep 9, 2022

On a related note, it seems that the log-det-jacobian term is missing from the objective function that is passed to HMC.
(At least I could not see the string 'jacfwd' or 'jacrev' anywhere in the codebase :)
The importance of this term is illustrated in https://github.com/probml/pyprobml/blob/master/notebooks/book2/03/change_of_variable_hmc.ipynb

daniel-dodd added a commit that referenced this issue Sep 17, 2022
The PR refactors transformations #105.

Transformations have been removed from objective functions e.g. the marginal likelihood. So instead of e.g., objective = posterior.mll(D, transformation, negative=True), we now do objective = posterior.mll(D, negative=True). The transformations are instead placed with training loops.

To remove clutter, we remove the dictionary constrainers and unconstrainers notion, and instead define transformations via bijectors. In particular:

The ParameterState dataclass now comprises params,trainablesand bijectors. So e.g. calling parameter_state = gpx.initialise(posterior, key)and then parameter_state.unpack() returns a tuple params, trainables, bijectors.

The function gpx.transform has been removed. To transform parameters to the constrained space, we now do gpx.constrain(params, bijectors) and likewise to transform to the unconstrained space, we do gpx.unconstrain(params, bijectors).

Finally, we now define training loops in abstractions e.g. fit directly on a ParameterState dataclass (i.e., we don't pass through the params dictionary, trainables dictionary and the transformations, as we used to do).
@daniel-dodd
Copy link
Member Author

Thanks for spotting this @murphyk (and apologies for the slow reply)! We recently refactored GPJax, first by removing transformations from all objects in v0.5.0. Then, our most recent release v0.5.2 removes priors from the "marginal likelihood", leaving it as a probability transition kernel only (#153). Priors transformations are now left to the users, and could for example be handled via transformed.dist as demonstrated in the tfp integration notebook example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants