Skip to content

Commit

Permalink
One-sided runtime check of whether a dynamics instance satisfies as…
Browse files Browse the repository at this point in the history
…sumptions of a `solver` (#332)

* Operations and Handlers for ODE Dynamical Systems (#155)

* initial notebook

* initial scaffolding

* initial ODE implementation

* more progress

* untested first pass at point intervention

* first successful interruption pass

* point intervention working

* first pass of observation handler

* modified CI

* adds tests for noop interventions and interruptions, failing, fixes issue where remaining tspan was expected for unsplit tspans.

* converts sir model to reusable fixture, adds todo test for affective point intervention.

* adds test of affective single and nested point interventinos.

* reduces permutations of a point intervention test.

* added stop which seems to help...

* recomplicates nested point intervention test parameters to cover additional case.

* pseudocode refac to non-recursive approach

* adds almost correct non-recursive point interruption thing that hopefully makes dynamic interventions easier to implement, raj todo some things.

* untested code implementing slicing and concat

* fixed some typos still not tested

* finishes refactor from recursive approach and passes some tests

* passes current suite of point interruption tests.

* adds errors and tests to confirm they throw, one is broken still.

* fixed tests

* fixed observation, no tests yet

* python3.8 type test error fix? Can't reproduce locally so sort of shooting in the dark

* added basic observation structure

* observation desired interface

* addded observation functionality with sample

* added intervention to example

* some progress on dynamic events, but will refactor to abstract out a simulate_to_next_event effectful operation that keeps all the torchdiffeq idiosyncrasies in one place.

* sketches out signatures and a plan (in comments) to refactor interruptions to better support dynamic events and the abstracting away of solver idiosyncrasies

* tests for point observations

* added tiny epsilon to observation interruption

* added error checking for time being equal to any element in full timespan just in case

* lint

* restructures intervention handling to better abstract solver idiosyncrasies away, preps to implement abstracted torchdiffeq solving.

* passes existing tests of point interruptions, preps for torchdiffeq implementation supporting dynamic interruptions.

* adds test case to single static intervention ensuring intervention only changes values after its time.

* refactors utilities for constructing torchdiffeq event functions.

* breaks out todos for double pass solver execution for dynamic event handling.

* finishes double solver pass using event function, passes current test suite.

* gets dynamic interventions working (at least for a single runtime test case + manual output check)

* initial notebook

* added synthetic data simulation section

* makes first dynamic intervention test actually test something

* runs linter

* bit hacky way to perform inference but recovers true sir params

* puts intervention blocking functionality in the SEL loop, rather than leaving it up to interventions themselves.

* added some inference evaluation but need more metrics

* switching to manual generation of predictions on observation sites

* cleaned up notebook and finished initial evaluation for just demonstrating conditioning functionality

* moved tutorial to docs folder

* issue with intervention tests not passing

* refactors PointObservation to work with SEL refactor.

* includes interventions

* poisson likelihood breaking for some reason

* fixed poisson likelihood issue

* test with exit stack works too...

* changed observation method to be the same as the test that somehow passes

* failing test that reproduces notebook error

* reverting noteback back to old version now that we understand why test is breaking

* added dynamic intervention to example

* adds test of nested dynamic interventions, and noop dynamic interruptions.

* added some todos and fixed some typos

* modifies svi integration test to avoid issue arising when observation occurs at start of tspan, adds test asserting that an error is explicitly thrown for that case.

* runs linter

* gets linter mostly happy.

* fixes idiosyncratic mypy issues, including meh breakout of dispatchable getitem for trajectory.

* quick mypy fix

* flake8 changes

* added import statement to make notebook work

* reran lint

* removed scratch notebook

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>
Co-authored-by: Andy Zane <azane@SmartSource-M-Machine.local>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>

* runtime profiling for multiple point observations

* added cprofile breakdown

* added dynamic intervention to stack

* feature map version. going to switch to a neural net

* basic syntac for multi level SIR. Need to change condition_sir to own handler

* about to refactor code to use TrajectoryObservation

* refactored to use TrajectoryObservation

* inference steps run but not recovering params

* finally found bug that made inference fail. Going to clean up code now

* changes so far, still a bug I cant find since unable to recover model params

* uncomitted changes

* Update staging-dynamic with renaming from master (#213)

* Add top-level subpackage imports in indexed and interventional modules (#169)

* move cf handlers out of init

* add imports

* Replaces [@ref] Linked References w/ Manual Ref and Bib (#171)

* replaces linked refs with manual refs for now due to only highly complex and error prone alternatives available for automation

* adds sphinx bibtex to setup for tests

* Add broadcasting test cases (#168)

* Update code in deep SCM notebook (#134)

* start

* update code

* refs merge

* fix json

* grammatical fixes, re-applying these changes as the original branch got conflated with something in stash

* add plot description

* adds link to referenced pyro tutorial

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Update text and code of CEVAE notebook (#131)

* notebook and fixes

* notebook works

* temp commit that gives mae=.02 when removing no grad line when generating data

* seems to work, verbose loss output

* removed loss outputs and fixed some small typos

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Added CONTRIBUTING.md file (#172)

* initial pass that makes small modifications to CONTRIBUTING.md file in pyro

* calls lint script instead

* fixed to call clean.sh script

* replaced with clean script

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Create publish_docs.yml (#180)

* removed old files (#184)

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Added docstring for intervene following issue #113 (#182)

* added docstring for intervene following issue #113

* removed extra spacing

* adds docstring for DoMessenger

* removes generic type hint on _intervene_atom so that intervene signatures properly show up in docs.

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>

* Add a draft logo to replace placeholder (#185)

* add a draft logo to replace placeholder

* add svg

* Update text and code of mediation analysis notebook (#132)

* notebook and fixes

* notebook works

* update notebook

* working

* revert cevae to master in this branch

* pull in changes

* fix typo

* adds (live, see comment) link to backdoor example, replacing inline TODO

* properly uses relative link for backdoor notebook.

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Causal Pyro README (#170)

* copied and pasted index.rst and auto converted to markdown

* reformatted, added table to tutorials and examples, still need to add cites

* fixed small spacing typos

* added status badge

* points readme documentation links to live docs.

* making the readme an rst that matches landing page

* converted tutorial links

* converted back to table format for links

* table still messed up

* gave up on tables, now a bulleted list for examples

* added documentation section

* small changes

* added back missing bullets

* fixed to queried_model

* review with Sam for cleaning up readme over a call

* fixed typo

* small edit

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>

* links github rendered readme as getting started page of sphinx docs. (#191)

* Fix HTML Outline rendering (#188)

* fixed outline rendering

* fixes outline links in mediation notebook.

* fixes outline and links for backdoor notebook.

* fixes outline links in cevae notebook.

* fixes slc notebook outline links.

* adds outline back into deep scm notebook.

* address remaining reference issues, building now with now warnings

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Design Notes Concat (#174)

* first pass at stitching design notes together into cohesive contributor tutorial.

* cleans up local bibs, adds omega citation.

* updates mention of omega and predicate conditioning

* re-integrates the original point being made at the top of observations.rst

* renames things to be less redundant

* Change name of package throughout codebase (#192)

* code and tests

* scripts

* remove

* code misses

* docs

* Causal Pyro

* causal pyro

* lint

* Create GitHub Actions workflow for PyPI (#181)

* Update setup.py arguments (#193)

* Update setup.py args

* classifier

* Update publish_docs.yml

* trigger docs deployment

* configure docs

* Add requirements.txt for docs workflow (#194)

* pandoc

* Add observe operation and new condition handler (#175)

* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* fix particle test case

* fix bug

* chirho

* Add path to conf.py for docs workflow (#195)

* Refactor counterfactuals to use observe and condition (#176)

* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* Refactor counterfactuals to use observe

* appease mypy

* Vindex fixes particle errors

* update backdoor

* update slc

* fix particle test case

* add cf commutativity test

* fix bug

* revert slc handler order

* add predictive smoke test

* nit

* elbo

* reorder test

* Add a stronger infer_discrete test

* move notebooks to separate branch

* test

* chirho

* merge fail

* Update and re-run example notebooks with new condition (#178)

* Update and re-run backdoor and SLC notebooks

* deepscm

* cevae

* import

* mediation

* merge

* update notebooks

* merge

* merge 2

* toc

* populate autodoc

* tweak

* Restores (via cherry-pick) Notebook Link and Formatting Changes (#205)

* fixed outline rendering

* fixes outline links in mediation notebook.

* fixes outline and links for backdoor notebook.

* fixes outline links in cevae notebook.

* fixes slc notebook outline links.

* adds outline back into deep scm notebook.

* address remaining reference issues, building now with now warnings

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

---------

Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Fixed symbolic link in tutorial (#206)

* fixed symbolic link in tutorial

* Revert changes to figure paths

* Reverted accidental newline from previous commit

* Fix image duplication in intro tutorial (#209)

* Fix image duplication

* fix the one other warning in the docs build

* Bump version to 0.1.0-alpha (#208)

* Bump prerelease version to 0.1.0-alpha

* Comply with pep440

* update paths

* imports

* notebook

* condition

* lint

---------

Co-authored-by: Andy Zane <azane@nnu.edu>
Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Raj Agrawal <r.agrawal@csail.mit.edu>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Support scalar multi-world counterfactuals in dynamical systems (#214)

* Add top-level subpackage imports in indexed and interventional modules (#169)

* move cf handlers out of init

* add imports

* Replaces [@ref] Linked References w/ Manual Ref and Bib (#171)

* replaces linked refs with manual refs for now due to only highly complex and error prone alternatives available for automation

* adds sphinx bibtex to setup for tests

* Add broadcasting test cases (#168)

* Update code in deep SCM notebook (#134)

* start

* update code

* refs merge

* fix json

* grammatical fixes, re-applying these changes as the original branch got conflated with something in stash

* add plot description

* adds link to referenced pyro tutorial

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Update text and code of CEVAE notebook (#131)

* notebook and fixes

* notebook works

* temp commit that gives mae=.02 when removing no grad line when generating data

* seems to work, verbose loss output

* removed loss outputs and fixed some small typos

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Added CONTRIBUTING.md file (#172)

* initial pass that makes small modifications to CONTRIBUTING.md file in pyro

* calls lint script instead

* fixed to call clean.sh script

* replaced with clean script

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Create publish_docs.yml (#180)

* removed old files (#184)

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Added docstring for intervene following issue #113 (#182)

* added docstring for intervene following issue #113

* removed extra spacing

* adds docstring for DoMessenger

* removes generic type hint on _intervene_atom so that intervene signatures properly show up in docs.

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>

* Add a draft logo to replace placeholder (#185)

* add a draft logo to replace placeholder

* add svg

* Update text and code of mediation analysis notebook (#132)

* notebook and fixes

* notebook works

* update notebook

* working

* revert cevae to master in this branch

* pull in changes

* fix typo

* adds (live, see comment) link to backdoor example, replacing inline TODO

* properly uses relative link for backdoor notebook.

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Causal Pyro README (#170)

* copied and pasted index.rst and auto converted to markdown

* reformatted, added table to tutorials and examples, still need to add cites

* fixed small spacing typos

* added status badge

* points readme documentation links to live docs.

* making the readme an rst that matches landing page

* converted tutorial links

* converted back to table format for links

* table still messed up

* gave up on tables, now a bulleted list for examples

* added documentation section

* small changes

* added back missing bullets

* fixed to queried_model

* review with Sam for cleaning up readme over a call

* fixed typo

* small edit

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>

* links github rendered readme as getting started page of sphinx docs. (#191)

* Fix HTML Outline rendering (#188)

* fixed outline rendering

* fixes outline links in mediation notebook.

* fixes outline and links for backdoor notebook.

* fixes outline links in cevae notebook.

* fixes slc notebook outline links.

* adds outline back into deep scm notebook.

* address remaining reference issues, building now with now warnings

---------

Co-authored-by: Andy Zane <andy@basis.ai>

* Design Notes Concat (#174)

* first pass at stitching design notes together into cohesive contributor tutorial.

* cleans up local bibs, adds omega citation.

* updates mention of omega and predicate conditioning

* re-integrates the original point being made at the top of observations.rst

* renames things to be less redundant

* Change name of package throughout codebase (#192)

* code and tests

* scripts

* remove

* code misses

* docs

* Causal Pyro

* causal pyro

* lint

* Create GitHub Actions workflow for PyPI (#181)

* Update setup.py arguments (#193)

* Update setup.py args

* classifier

* Update publish_docs.yml

* trigger docs deployment

* configure docs

* Add requirements.txt for docs workflow (#194)

* pandoc

* Add observe operation and new condition handler (#175)

* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* fix particle test case

* fix bug

* chirho

* Add path to conf.py for docs workflow (#195)

* Refactor counterfactuals to use observe and condition (#176)

* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* Refactor counterfactuals to use observe

* appease mypy

* Vindex fixes particle errors

* update backdoor

* update slc

* fix particle test case

* add cf commutativity test

* fix bug

* revert slc handler order

* add predictive smoke test

* nit

* elbo

* reorder test

* Add a stronger infer_discrete test

* move notebooks to separate branch

* test

* chirho

* merge fail

* Update and re-run example notebooks with new condition (#178)

* Update and re-run backdoor and SLC notebooks

* deepscm

* cevae

* import

* mediation

* merge

* update notebooks

* merge

* merge 2

* toc

* populate autodoc

* tweak

* Restores (via cherry-pick) Notebook Link and Formatting Changes (#205)

* fixed outline rendering

* fixes outline links in mediation notebook.

* fixes outline and links for backdoor notebook.

* fixes outline links in cevae notebook.

* fixes slc notebook outline links.

* adds outline back into deep scm notebook.

* address remaining reference issues, building now with now warnings

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

---------

Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Fixed symbolic link in tutorial (#206)

* fixed symbolic link in tutorial

* Revert changes to figure paths

* Reverted accidental newline from previous commit

* Fix image duplication in intro tutorial (#209)

* Fix image duplication

* fix the one other warning in the docs build

* Bump version to 0.1.0-alpha (#208)

* Bump prerelease version to 0.1.0-alpha

* Comply with pep440

* update paths

* imports

* notebook

* condition

* support dynamical counterfactuals on scalar variables

* replace tensor with stack

* type

* lint

---------

Co-authored-by: Andy Zane <azane@nnu.edu>
Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Raj Agrawal <r.agrawal@csail.mit.edu>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Merging in Dynamic Demo to Simplify Branching for Parallel Work (#215)

* adds and tests a non-interrupting observation handler

* swaps out new handler in performance analysis of observation handler.

* recovering params when beta and gamma fixed across all regions

* adds batched handler for non interrupting observations, passes equivalency test of logprob and shape.

* inference working now, going to refactor and remove out covariates

* uncomitted changes

* adds auto broadcasting for the vectorized point observation.

* adds check to not call to_event on deterministic 'sample' operation

* working demo end-to-end

* finished pass

* finished demo

* added uncertainty over intervention assignments

* finished notebook run w/ new uncertain interventions

* undid some linting to make more readable

* lints

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Cleanup of Performance Improvements for ODE Conditioning (#218)

* removes non vectorized interrupting point observation

* cleans up tests of NonInterruptingPointObservation

* removes redundant test now that non vectorized non interrupting observation handler is removed.

* changes order of point observation in equivalence test to make sure non issue.

* lints

* Fix PyroModule state persistence in simulate (#220)

* Fixes for PyroModule

* remove t

* add unit test

* Test Composition of Dynamic Counterfactual and Observation (#219)

* removes non vectorized interrupting point observation

* cleans up tests of NonInterruptingPointObservation

* removes redundant test now that non vectorized non interrupting observation handler is removed.

* changes order of point observation in equivalence test to make sure non issue.

* lints

* adds file for handler composition tests.

* brokenish sketch of possible composition of counterfactual and dynamic point observation case.

* gets hacky poc for dynamical counterfactual query w post intervention noise working

* adds trajectory and state getitem support for broadcastable boolean masks, removes .to_event hack in vectorized non interrupting point observation, fails tests due to latter.

* modifies tests to reflect requirement of manual shape management

* adds shape and inference smoke test composing many dynamical handlers.

* lints

* adds comment explaining the handler blocking that induces the conditional model.

* Fix Gradient Propagation Through Dynamic Interventions (#227)

* fixes issue where gradients weren't propagating through dynamic interventions.

* lints

* Added counterfactual example to the demo (#223)

* removes non vectorized interrupting point observation

* cleans up tests of NonInterruptingPointObservation

* removes redundant test now that non vectorized non interrupting observation handler is removed.

* changes order of point observation in equivalence test to make sure non issue.

* lints

* adds file for handler composition tests.

* changed to condition only

* brokenish sketch of possible composition of counterfactual and dynamic point observation case.

* gets hacky poc for dynamical counterfactual query w post intervention noise working

* adding basic structure for counterfactual

* adds trajectory and state getitem support for broadcastable boolean masks, removes .to_event hack in vectorized non interrupting point observation, fails tests due to latter.

* counterfactual example seems to work, need to clean up notebook

* modifies tests to reflect requirement of manual shape management

* working cf demo

* different times for superspreader

* changed cf values

* adds shape and inference smoke test composing many dynamical handlers.

* lints

* adds comment explaining the handler blocking that induces the conditional model.

* adjusts counterfactual story to have partially shared noise between worlds based on a screening machine story, still needs updates to text and comments.

* gets initial conditions set correctly

* adds comment on cf guide

---------

Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Undo mistaken edit in staging-dynamic merge

* Undo mistaken edit during staging-dynamic merge 2

* undo mistaken edit during staging-dynamic merge - 3

* Remove duplicate internals docs from merge with staging-dynamic

* Fix error from staging-dynamic merge in docs

* undo linting error from merge

* fixed test to reflect PyroModule changes (#242)

* fixed test to reflect PyroModule changes

* lint

* Merge master into staging-dynamic (#284)

* Add an effect handler for inserting factor statements (#238)

* Add a Factors handler for inserting new factors

* docs and test

* lint

* Add a BiasedPreemptions handler (#239)

* Add BiasedPreemptions handler

* test and tiny refactoring

* updated copywrite (#254)

* pin sphinx version to 7.1.2 to fix docs building (#256)

* pin sphinx version to 7.1.2 fixes doc building locally

* remove pypandoc

* Update requirements.txt

* Update setup.py

* pin sphinx_rtd_theme==1.3.0 to fix docs deployment (#262)

* Fix inference in backdoor adjustment example (#258)

* Fix inference in backdoor adjustment example

* text and matching ci

* typo

* device

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* Fixed typo (#267)

* changed headers for examples to be consistent (#271)

* Make tutorial run end to end (#274)

* Eli's data type and condition fixes

* convert to floats and convert to ints when indexing

* clean up

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Single-Cell RNA Seq Example (#265)

* data processing code

* model doesnt compile yet

* fixed dimensions but module naming issue

* before switching to interaction version of model

* presentation w/ alvin

* cleaned up summary stats section and focus on one drug

* before switching to multiworld for ate

* eli shape corrections

* uncomitted changes w/ eli re. shapes

* eli's fix of counterfactual

* clean up

* rewording

* change of notation and reran notebook

* removed old sentence

* fixed typo

* removed some masking and redundancy

* made dropout a function of confounders too

* Ra aq scrna (#263)

* Add an effect handler for inserting factor statements (#238)

* Add a Factors handler for inserting new factors

* docs and test

* lint

* Add a BiasedPreemptions handler (#239)

* Add BiasedPreemptions handler

* test and tiny refactoring

* updated copywrite (#254)

* pin sphinx version to 7.1.2 to fix docs building (#256)

* pin sphinx version to 7.1.2 fixes doc building locally

* remove pypandoc

* Update requirements.txt

* Update setup.py

* pin sphinx_rtd_theme==1.3.0 to fix docs deployment (#262)

* 1. add preprocessing method for expression normalization
2. test more HDAC drugs
3. reproduce raj results

* add pseudobulk comparison

* add T>0 for theta_drug; clean up notebook for one drug with high IC50

* change T>0 to T for theta_drug; did a bunch of experiments for library size, not improved yet

* use corrected gene expression for ate estimation

* replace the likelihood to Poisson; it seems working for HDACi

---------

Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* cleaned up notebook and fixed some bugs

* fixed some bugs and cleaned up references and some notation

* fixed drug typo and removed treatment mask

* query models, add latent factor per cell

* removed old file, did not finish running on all genes so moving to gcp

* cleaned up some more text

* target genes specified, kernel crashes on last fig

* fixed missing ref

* reducted num samples to fix memory issue

* fixed typo

* updated target genes w/ chip targets and fixed plotting bug

* added alvins descriptions + some rewriting

* fixed some typos

* added pca discussion

* forgot to save when committing

* fixed typo

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Alvin Qin <qinqian@users.noreply.github.com>
Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Sam Witty <samawitty@gmail.com>

* Document counterfactual operations and handlers (#275)

* counterfactual ops docs

* typo

* minor edits

* lint

* more linting

* docs for BaseCounterfactualMessenger

* docs for SingleWorldCounterfactual

* a bit more detail about marginal distributions

* lint and minor edits

* added example usage for counterfactual handlers

* docs for multiworld counterfactual

* added pass of TwinWorldCounterfactual documentation

* Add RNAseq example to readme and docs (#277)

* Add RNAseq example to readme and docs

* fix header level

* Fix docstring propagation in counterfactual ops (#278)

* Fix docstring propagation in counterfactual ops

* lint

* Fix TOC link for single RNA seq example notebook (#280)

* fix toc link

* fixed toc and equation rendering issues

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Bump version to 0.1.0 (#279)

* Adding undo_split and a test thereof (+small lint) (#264)

* added undo_split and a test thereof (+small lint)

* moved undo_split, added a test

* unlinting counterfactual

* unlinting ops

* parametrize test WIP

* parametrizing test WIP

* reverting linting on test_internals.p

* isort fix on test_internals.py

* parametrizing test WIP

* linting WIP

* tests linted

* cleanup

* removed implicit Optional

* isort/lint test_internals.py

* reverting test_internals to orginal

* black linted no -l 120, to prevent github check failure

* make lint, make format

* format

* fix multi-antecedent case and format

* remove obsolete comments

* appease flake8

* add to sphinx

---------

Co-authored-by: Eli <eli@basis.ai>

* Replace Preemptions handler implementation (#250)

* Replace Preemptions implementation with BiasedPreemption

* format

* remove nondeterministic default behavior for preempt

* address comments

* Ru consequent differs (#268)

* added undo_split and a test thereof (+small lint)

* moved undo_split, added a test

* unlinting counterfactual

* unlinting ops

* parametrize test WIP

* parametrizing test WIP

* reverting linting on test_internals.p

* isort fix on test_internals.py

* parametrizing test WIP

* linting WIP

* tests linted

* cleanup

* removed implicit Optional

* isort/lint test_internals.py

* setting up  consequent differs for pr

* reverting test_internals.py to original

* reverting test_internals to orginal

* consequent_differs WIP

* consequent differs WIP

* consequent differs WIP

* consequent_differs in explanation, docstring

* added consequent_differs and a test for it

* black linted no -l 120, to prevent github check failure

* lint

* make lint, make format

* make lint, make format

* format

* fix multi-antecedent case and format

* remove obsolete comments

* appease flake8

* add to sphinx

* fix event_dim>0 in tests

---------

Co-authored-by: Eli <eli@basis.ai>

---------

Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Raj Agrawal <r.agrawal@csail.mit.edu>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Alvin Qin <qinqian@users.noreply.github.com>
Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>
Co-authored-by: Eli <eli@basis.ai>

* Refactor dynamical module file structure to be consistent with other modules (#286)

* factored out generic Dynamics class from ODEDynamics and moved torchdiffeq specific implementations out of ODEDynamics class

* update dynamics test imports

* moved non user-facing objects and operations to internals.py

* split out operations, handlers, and internals

* moved around ODE operations and handlers

* lint

* lint again

* Decouple ODEDynamics class from TorchDiffEq (#290)

* factored out generic Dynamics class from ODEDynamics and moved torchdiffeq specific implementations out of ODEDynamics class

* update dynamics test imports

* moved non user-facing objects and operations to internals.py

* split out operations, handlers, and internals

* moved around ODE operations and handlers

* lint

* lint again

* separated effectful operations into hidden dispatch

* change kwargs in test to args

* added note

* separated out TorchDiffEq Backend

* tidy up tests

* remove comment

* lint

* lint

* added options to TorchDiffEqBackend

* moved handlers.py into folder

* add files to __init__.py for handlers

* lint

* seperated interruptions from simulator event loop

* added placeholder BackendHandler class

* added BackendHandler and edited tests accordingly

* lint

* lint again

* remove unnecessary type check

* renamed backend to solver throughout

* moved functionality out of __init__.py

* updated and reran notebook

* lingering test without solver kwarg. Probably isn't used anywhere?

* missing import

* Consolidate `Backend` and `SolverHandler` into a single `Solver` effect handler / dispatch mechanism (#292)

* replaced ODEBackend with ODESolver

* lingering renaming of Backend to Solver and removal of unnecessary Solver handler

* typo in notebook

* remove dupliace Solver

* dummy commit to trigger linting?

* fixed lint error

---------

Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>

* Decouple `SimulatorEventLoop` from concatenation of trajectories (#293)

* initial scaffolding

* remove unnecessary TypeVar

* changed function signature of . Will Break tests

* progress towards refactor

* marked where we left off

* got state-based SEL working

* added default if nodyn

* moved torch agnostic implementation out of torchdiffeq module

* separated searching from dynamic events from simulating. This will be useful when we log

* pulled out interruption logic from simulation

* replaced name of Point with Static and fixed linter

* renamed TrajectoryLogging to DynamicTrace

* got DynamicTrace working

* added more testing

* lint

* lint

* got noop interruptions tests working

* lint

* more progress on NonInterruptionPointObservationArray and tests

* partial lint and remove comment

* fixed test_solver

* got dynamic interventions tests working

* got static intervention tests working

* fixed test_static_observation

* lint

* got handler composition test working

* all tests pass

* lint

* lint

* finally fixed linter version

* revise notebook

* added missing odeint_kwargs from

* added missing odeint_kwargs from (#296)

* Remove `ODEDynamics` subclass and corresponding `simulate` indirection through `ode_simulate` (#298)

* added missing odeint_kwargs from

* removed ODEDynamics and resolved circular imports

* Migrate torchdiffeq dependency from "extras" to "install_requires" (#299)

* added missing odeint_kwargs from

* Migrate torchdiffeq dependency from extras to install_requires

* move dynamical dependencies to its own module and add to CI

* removed hanging space...

* Simplify State and Trajectory types (#301)

* Simplify State and Trajectory

* remove comment

* Remove unused unsqueeze function (#302)

* call _reset on __enter__ (#304)

* Clean up Trajectory.append (#303)

* generic append

* fix bug in reset

* revert

* space

* fix test

* remove staging-dynamic from CI tests for PR to master

* remove staging-dynamic from lint GitHub CI for PR to master

* Temporarily restore CI

* Add default behavior for `X.t` with `torchdiffeq` solver. (#307)

* add time to state in torchdiffeq _deriv method

* modified tests to include simple change with time

* lint

* Simplify SimulatorEventLoop logic (#309)

* Simplify interruptions

* remove shallow

* reoder

* types and lint

* types

* nit

* fix static logic

* Rename NonInterruptingPointObservationArray to StaticBatchObservation (#311)

* Rename NonInterruptingPointObservationArray

* lint

* Rename Dynamics to InPlaceDynamics (#310)

* Rename Dynamics to InPlaceDynamics

* type

* remove unused type variable

* union

* remove dynamics for now

* Remove obsolete test file (#314)

* Remove unnecessary kwargs from interruptions (#313)

* Move dynamical backend interface into one file (#312)

* consolidate backend interface in one file

* consolidate patterns

* rename indexed to _patterns

* rename to backend

* rename pattern

* rename pattern and lint

* import

* reorder file

* remove deleted file

* fix merge

* fix error

* lint

* dead code

* move append and rename _utils

* get_solver

* fix

* Remove some trivial dynamical test cases (#315)

* Remove trivial test cases

* nit

* expand

* reordered state and dstate in Dynamics (#316)

* Make Trajectory methods use indexed ops (#317)

* refactor trajectory.__getitem__ to use gather

* Remove trivial test cases

* nit

* expand

* add back len

* save

* Rename dynamical submodules and components (#319)

* refactor trajectory.__getitem__ to use gather

* Remove trivial test cases

* nit

* expand

* add back len

* save

* move interruptions into one file

* fix imports

* rename dynamical to event_loop

* rename dynamictrace to logtrajectory

* fix import

* remove circular import

* typing nits

* rename solver modules

* add missing files

* add back check

* first ops renaming

* ops file

* revert changes to staticbatchobservation

* rename trace->trajectory

* comments

* rename test file

* Remove var_order attribute from State interface (#322)

* Move State.var_order to internals

* utils

* Update _utils.py

* Fix generic types and arguments of LogTrajectory and StaticBatchObservation (#324)

* various fixes to logtrajectory

* nit

* post

* Remove usage of Trajectory.__len__  (#323)

* remove len

* fix merge

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* Remove Trajectory.__getitem__ method (#325)

* remove len

* remove trajectory getitem

* remove getitem method

* fix almost all tests

* lint

* fix remaining tests

* Remove Trajectory.to_state method (#326)

* remove len

* remove trajectory getitem

* remove getitem method

* fix almost all tests

* lint

* fix remaining tests

* remove Trajectory.to_state

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* Separate Observable from InPlaceDynamics interface (#330)

* Separate Observable and Dynamics

* trajectory

* remove inplacedynamics

* Remove Trajectory type (#327)

* remove len

* remove trajectory getitem

* remove getitem method

* fix almost all tests

* lint

* fix remaining tests

* remove Trajectory.to_state

* delete trajectory type

* remove typevar

* address comments

* fix merge error

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* Move keys property of State into a helper function get_keys (#331)

* Move keys property of state into a function

* Fix merge

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* dispatch for checking dynamics meets a particular solver

* torchdiffeq implementation of check_dynamics

* effect handler for checking calls to simulate at runtime

* test

* lint

* minor fix in NotImplementedError

* re lint

* fix get_keys

* add RuntimeCheckDynamics to init

* Remove stale skipped dynamical systems tests (#339)

* Remove stale skipped tests

* lint

* Add dynamical sphinx config (#342)

* Add dynamical sphinx config

* add notebook to index and fix rendering and indentation errors

* Use observe operation in dynamical system observation handlers (#340)

* change observation interface and fix tests

* restore stale test for merge

* lint and restore test for merge

* Update _utils.py

* Use functional interface for dynamical systems (#341)

* change observation interface and fix tests

* restore stale test for merge

* lint and restore test for merge

* Switch to functional interface for dynamical systems models

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>

* replaced explicit State with Dict (#346)

* Merge master into `staging-dynamic` (#348)

* fix typo the the (#300)

* Adding  helper function for generating stochastic interventions with approximately uniform distributions.  (#294)

* added uniform_proposal to explanation.py

* added a test for uniform_proposal

* added _uniform_proposal_indep

* added _uniform_proposal_indep

* added test for uniform_proposal_indep, lint

* added _uniform_proposal_integer and a test, lint

* added random_intervention()

* added test for random_intervention()

* removed redundant logs

* revised uniform prop, random intervention, tests

* SearchForCause effect handler with tests thereof (#297)

* started part of cause

* part of cause in progress

* second test WIP

* fixed small typos in undo_split documentation

* first emulated second test success

* added SearchOfCause

* dealing with the second test WIP

* dealing with the layered test WIP

* two-layer test succeeds

* small lint

* lint after black update

* renamed the handler to SearchForCause

* renamed the handler in the test

* tweak tests

* simplify tests

* revert

* revert

---------

Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Eli <eli@elibingham.com>

* Clean up random_intervention tests and docs (#306)

* Clean up random_intervention tests

* docstring

* import

* imports

* fix docstring

* reorder files

* Update chirho/counterfactual/handlers/explanation.py

Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>

---------

Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>

* make ci tests parallel (#345)

---------

Co-authored-by: Zenna Tavares <zennatavares@gmail.com>
Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>
Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Eli <eli@elibingham.com>

* Remove get_keys operation (#349)

* Remove get_keys operation

* missed one

* Revise dynamical systems notebook to work with recent refactoring changes (#352)

* early progress on notebook

* revised notebook to use .keys()

* working up to counterfactual

* counterfactual working

* counterfactual working

* notebook ran end-to-end

* remove unused imports

* Update lint.yml to remove staging-dynamic

* Update test.yml to remove staging-dynamic

* revised torchdiffeq check

* remove unused dispatch

* lint

* lint

* major refactor of dynamics check

* simplify runtime check

* rename

* remove kwargs

* lint

* undo accidental spacing

* another lingering spacing change

* spacing...

* wrestling with git changes...

* simplify with validate_dynamics op

* rename file

* kwargs

* add a pyro.setting to enable runtime validation of dynamics

* flag

* add setting and test

* address comments

* annotate

---------

Co-authored-by: Raj Agrawal <r.agrawal@csail.mit.edu>
Co-authored-by: Andy Zane <azane@SmartSource-M-Machine.local>
Co-authored-by: Raj Agrawal <r.agrawal.mit@gmail.com>
Co-authored-by: Andy Zane <andy@basis.ai>
Co-authored-by: eb8680 <eb8680@users.noreply.github.com>
Co-authored-by: Andy Zane <azane@nnu.edu>
Co-authored-by: Alvin Qin <qinqian@users.noreply.github.com>
Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>
Co-authored-by: Eli <eli@basis.ai>
Co-authored-by: Eli <eli@elibingham.com>
Co-authored-by: Zenna Tavares <zennatavares@gmail.com>
  • Loading branch information
12 people committed Nov 29, 2023
1 parent f4d72c0 commit e70fcf8
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 2 deletions.
4 changes: 4 additions & 0 deletions chirho/dynamical/handlers/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chirho.dynamical.internals.solver import (
Interruption,
apply_interruptions,
check_dynamics,
get_new_interruptions,
simulate_to_interruption,
)
Expand All @@ -28,6 +29,9 @@ class InterruptionEventLoop(Generic[T], pyro.poutine.messenger.Messenger):
def _pyro_simulate(msg) -> None:
dynamics, state, start_time, end_time = msg["args"]

if pyro.settings.get("validate_dynamics"):
check_dynamics(dynamics, state, start_time, end_time, **msg["kwargs"])

# local state
all_interruptions: List[Prioritized] = []
heapq.heappush(
Expand Down
14 changes: 13 additions & 1 deletion chirho/dynamical/handlers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _pyro_simulate_to_interruption(self, msg) -> None:

interruptions, dynamics, initial_state, start_time, end_time = msg["args"]
msg["kwargs"].update(self.odeint_kwargs)

msg["value"] = torchdiffeq_simulate_to_interruption(
interruptions,
dynamics,
Expand All @@ -58,3 +57,16 @@ def _pyro_simulate_to_interruption(self, msg) -> None:
**msg["kwargs"]
)
msg["done"] = True

def _pyro_check_dynamics(self, msg) -> None:
from chirho.dynamical.internals.backends.torchdiffeq import (
torchdiffeq_check_dynamics,
)

dynamics, initial_state, start_time, end_time = msg["args"]
msg["kwargs"].update(self.odeint_kwargs)

torchdiffeq_check_dynamics(
dynamics, initial_state, start_time, end_time, **msg["kwargs"]
)
msg["done"] = True
19 changes: 19 additions & 0 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
from typing import Callable, List, Optional, Tuple, TypeVar

import pyro
import torch
import torchdiffeq

Expand All @@ -13,6 +14,24 @@
T = TypeVar("T")


class TorchdiffeqRuntimeCheck(pyro.poutine.messenger.Messenger):
def _pyro_sample(self, msg):
raise ValueError(
"TorchDiffEq only supports ODE models, and thus does not allow `pyro.sample` calls."
)


def torchdiffeq_check_dynamics(
dynamics: Dynamics[torch.Tensor],
initial_state: State[torch.Tensor],
start_time: torch.Tensor,
end_time: torch.Tensor,
**kwargs,
) -> None:
with TorchdiffeqRuntimeCheck():
dynamics(initial_state)


def _deriv(
dynamics: Dynamics[torch.Tensor],
var_order: Tuple[str, ...],
Expand Down
22 changes: 22 additions & 0 deletions chirho/dynamical/internals/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ def apply_interruptions(
"""
# Default is to do nothing.
return dynamics, start_state


@pyro.poutine.runtime.effectful(type="check_dynamics")
def check_dynamics(
dynamics: Dynamics[T],
initial_state: State[T],
start_time: R,
end_time: R,
**kwargs,
) -> None:
"""
Validate a dynamical system.
"""
pass


DYNAMICS_VALIDATION_ENABLED: bool = False


@pyro.settings.register("validate_dynamics", __name__, "DYNAMICS_VALIDATION_ENABLED")
def _check_validate_dynamics_flag(value: bool) -> None:
assert isinstance(value, bool)
4 changes: 3 additions & 1 deletion chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def simulate(
"""
Simulate a dynamical system.
"""
from chirho.dynamical.internals.solver import simulate_point
from chirho.dynamical.internals.solver import check_dynamics, simulate_point

with contextlib.nullcontext() if solver is None else solver:
if pyro.settings.get("validate_dynamics"):
check_dynamics(dynamics, initial_state, start_time, end_time, **kwargs)
return simulate_point(dynamics, initial_state, start_time, end_time, **kwargs)
68 changes: 68 additions & 0 deletions tests/dynamical/test_validate_dynamics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging

import pyro
import pytest
import torch

from chirho.dynamical.handlers.event_loop import InterruptionEventLoop
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.internals.solver import check_dynamics
from chirho.dynamical.ops import State, simulate

pyro.settings.set(module_local_params=True)

logger = logging.getLogger(__name__)

# Global variables for tests
init_state = State(S=torch.tensor(1.0))
start_time = torch.tensor(0.0)
end_time = torch.tensor(4.0)


def valid_diff(state: State) -> State:
return state


def invalid_diff(state: State) -> State:
pyro.sample("x", pyro.distributions.Normal(0.0, 1.0))
return State(S=(state["S"]))


def test_validate_dynamics_torchdiffeq():
with TorchDiffEq():
check_dynamics(
valid_diff,
init_state,
start_time,
end_time,
)

with pytest.raises(ValueError):
with TorchDiffEq():
check_dynamics(
invalid_diff,
init_state,
start_time,
end_time,
)


def test_validate_dynamics_setting_torchdiffeq():
with pyro.settings.context(validate_dynamics=False):
with InterruptionEventLoop(), TorchDiffEq():
simulate(
invalid_diff,
init_state,
start_time,
end_time,
)

with pyro.settings.context(validate_dynamics=True):
with pytest.raises(ValueError):
with InterruptionEventLoop(), TorchDiffEq():
simulate(
invalid_diff,
init_state,
start_time,
end_time,
)

0 comments on commit e70fcf8

Please sign in to comment.