-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformations #109
Transformations #109
Conversation
Codecov Report
@@ Coverage Diff @@
## v0.5_update #109 +/- ##
===============================================
+ Coverage 98.95% 99.05% +0.09%
===============================================
Files 13 13
Lines 961 954 -7
===============================================
- Hits 951 945 -6
+ Misses 10 9 -1
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
gpjax/parameters.py
Outdated
def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: | ||
"""Transform the parameters to the constrained space for corresponding bijectors. | ||
|
||
Args: | ||
params (tp.Dict): The parameter set for which transformations should be derived from. | ||
params (tp.Dict): The parameters that are to be transformed. | ||
transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. | ||
foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). | ||
|
||
Returns: | ||
tp.Tuple[tp.Dict, tp.Dict]: A pair of dictionaries. The first dictionary maps each parameter to a function that constrains the parameter. The second dictionary maps each parameter to a function that unconstrains the parameter. | ||
tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. | ||
""" | ||
|
||
def forward(bijector): | ||
return bijector.forward | ||
map = lambda param, trans: trans.forward(param) | ||
|
||
def inverse(bijector): | ||
return bijector.inverse | ||
return jax.tree_util.tree_map(map, params, bijectors) | ||
|
||
bijectors = build_bijectors(params) | ||
|
||
constrainers = jax.tree_util.tree_map(lambda _: forward, deepcopy(params)) | ||
unconstrainers = jax.tree_util.tree_map(lambda _: inverse, deepcopy(params)) | ||
|
||
constrainers = jax.tree_util.tree_map(lambda f, b: f(b), constrainers, bijectors) | ||
unconstrainers = jax.tree_util.tree_map( | ||
lambda f, b: f(b), unconstrainers, bijectors | ||
) | ||
|
||
return constrainers, unconstrainers | ||
|
||
|
||
def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict: | ||
"""Transform the parameters according to the constraining or unconstraining function dictionary. | ||
def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be a function of params
and bijectors
, or simply ParameterState
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a nice suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have a utility function call constrain_state(parameter_state: ParameterState)
and uncontrain_state(parameter_state: ParameterState)
that have the form:
def constrain_state(parameter_state: ParameterState):
return constrain(parameter_state.params, parameter_state.bijectors)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could replace this:
def loss(params):
params = trainable_params(params, trainables)
params = constrain(params, bijectors)
return objective(params)
with something like?
evaluate(objective, parameter_state)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we need to decide, withing GPJax, do we pass around params as a dictionary, or ParameterState
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. Doing things via ParameterState
is clean and user friendly. It is easy to see the transformations and which parameters are currently trainable. Working with 3 unpacked dictionaries might cause more pain points if you have more than a one model on the go (having parameter state for each one would keep things tidy). The training abstractions would look cleaner. On the other hand working with 3 unpacked dictionaries is more direct with regards to the operations being applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether this is the point at which the package should split - GPJax works directly with the dictionaries and an abstracted package, such as TuneGP, works with the ParameterState
object to give a more user-friendly interface.
@thomaspinder The MCMC needs sorting out in classification and TFP notebooks (see built docs). Looking at the old docs, I realise that previously, we didn't transform the parameters to the unconstrained space before passing it to the training loop, and now we do. The trace plots look bad. Moreover, if we do adopt a |
I've approved everything here. I'm hesistant about merging straight into master though. Can you rebase your natgrads branch with this, or do you require it to be on master? |
All notebooks are updated, except the tensorflow probability and MCMC section of the classification notebook.
This reverts commit 7d9ed4d.
eb853c9
to
5a01e63
Compare
The PR refactors transformations #105.
Transformations have been removed from objective functions e.g. the marginal likelihood. So instead of e.g.,
objective = posterior.mll(D, transformation, negative=True)
, we now doobjective = posterior.mll(D, negative=True)
. The transformations are instead placed with training loops.To remove clutter, we remove the dictionary
constrainers
andunconstrainers
notion, and instead define transformations via bijectors. In particular:The
ParameterState
dataclass now comprisesparams
,trainables
andbijectors
. So e.g. callingparameter_state = gpx.initialise(posterior, key)
and thenparameter_state.unpack()
returns a tupleparams, trainables, bijectors
.The function
gpx.transform
has been removed. To transform parameters to the constrained space, we now dogpx.constrain(params, bijectors)
and likewise to transform to the unconstrained space, we dogpx.unconstrain(params, bijectors)
.Finally, we now define training loops in abstractions e.g.
fit
directly on aParameterState
dataclass (i.e., we don't pass through the params dictionary, trainables dictionary and the transformations, as we used to do).