diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index f1c40bdcc..88c5f7b3c 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]): def __init__( self, posterior: AbstractPosterior[P, L], - inducing_inputs: Float[Array, "N D"], + inducing_inputs: tp.Union[ + Float[Array, "N D"], + Real, + Static, + ], jitter: ScalarFloat = 1e-6, ): - self.inducing_inputs = Static(inducing_inputs) + if not isinstance(inducing_inputs, (Real, Static)): + inducing_inputs = Real(inducing_inputs) + + self.inducing_inputs = inducing_inputs self.jitter = jitter super().__init__(posterior)