Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformations #109

Merged
merged 13 commits into from
Sep 17, 2022
Merged

Transformations #109

merged 13 commits into from
Sep 17, 2022

Conversation

daniel-dodd
Copy link
Member

The PR refactors transformations #105.

Transformations have been removed from objective functions e.g. the marginal likelihood. So instead of e.g., objective = posterior.mll(D, transformation, negative=True), we now do objective = posterior.mll(D, negative=True). The transformations are instead placed with training loops.

To remove clutter, we remove the dictionary constrainers and unconstrainers notion, and instead define transformations via bijectors. In particular:

  • The ParameterState dataclass now comprises params,trainablesand bijectors. So e.g. calling parameter_state = gpx.initialise(posterior, key)and then parameter_state.unpack() returns a tuple params, trainables, bijectors.

  • The function gpx.transform has been removed. To transform parameters to the constrained space, we now do gpx.constrain(params, bijectors) and likewise to transform to the unconstrained space, we do gpx.unconstrain(params, bijectors).

Finally, we now define training loops in abstractions e.g. fit directly on a ParameterState dataclass (i.e., we don't pass through the params dictionary, trainables dictionary and the transformations, as we used to do).

@codecov
Copy link

codecov bot commented Aug 31, 2022

Codecov Report

Merging #109 (eb853c9) into v0.5_update (246addf) will increase coverage by 0.09%.
The diff coverage is 100.00%.

❗ Current head eb853c9 differs from pull request most recent head ce79de4. Consider uploading reports for the commit ce79de4 to get more accurate results

@@               Coverage Diff               @@
##           v0.5_update     #109      +/-   ##
===============================================
+ Coverage        98.95%   99.05%   +0.09%     
===============================================
  Files               13       13              
  Lines              961      954       -7     
===============================================
- Hits               951      945       -6     
+ Misses              10        9       -1     
Flag Coverage Δ
unittests 99.05% <100.00%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
gpjax/__init__.py 100.00% <100.00%> (ø)
gpjax/abstractions.py 100.00% <100.00%> (+1.16%) ⬆️
gpjax/gps.py 100.00% <100.00%> (ø)
gpjax/parameters.py 95.68% <100.00%> (-0.35%) ⬇️
gpjax/variational_inference.py 97.56% <100.00%> (-0.09%) ⬇️
gpjax/kernels.py 98.68% <0.00%> (ø)
gpjax/likelihoods.py 100.00% <0.00%> (ø)
... and 1 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@daniel-dodd daniel-dodd marked this pull request as ready for review August 31, 2022 15:03
Comment on lines 142 to 159
def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict:
"""Transform the parameters to the constrained space for corresponding bijectors.

Args:
params (tp.Dict): The parameter set for which transformations should be derived from.
params (tp.Dict): The parameters that are to be transformed.
transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set.
foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False).

Returns:
tp.Tuple[tp.Dict, tp.Dict]: A pair of dictionaries. The first dictionary maps each parameter to a function that constrains the parameter. The second dictionary maps each parameter to a function that unconstrains the parameter.
tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary.
"""

def forward(bijector):
return bijector.forward
map = lambda param, trans: trans.forward(param)

def inverse(bijector):
return bijector.inverse
return jax.tree_util.tree_map(map, params, bijectors)

bijectors = build_bijectors(params)

constrainers = jax.tree_util.tree_map(lambda _: forward, deepcopy(params))
unconstrainers = jax.tree_util.tree_map(lambda _: inverse, deepcopy(params))

constrainers = jax.tree_util.tree_map(lambda f, b: f(b), constrainers, bijectors)
unconstrainers = jax.tree_util.tree_map(
lambda f, b: f(b), unconstrainers, bijectors
)

return constrainers, unconstrainers


def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict:
"""Transform the parameters according to the constraining or unconstraining function dictionary.
def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be a function of params and bijectors, or simply ParameterState?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice suggestion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a utility function call constrain_state(parameter_state: ParameterState) and uncontrain_state(parameter_state: ParameterState) that have the form:

def constrain_state(parameter_state: ParameterState):
    return constrain(parameter_state.params, parameter_state.bijectors)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could replace this:

def loss(params):
        params = trainable_params(params, trainables)
        params = constrain(params, bijectors)
        return objective(params)

with something like?

evaluate(objective, parameter_state)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to decide, withing GPJax, do we pass around params as a dictionary, or ParameterState.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. Doing things via ParameterState is clean and user friendly. It is easy to see the transformations and which parameters are currently trainable. Working with 3 unpacked dictionaries might cause more pain points if you have more than a one model on the go (having parameter state for each one would keep things tidy). The training abstractions would look cleaner. On the other hand working with 3 unpacked dictionaries is more direct with regards to the operations being applied.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this is the point at which the package should split - GPJax works directly with the dictionaries and an abstracted package, such as TuneGP, works with the ParameterState object to give a more user-friendly interface.

@thomaspinder thomaspinder added this to the v0.5.0 milestone Sep 1, 2022
@thomaspinder thomaspinder added the enhancement New feature or request label Sep 1, 2022
@daniel-dodd
Copy link
Member Author

@thomaspinder The MCMC needs sorting out in classification and TFP notebooks (see built docs). Looking at the old docs, I realise that previously, we didn't transform the parameters to the unconstrained space before passing it to the training loop, and now we do. The trace plots look bad. Moreover, if we do adopt a ParameterState convention, it would be good to clean the code up in these.

@thomaspinder
Copy link
Collaborator

I've approved everything here. I'm hesistant about merging straight into master though. Can you rebase your natgrads branch with this, or do you require it to be on master?

@daniel-dodd daniel-dodd changed the base branch from master to v0.5_update September 16, 2022 12:53
@daniel-dodd daniel-dodd merged commit 7773eef into v0.5_update Sep 17, 2022
@thomaspinder thomaspinder deleted the Transformations branch October 16, 2022 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants