Gaussian Process with parametrized mean #49
Replies: 3 comments 6 replies
-
Hi @jecampagne, Happy to see you like that JAXNS working with GPs. In fact, I plan to expand that module significantly to address exotic models. You can do Bayesian inference over any parameters, even in a hierarchical manner.
with PriorChain() as prior_chain:
# your coordinates of data
X = DeltaPrior('X', X)
# uncert prior
uncert = HalfLaplacePrior('uncert', 1.)
# kernel prior
l = UniformPrior('l', 0., 2.)
sigma = UniformPrior('sigma', 0., 2.)
cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma, tracked=False)
# mean prior
mean_xp = ForcedIdentifiabilityPrior('mu_xp', num_nodes, x_min, x_max)
mean_fp= UniformPrior('mu_fp', y_min*jnp.ones(num_nodes), y_max*jnp.ones(num_nodes))
mean = X[:,0].interp(mean_xp, mean_fp, name='mu')
def log_normal(x, mean, cov):
L = jnp.linalg.cholesky(cov)
# U, S, Vh = jnp.linalg.svd(cov)
log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))#
dx = x - mean
dx = solve_triangular(L, dx, lower=True)
# U S Vh V 1/S Uh
# pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \
- log_det \
- 0.5 * maha
return log_likelihood
def log_likelihood(K, mu, uncert):
"""
P(Y|params) = N[Y, f, K]
"""
data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
return log_normal(Y_obs, mu, K + data_cov)
|
Beta Was this translation helpful? Give feedback.
-
Hi, Now I cannot figure out how to manage to merge the whole set of data in a single GP as the parameters of the paramatrized model are different from each setup. |
Beta Was this translation helpful? Give feedback.
-
@jecampagne anything left to help on this? |
Beta Was this translation helpful? Give feedback.
-
Hi Joshua,
I have just seen your new devel and run my old example. Now, I just realise that your nice lib. is also dealing with Gaussian Processes. A long this line I have posted to the
gpax
repo this resultlook here. In fact, the point developed was to use a parametrised
mean_fn
with priors on the parameters which fix the mean-value of the MVN of the GP (ie. the kernel is the RBF or what ever you want). In particular, in my casemean_fn
is a piecewise parametrized function with a piece for t<0 and an other piece for t>0Now I have for you 2 questions:
GaussProc
class in your framework?gpax
post, and I fit a GP for each one with the possibility to give prediction on the minimum for t>0.Of course I add plenty of time series I would have probably try to setup a LSTM/GRU recurrent network for instance, but here the lack of time series make this solution unpracticable.
Any idea is welcome.
Thanks
Beta Was this translation helpful? Give feedback.
All reactions