diff --git a/README.md b/README.md index 90ea7d8b..80a82d1d 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,9 @@ pred = la(x, link_approx='probit') torch.save(la.state_dict(), 'state_dict.bin') # Load serialized Laplace -la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full') +la2 = Laplace(model, 'classification', + subset_of_weights='all', + hessian_structure='diag') la2.load_state_dict(torch.load('state_dict.bin')) ```