Module laplace.baselaplace
-Classes
--
-
-class BaseLaplace -(model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None) -
--
--
Baseclass for all Laplace approximations in this library.
-Parameters
--
-
model
:torch.nn.Module
-- -
likelihood
:Likelihood
orstr in {'classification', 'regression', 'reward_modeling'}
-- determines the log likelihood Hessian approximation.
-In the case of 'reward_modeling', it fits Laplace using the classification likelihood,
-then does prediction as in regression likelihood. The model needs to be defined accordingly:
-The forward pass during training takes
x.shape == (batch_size, 2, dim)
with -y.shape = (batch_size,)
. Meanwhile, during evaluationx.shape == (batch_size, dim)
. -Note that 'reward_modeling' only supportsKronLaplace
andDiagLaplace
.
- sigma_noise
:torch.Tensor
orfloat
, default=1
-- observation noise for the regression setting; must be 1 for classification -
prior_precision
:torch.Tensor
orfloat
, default=1
-- prior precision of a Gaussian prior (= weight decay); -can be scalar, per-layer, or diagonal in the most general case -
prior_mean
:torch.Tensor
orfloat
, default=0
-- prior mean of a Gaussian prior, useful for continual learning -
temperature
:float
, default=1
-- temperature of the likelihood; lower temperature leads to more -concentrated posterior and vice versa. -
enable_backprop
:bool
, default=False
-- whether to enable backprop to the input
x
through the Laplace predictive. -Useful for e.g. Bayesian optimization.
- dict_key_x
:str
, default='input_ids'
-- The dictionary key under which the input tensor
x
is stored. Only has effect -when the model takes aMutableMapping
as the input. Useful for Huggingface -LLM models.
- dict_key_y
:str
, default='labels'
-- The dictionary key under which the target tensor
y
is stored. Only has effect -when the model takes aMutableMapping
as the input. Useful for Huggingface -LLM models.
- backend
:subclasses
ofCurvatureInterface
-- backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None. -
backend_kwargs
:dict
, default=None
-- arguments passed to the backend on initialization, for example to -set the number of MC samples for stochastic approximations. -
asdl_fisher_kwargs
:dict
, default=None
-- arguments passed to the ASDL backend specifically on initialization. -
Subclasses
- -Instance variables
--
-
prop backend : CurvatureInterface
-- - - -
prop log_likelihood : torch.Tensor
--
--
Compute log likelihood on the training data after
-.fit()
has been called. -The log likelihood is computed on-demand based on the loss and, for example, -the observation noise which makes it differentiable in the latter for -iterative updates.Returns
--
-
log_likelihood
:torch.Tensor
-- -
- prop prior_precision_diag : torch.Tensor
--
--
Obtain the diagonal prior precision p_0 constructed from either -a scalar, layer-wise, or diagonal prior precision.
-Returns
--
-
prior_precision_diag
:torch.Tensor
-- -
- prop prior_mean : torch.Tensor
-- - - -
prop prior_precision : torch.Tensor
-- - - -
prop sigma_noise : torch.Tensor
-- - - -
Methods
--
-
-def fit(self, train_loader: DataLoader) ‑> None -
-- - - -
-def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) -
-- - - -
-def predictive(self, x: torch.Tensor, pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int) -
-- - - -
-def optimize_prior_precision(self, pred_type: PredType | str, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.DIAG, val_loader: DataLoader | None = None, loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -
--
--
Optimize the prior precision post-hoc using the
-method
-specified by the user.Parameters
--
-
pred_type
:PredType
orstr in {'glm', 'nn'}
-- type of posterior predictive, linearized GLM predictive or neural -network sampling predictiv. The GLM predictive is consistent with the -curvature approximations used here. -
method
:TuningMethod
orstr in {'marglik', 'gridsearch'}
, default=PredType.MARGLIK
-- specifies how the prior precision should be optimized. -
n_steps
:int
, default=100
-- the number of gradient descent steps to take. -
lr
:float
, default=1e-1
-- the learning rate to use for gradient descent. -
init_prior_prec
:float
ortensor
, default=1.0
-- initial prior precision before the first optimization step. -
prior_structure
:PriorStructure
orstr in {'scalar', 'layerwise', 'diag'}
, default=PriorStructure.SCALAR
-- if init_prior_prec is scalar, the prior precision is optimized with this structure. -otherwise, the structure of init_prior_prec is maintained. -
val_loader
:torch.data.utils.DataLoader
, default=None
-- DataLoader for the validation set; each iterate is a training batch (X, y). -
loss
:callable
ortorchmetrics.Metric
, default=None
-- loss function to use for CV. If callable, the loss is computed offline (memory intensive).
-If torchmetrics.Metric, running loss is computed (efficient). The default
-depends on the likelihood:
RunningNLLMetric()
for classification and -reward modeling, runningMeanSquaredError()
for regression.
- log_prior_prec_min
:float
, default=-4
-- lower bound of gridsearch interval. -
log_prior_prec_max
:float
, default=4
-- upper bound of gridsearch interval. -
grid_size
:int
, default=100
-- number of values to consider inside the gridsearch interval. -
link_approx
:LinkApprox
orstr in {'mc', 'probit', 'bridge'}
, default=LinkApprox.PROBIT
-- how to approximate the classification link function for the
'glm'
. -Forpred_type='nn'
, only'mc'
is possible.
- n_samples
:int
, default=100
-- number of samples for
link_approx='mc'
.
- verbose
:bool
, default=False
-- if true, the optimized prior precision will be printed -(can be a large tensor if the prior has a diagonal covariance). -
progress_bar
:bool
, default=False
-- whether to show a progress bar; updated at every batch-Hessian computation.
-Useful for very large model and large amount of data, esp. when
subset_of_weights='all'
.
-
-
- -class ParametricLaplace -(model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None) -
--
--
Parametric Laplace class.
-Subclasses need to specify how the Hessian approximation is initialized, -how to add up curvature over training data, how to sample from the -Laplace approximation, and how to compute the functional variance.
-A Laplace approximation is represented by a MAP which is given by the -
model
parameter and a posterior precision or covariance specifying -a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). -The goal of this class is to compute the posterior precision P -which sums as - -P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) -\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. - -Every subclass implements different approximations to the log likelihood Hessians, -for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have -a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . -In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in -all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Ancestors
--
-
- BaseLaplace -
Subclasses
- -Instance variables
--
-
prop scatter : torch.Tensor
--
--
Computes the scatter, a term of the log marginal likelihood that -corresponds to L-2 regularization: -
-scatter
= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .Returns
--
-
scatter
:torch.Tensor
-- -
- prop log_det_prior_precision : torch.Tensor
--
--
Compute log determinant of the prior precision -\log \det P_0
-Returns
--
-
log_det
:torch.Tensor
-- -
- prop log_det_posterior_precision : torch.Tensor
--
--
Compute log determinant of the posterior precision -\log \det P which depends on the subclasses structure -used for the Hessian approximation.
-Returns
--
-
log_det
:torch.Tensor
-- -
- prop log_det_ratio : torch.Tensor
--
--
Compute the log determinant ratio, a part of the log marginal likelihood. - -\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 -
-Returns
--
-
log_det_ratio
:torch.Tensor
-- -
- prop posterior_precision : torch.Tensor
--
--
Compute or return the posterior precision P.
-Returns
--
-
posterior_prec
:torch.Tensor
-- -
-
Methods
--
-
-def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None -
--
--
Fit the local Laplace approximation at the parameters of the model.
-Parameters
--
-
train_loader
:torch.data.utils.DataLoader
-- each iterate is a training batch, either
(X, y)
tensors or a dict-like -object containing keys as expressed byself.dict_key_x
and -self.dict_key_y
.train_loader.dataset
needs to be set to access -N, size of the data set.
- override
:bool
, default=True
-- whether to initialize H, loss, and n_data again; setting to False is useful for -online learning settings to accumulate a sequential posterior approximation. -
progress_bar
:bool
, default=False
-- whether to show a progress bar; updated at every batch-Hessian computation.
-Useful for very large model and large amount of data, esp. when
subset_of_weights='all'
.
-
- -def square_norm(self, value) ‑> torch.Tensor -
--
--
Compute the square norm under post. Precision with
-value-self.mean
as 𝛥: - -\Delta^ -op P \Delta - -Returns
--
-
square_form
-- -
- -def log_prob(self, value: torch.Tensor, normalized: bool = True) ‑> torch.Tensor -
--
--
Compute the log probability under the (current) Laplace approximation.
-Parameters
--
-
value
:torch.Tensor
-- -
normalized
:bool
, default=True
-- whether to return log of a properly normalized Gaussian or just the
-terms that depend on
value
.
-
Returns
--
-
log_prob
:torch.Tensor
-- -
- -def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) -
--
--
Compute the Laplace approximation to the log marginal likelihood subject -to specific Hessian approximations that subclasses implement. -Requires that the Laplace approximation has been fit before. -The resulting torch.Tensor is differentiable in
-prior_precision
and -sigma_noise
if these have gradients enabled. -By passingprior_precision
orsigma_noise
, the current value is -overwritten. This is useful for iterating on the log marginal likelihood.Parameters
--
-
prior_precision
:torch.Tensor
, optional
-- prior precision if should be changed from current
prior_precision
value
- sigma_noise
:torch.Tensor
, optional
-- observation noise standard deviation if should be changed -
Returns
--
-
log_marglik
:torch.Tensor
-- -
- -def predictive_samples(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None) -
--
--
Sample from the posterior predictive on input data
-x
. -Can be used, for example, for Thompson sampling.Parameters
--
-
x
:torch.Tensor
orMutableMapping
-- input data
(batch_size, input_shape)
- pred_type
:{'glm', 'nn'}
, default='glm'
-- type of posterior predictive, linearized GLM predictive or neural -network sampling predictive. The GLM predictive is consistent with -the curvature approximations used here. -
n_samples
:int
-- number of samples -
diagonal_output
:bool
-- whether to use a diagonalized glm posterior predictive on the outputs.
-Only applies when
pred_type='glm'
.
- generator
:torch.Generator
, optional
-- random number generator to control the samples (if sampling used) -
Returns
--
-
samples
:torch.Tensor
-- samples
(n_samples, batch_size, output_shape)
-
- -def functional_variance(self, Js: torch.Tensor) ‑> torch.Tensor -
--
--
Compute functional variance for the
-'glm'
predictive: -f_var[i] = Js[i] @ P.inv() @ Js[i].T
, which is a output x output -predictive covariance matrix. -Mathematically, we have for a single Jacobian -\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} -the output covariance matrix - \mathcal{J} P^{-1} \mathcal{J}^T .Parameters
--
-
Js
:torch.Tensor
-- Jacobians of model output wrt parameters
-
(batch, outputs, parameters)
-
Returns
--
-
f_var
:torch.Tensor
-- output covariance
(batch, outputs, outputs)
-
- -def functional_covariance(self, Js: torch.Tensor) ‑> torch.Tensor -
--
--
Compute functional covariance for the
-'glm'
predictive: -f_cov = Js @ P.inv() @ Js.T
, which is a batchoutput x batchoutput -predictive covariance matrix.This emulates the GP posterior covariance N([f(x1), …,f(xm)], Cov[f(x1), …, f(xm)]). -Useful for joint predictions, such as in batched Bayesian optimization.
-Parameters
--
-
Js
:torch.Tensor
-- Jacobians of model output wrt parameters
-
(batch*outputs, parameters)
-
Returns
--
-
f_cov
:torch.Tensor
-- output covariance
(batch*outputs, batch*outputs)
-
- -def sample(self, n_samples: int = 100, generator: torch.Generator | None = None) -
--
--
Sample from the Laplace posterior approximation, i.e., - \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
-Parameters
--
-
n_samples
:int
, default=100
-- number of samples -
generator
:torch.Generator
, optional
-- random number generator to control the samples -
Returns
--
-
samples
:torch.Tensor
-- -
- -def state_dict(self) ‑> dict[str, typing.Any] -
-- - - -
-def load_state_dict(self, state_dict: dict[str, Any]) ‑> None -
-- - - -
Inherited members
- -
- -class FunctionalLaplace -(model: nn.Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x='input_ids', dict_key_y='labels', backend: type[CurvatureInterface] | None = laplace.curvature.backpack.BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0) -
--
--
Applying the GGN (Generalized Gauss-Newton) approximation for the Hessian in the Laplace approximation of the posterior -turns the underlying probabilistic model from a BNN into a GLM (generalized linear model). -This GLM (in the weight space) is equivalent to a GP (in the function space), see -Approximate Inference Turns Deep Networks into Gaussian Processes (Khan et al., 2019)
-This class implements the (approximate) GP inference through which -we obtain the desired quantities (posterior predictive, marginal log-likelihood). -See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) -for more details.
-Note that for
-likelihood='classification'
, we approximate L_{NN} with a diagonal matrix -( L_{NN} is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t. -neural network output f , See Appendix A.2.1 for exact definition). We -resort to such an approximation because of the (possible) errors found in Laplace approximation for -multiclass GP classification in Chapter 3.5 of R&W 2006 GP book, -see the question -here -for more details. Alternatively, one could also resort to one-vs-one or one-vs-rest implementations -for multiclass classification, however, that is not (yet) supported here.Parameters
--
-
num_data
:int
-- number of data points for Subset-of-Data (SOD) approximate GP inference. -
diagonal_kernel
:bool
-- GP kernel here is product of Jacobians, which results in a C \times C matrix where C is the output
-dimension. If
diagonal_kernel=True
, only a diagonal of a GP kernel is used. This is (somewhat) equivalent to -assuming independent GPs across output channels.
-
See
BaseLaplace
class for the full interface.Ancestors
--
-
- BaseLaplace -
Subclasses
- -Instance variables
--
-
prop gp_kernel_prior_variance
-- - - -
prop log_det_ratio : torch.Tensor
--
--
Computes log determinant term in GP marginal likelihood
-For
-classification
we use eq. (3.44) from Chapter 3.5 from -GP book R&W 2006 with -(note that we always use diagonal approximation D of the Hessian of log likelihood w.r.t. f):log determinant term := \log | I + D^{1/2}K D^{1/2} |
-For
-regression
, we use "standard" GP marginal likelihood:log determinant term := \log | K + \sigma_2 I |
- prop scatter : torch.Tensor
--
--
Compute scatter term in GP log marginal likelihood.
-For
-classification
we use eq. (3.44) from Chapter 3.5 from -GP book R&W 2006 with \hat{f} = f :scatter term := f K^{-1} f^{T}
-For
-regression
, we use "standard" GP marginal likelihood:scatter term := (y - m)K^{-1}(y -m )^T , -where m is the mean of the GP prior, which in our case corresponds to - m := f + J (\theta - \theta_{MAP})
- prop prior_precision
-- - - -
Methods
--
-
-def fit(self, train_loader: DataLoader | MutableMapping, progress_bar: bool = False) -
--
--
Fit the Laplace approximation of a GP posterior.
-Parameters
--
-
train_loader
:torch.data.utils.DataLoader
-train_loader.dataset
needs to be set to access N, size of the data set -train_loader.batch_size
needs to be set to access b batch_size
-progress_bar
:bool
-- whether to show a progress bar during the fitting process. -
- -def predictive_samples(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None) -
--
--
Sample from the posterior predictive on input data
-x
. -Can be used, for example, for Thompson sampling.Parameters
--
-
x
:torch.Tensor
orMutableMapping
-- input data
(batch_size, input_shape)
- pred_type
:{'glm'}
, default='glm'
-- type of posterior predictive, linearized GLM predictive. -
n_samples
:int
-- number of samples -
diagonal_output
:bool
-- whether to use a diagonalized glm posterior predictive on the outputs.
-Only applies when
pred_type='glm'
.
- generator
:torch.Generator
, optional
-- random number generator to control the samples (if sampling used) -
Returns
--
-
samples
:torch.Tensor
-- samples
(n_samples, batch_size, output_shape)
-
- -def functional_variance(self, Js_star: torch.Tensor) ‑> torch.Tensor -
--
--
GP posterior variance:
-k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
-Parameters
--
-
Js_star
:torch.Tensor
ofshape (N*, C, P)
-- Jacobians of test data points -
Returns
--
-
f_var
:torch.Tensor
ofshape (N*,C, C)
-- Contains the posterior variances of N* testing points. -
- -def functional_covariance(self, Js_star: torch.Tensor) ‑> torch.Tensor -
--
--
GP posterior covariance:
-k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
-Parameters
--
-
Js_star
:torch.Tensor
ofshape (N*, C, P)
-- Jacobians of test data points -
Returns
--
-
f_var
:torch.Tensor
ofshape (N*xC, N*xC)
-- Contains the posterior covariances of N* testing points. -
- -def optimize_prior_precision(self, pred_type: PredType | str = PredType.GP, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -
--
--
optimize_prior_precision_base
fromBaseLaplace
withpred_type='gp'
- -def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) -
--
--
Compute the Laplace approximation to the log marginal likelihood. -Requires that the Laplace approximation has been fit before. -The resulting torch.Tensor is differentiable in
-prior_precision
and -sigma_noise
if these have gradients enabled. -By passingprior_precision
orsigma_noise
, the current value is -overwritten. This is useful for iterating on the log marginal likelihood.Parameters
--
-
prior_precision
:torch.Tensor
, optional
-- prior precision if should be changed from current
prior_precision
value
- sigma_noise
:torch.Tensor
, optional
-- observation noise standard deviation if should be changed -
Returns
--
-
log_marglik
:torch.Tensor
-- -
- -def state_dict(self) ‑> dict -
-- - - -
-def load_state_dict(self, state_dict: dict) -
-- - - -
Inherited members
--
-
BaseLaplace
: - -
-
- -class FullLaplace -(model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None) -
--
--
Laplace approximation with full, i.e., dense, log likelihood Hessian approximation -and hence posterior precision. Based on the chosen
backend
parameter, the full -approximation can be, for example, a generalized Gauss-Newton matrix. -Mathematically, we have P \in \mathbb{R}^{P \times P}. -SeeBaseLaplace
for the full interface.Ancestors
- -Subclasses
- -Instance variables
--
-
prop posterior_scale : torch.Tensor
--
--
Posterior scale (square root of the covariance), i.e., -P^{-\frac{1}{2}}.
-Returns
--
-
scale
:torch.tensor
-(parameters, parameters)
-
- prop posterior_covariance : torch.Tensor
--
--
Posterior covariance, i.e., P^{-1}.
-Returns
--
-
covariance
:torch.tensor
-(parameters, parameters)
-
- prop posterior_precision : torch.Tensor
--
--
Posterior precision P.
-Returns
--
-
precision
:torch.tensor
-(parameters, parameters)
-
-
Inherited members
- -
- -class KronLaplace -(model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None) -
--
--
Laplace approximation with Kronecker factored log likelihood Hessian approximation -and hence posterior precision. -Mathematically, we have for each parameter group, e.g., torch.nn.Module, -that \P\approx Q \otimes H. -See
BaseLaplace
for the full interface and see -Kron
andKronDecomposed
for the structure of -the Kronecker factors.Kron
is used to aggregate factors by summing up and -KronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature), -and computing posterior covariances, marginal likelihood, etc. -Damping can be enabled by settingdamping=True
.Ancestors
- -Subclasses
- -Instance variables
--
-
prop posterior_precision : KronDecomposed
-- - - -
prop prior_precision : torch.Tensor
-- - - -
Methods
--
-
-def state_dict(self) ‑> dict[str, typing.Any] -
-- - - -
-def load_state_dict(self, state_dict: dict[str, Any]) -
-- - - -
Inherited members
- -
- -class DiagLaplace -(model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None) -
--
--
Laplace approximation with diagonal log likelihood Hessian approximation -and hence posterior precision. -Mathematically, we have P \approx \textrm{diag}(P). -See
BaseLaplace
for the full interface.Ancestors
- -Subclasses
- -Instance variables
--
-
prop posterior_precision : torch.Tensor
--
--
Diagonal posterior precision p.
-Returns
--
-
precision
:torch.tensor
-(parameters)
-
- prop posterior_scale : torch.Tensor
--
--
Diagonal posterior scale \sqrt{p^{-1}}.
-Returns
--
-
precision
:torch.tensor
-(parameters)
-
- prop posterior_variance : torch.Tensor
--
--
Diagonal posterior variance p^{-1}.
-Returns
--
-
precision
:torch.tensor
-(parameters)
-
-
Inherited members
- -
- -class LowRankLaplace -(model: nn.Module, likelihood: Likelihood | str, backend: type[CurvatureInterface] = laplace.curvature.curvature.CurvatureInterface, sigma_noise: float | torch.Tensor = 1, prior_precision: float | torch.Tensor = 1, prior_mean: float | torch.Tensor = 0, temperature: float = 1, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend_kwargs: dict[str, Any] | None = None) -
--
--
Laplace approximation with low-rank log likelihood Hessian (approximation). -The low-rank matrix is represented by an eigendecomposition (vecs, values). -Based on the chosen
-backend
, either a true Hessian or, for example, GGN -approximation could be used. -The posterior precision is computed as - P = V diag(l) V^T + P_0. -To sample, compute the functional variance, and log determinant, algebraic tricks -are usedto reduce the costs of inversion to the that of a K -imes K matrix -if we have a rank of K.Note that only
-AsdfghjklHessian
backend is supported. Install it via: -pip install git+https://git@github.com/wiseodd/asdl@asdfghjklSee
BaseLaplace
for the full interface.Ancestors
- -Instance variables
--
-
prop V : torch.Tensor
-- - - -
prop Kinv : torch.Tensor
-- - - -
prop posterior_precision : tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]
--
--
Return correctly scaled posterior precision that would be constructed -as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
-Returns
--
-
H
:tuple(eigenvectors, eigenvalues)
-- scaled self.H with temperature and loss factors. -
prior_precision_diag
:torch.Tensor
-- diagonal prior precision shape
parameters
to be added to H.
-
-
Inherited members
- -
-