Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <samawitty@gmail.com> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. --------- Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com> Co-authored-by: Eli <eli@elibingham.com> Co-authored-by: Sam Witty <samawitty@gmail.com> Co-authored-by: Raj Agrawal <r.agrawal@csail.mit.edu> Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
- Loading branch information