Skip to content

Commit

Permalink
Allow to pass trainable inducing inputs to AbstractVariationalGaussian
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanocortinovis committed Nov 1, 2024
1 parent 8f14cbd commit 3152cde
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
def __init__(
self,
posterior: AbstractPosterior[P, L],
inducing_inputs: Float[Array, "N D"],
inducing_inputs: tp.Union[
Float[Array, "N D"],
Real,
Static,
],
jitter: ScalarFloat = 1e-6,
):
self.inducing_inputs = Static(inducing_inputs)
if not isinstance(inducing_inputs, (Real, Static)):
inducing_inputs = Real(inducing_inputs)

self.inducing_inputs = inducing_inputs
self.jitter = jitter

super().__init__(posterior)
Expand Down

0 comments on commit 3152cde

Please sign in to comment.