diff --git a/docs/source/deepscm.ipynb b/docs/source/deepscm.ipynb index 42c3d9d61..0dfea1b57 100644 --- a/docs/source/deepscm.ipynb +++ b/docs/source/deepscm.ipynb @@ -5,52 +5,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Deep Structural Causal Models Example" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Outline\n", - "\n", - "[Setup](#setup)\n", - "\n", - "[Overview: Counterfacutal estimation with normalizing flows](#overview-counterfactual-estimation-with-normalizing-flows)\n", - "- [Task: Counterfactual Inference](#task-counterfactual-inference)\n", - "- [Challenge: Holding exogeous noise fixed with tractable likelihoods](#challenge-holding-exogenous-noise-fixed-with-tractable-likelihoods)\n", - "- [Assumptions: All confounders observed. Unique mapping from structural functions to joint probability distributions.](#assumptions-all-confounders-observed-unique-mapping-from-structural-functions-to-joint-probability-distributions)\n", - "- [Intuition: Deep invertible neural networks using Normalizing Flows](#intuition-deep-invertible-neural-networks-using-normalizing-flows)\n", - "- [Caveat: Strong assumptions and identifiability](#caveat-strong-assumptions-and-identifiability)\n", - "\n", - "[Example: Morpho-MNIST](#example-morpho-mnist)\n", - "- [Variables](#variables)\n", - "- [Motivation](#motivation)\n", - "\n", - "[Causal Probabilistic Program](#causal-probabilistic-program)\n", - "- [Model Description](#model-description)\n", - "- [Maximum Likelihood Inference](#maximum-likelihood-inference)\n", - "- [Informal Predictive Check: Visualizing Samples](#informal-predictive-check-visualizing-samples)\n", - "\n", - "[Causal Query: counterfactual data generationg](#causal-query-counterfactual-data-generation)\n", - "\n", - "[Results](#results)\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup\n", - "\n", - "Here, we install the necessary Pytorch, Pyro, and Causal Pyro dependencies for this example." + "# Example: Deep structural causal model counterfactuals" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -62,7 +22,6 @@ } ], "source": [ - "%reload_ext tensorboard\n", "%reload_ext autoreload\n", "%autoreload 2\n", "%pdb off\n", @@ -87,6 +46,7 @@ "import pyro.distributions as dist\n", "\n", "from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual\n", + "from causal_pyro.indexed.ops import IndexSet, gather, indices_of\n", "from causal_pyro.interventional.handlers import do\n", "\n", "pyro.clear_param_store()\n", @@ -109,15 +69,15 @@ "source": [ "### **Task:** Counterfactual inference\n", "\n", - "With the exception of the [mediation](mediation.ipynb) analysis example, previous examples have focussed on the (conditional) average treatment effects. These estimands answer questions of the form: \"what is the average difference in outcomes for all individuals in the population if they were forced to take treatment $T=1$ relative to if they were forced to take treatment $T=0$. In some settings, however, we are interested in answering retrospective, or \"counterfactual\", questions about individuals. These questions take the form: \"for individual $i$ who's attributes were $X_i$, $T_i$, and $Y_i$, what would $Y_i$ had been if they were forced to take treatment $T_i$ instead?\" This question is different for two reasons: (i) it refers to an individual, rather than a population, and (ii) we are conditioning on what actually happened (i.e. the factual conditions) for that individual. \n", + "With the exception of the [mediation](mediation.ipynb) analysis example, previous examples have focussed on the (conditional) average treatment effects. These estimands answer questions of the form: \"what is the average difference in outcomes for all individuals in the population if they were forced to take treatment $T=1$ relative to if they were forced to take treatment $T=0$. In some settings, however, we are interested in answering retrospective, or \"counterfactual\", questions about individuals. These questions take the form: \"for individual $i$ who's attributes were $X_i$, $T_i$, and $Y_i$, what would $Y_i$ have been if they were forced to take treatment $T_i$ instead?\" This question is different for two reasons: (i) it refers to an individual, rather than a population, and (ii) we are conditioning on what actually happened (i.e. the factual conditions) for that individual.\n", "\n", "Methodologically, this means that we'll need to be more careful about all of the external, or \"exogenous\" variables that are often ignored when making causal inferences from data. As a somewhat contrived example, we might ordinarily model the probabilistic causal relationships between random variables representing \"how high I throw my hat in the air\", which we'll call $T$ and \"how far away my hat lands\", which we'll call $Y$, using the following probabilistic relationship.\n", "\n", "$$Y_i \\sim uniform(0, 2 * T_i)$$\n", "\n", - "Here, our uncertainty in how far the hand lands is determined entirely by how windy it is at that time; if no wind is present then the hat will land exactly at our feet and if it's particularly windy the hat twice as many feet away from us as how high we threw it in the air.\n", + "Here, our uncertainty in how far the hat lands is determined entirely by how windy it is at that time; if no wind is present then the hat will land exactly at our feet and if it's particularly windy the hat will land twice as many feet away from us as how high we threw it in the air.\n", "\n", - "In this setting, \"interventional\" questions like those we saw in the [tutorial](tutorial_i.ipynb) can be answered by simply replacing $T_i$ in the above distribution with its intervention assignment, and then sampling from the newly formed distribution. However, when we ask a counterfactual question like; \"given that I threw the hat up 1 foot and it landed at my 0.5 feet away from me, how far away would the hat have landed if I threw it up 2 feet instead?\" we can no longer just sample from the conditional distribution, as this ignores our knowledge of what actually happened factual world. In this case, the answer to the counterfactual question would be that the hat would still land at our feet, because we already know that it wasn't windy.\n", + "In this setting, \"interventional\" questions like those we saw in the [tutorial](tutorial_i.ipynb) can be answered by simply replacing $T_i$ in the above distribution with its intervention assignment, and then sampling from the newly formed distribution. However, when we ask a counterfactual question — for example, \"given that I threw the hat up 1 foot and it landed at my feet, 0.5 feet away from me, how far away would the hat have landed if I threw it up 2 feet instead?\" — we can no longer just sample from the interventional distribution, as this ignores our knowledge of what actually happened in the factual world. In this case, the answer to the counterfactual question would be that the hat would still land at our feet, because we already know that it wasn't windy.\n", "\n", "To answer these kinds of counterfactual questions, causal models must instead be written explicitly in \"structural form\", i.e. collections of deterministic functions of exogenous noise. For example, if we were to make some additional assumptions we could alternatively rewrite our earlier hat throwing model as the following, where $W_i$ describes the amount of windiness:\n", "\n", @@ -137,13 +97,11 @@ "2. **Action** - $(Y_i|do(T_i = 2)) = W_i * 2$\n", "3. **Prediction** - $(Y_i|do(T_i=2), W_i=0) = 1$\n", "\n", - "**Meta Note:** I'm avoidng counterfactual notation here, but maybe it's just more clear to be explicit.\n", - "\n", - "As we'll see later, Causal Pyro combines all three of these steps into a joint inference process using a generalization of what is known as a \"twin-world\" representation for causal inference. In general counterfactual questions do not have fully deterministic answers because; (i) exogenous noise can often not be inferred exactly, and (ii) structural functions themselves may contain uncertainty parameters.\n", + "\n", "\n", - "For an excellent overview and discussion of the challenges in answering counterfactual questions, see Bareinboim et al.'s (2022) work.\n", + "As we'll see later, Causal Pyro combines all three of these steps into a joint inference process using a generalization of what is known as a \"twin-world\" representation for causal inference. In general, counterfactual questions do not have fully deterministic answers because; (i) exogenous noise can often not be inferred exactly, and (ii) structural functions themselves may contain uncertainty parameters.\n", "\n", - "TODO: add citations." + "For an excellent overview and discussion of the challenges in answering counterfactual questions, see Bareinboim et al.'s (2022) work.\n" ] }, { @@ -153,7 +111,7 @@ "source": [ "### **Challenge:** Holding exogenous noise fixed with tractable likelihoods\n", "\n", - "In our simplified example above, we assumed that model parameters (and thus structural functions) were known aprior. In practice this is hardly ever the case, even with the stronger assumptions necessary for answering counterfactual questions. Instead, we would like to learn model parameters within a function class that fit observational data, and then later use those learned parameters to answer counterfactual questions. In particular, we'd like these models to permit a broad class of structural functions, such as Gaussian processes or neural networks.\n", + "In our simplified example above, we assumed that model parameters (and thus structural functions) were known aprior. In practice this is hardly ever the case, even with the stronger assumptions necessary for answering counterfactual questions. Instead, we would like to learn model parameters within a function class that fits observational data, and then later use those learned parameters to answer counterfactual questions. In particular, we'd like these models to permit a broad class of structural functions, such as Gaussian processes or neural networks.\n", "\n", "Unfortunately, one challenge with using these kinds of high-capacity function approximations for counterfactual inference is that they are not often invertible, making it difficult to infer values of exogenous noise for any particular data instance.\n", "\n", @@ -167,9 +125,9 @@ "source": [ "### **Assumptions:** All confounders observed. Unique mapping from structural functions to joint probability distributions.\n", "\n", - "Like many of the examples thusfar, in this example we will assume that all confounders between endogenous variables are observed. See the [backdoor](backdoor.ipynb) example for a more in-depth description of this assumption.\n", + "Like many of the examples thus far, in this example we will assume that all confounders between endogenous variables are observed. See the [backdoor](backdoor.ipynb) example for a more in-depth description of this assumption.\n", "\n", - "Additionally, estimating counterfactual quantities requires additional assumptions. Just as many interventional distributions can map to the same observational distribution (see the [tutorial](tutorial.ipynb)), so too can many counterfactual distributions map to the same interventional distribution. Above we chose a single reparameterization of the conditional probability distribution $P(Y_i|T_i)$ in terms of structural functions, but that was just one particular choice, and other choices can often result in different counterfactual conclusions. In general, to disambiguate between multiple plausible structural causal models we must either assume a family structural causal models a priori, either by specifying a parameteric family as we do here, or by making more declarative assumptions about structural functions (e.g. structural functions are monotonic \\[Pearl 2009\\]). Importantly, the use of Normalizing flows as the parametric family in this case is both an innovation **and** an assumption, implicitly restricting the space of structural functions apriori." + "Additionally, estimating counterfactual quantities requires additional assumptions. Just as many interventional distributions can map to the same observational distribution (see the [tutorial](tutorial.ipynb)), so too can many counterfactual distributions map to the same interventional distribution. Above we chose a single reparameterization of the conditional probability distribution $P(Y_i|T_i)$ in terms of structural functions, but that was just one particular choice, and other choices can often result in different counterfactual conclusions. In general, to disambiguate between multiple plausible structural causal models we must either assume a family structural causal models a priori, either by specifying a parameteric family as we do here, or by making more declarative assumptions about structural functions (e.g. structural functions are monotonic \\[Pearl 2009\\]). Importantly, the use of Normalizing flows as the parametric family in this case is both an innovation **and** an assumption, implicitly restricting the space of structural functions a priori." ] }, { @@ -197,7 +155,7 @@ "\n", "In fact, many causal inference researchers dismiss this kind of unit-level counterfactual estimation as entirely implausible (see Rubin's [\"fundamental theorem of causal inference\"](https://en.wikipedia.org/wiki/Rubin_causal_model#:~:text=would%20be%20masked.-,Conclusion,effect%20on%20a%20single%20unit.)).\n", "\n", - "As computational tool-builders, our goal is only to provide users with the means of deriving causal conclusions from their causal assumptions, regardless of how strong those assumptions may be." + "As builders of computational tools, our goal is only to provide users with the means of deriving causal conclusions from their causal assumptions, regardless of how strong those assumptions may be." ] }, { @@ -207,9 +165,9 @@ "source": [ "## Example: Morpho-MNIST\n", "\n", - "In this notebook, we demonstrate the use of \"Deep SCMs\" on a semi-synthetic example derived from the standard MNIST benchmark. Specifically, we use the same dataset that was generated as a part of the empirical evaluation from the Deep SCM paper on which this example is based. In this notebook we filter the dataset to only include images for the digit 5 to reduce the computational overhead of neural network training.\n", + "We consider a synthetic dataset based on MNIST, where the image of each digit ($X$) depends on stroke thickness ($T$) and brightness ($I$) of the image and the thickness depends on brightness as well.\n", "\n", - "In summary, in constructing this dataset the authors generate synthetic scalar values for \"thickness\" and \"intensity\" and then transform real MNIST images using the image-transformation techniques described in the Morpho-MNIST paper \\[Castro et al. 2019\\]. See Section 4 in the Deep SCM paper \\[Pawlowski et al. 2020\\] for additional detail." + "In constructing this dataset, the authors generate synthetic scalar values for \"thickness\" and \"intensity\" and then transform real MNIST images using the image-transformation techniques described in the Morpho-MNIST paper \\[Castro et al. 2019\\]. See Section 4 in the Deep SCM paper \\[Pawlowski et al. 2020\\] for additional detail." ] }, { @@ -234,14 +192,26 @@ "While this specific example is somewhat contrived, it does demonstrate the utility of incorporating neural networks as components of probabilistic causal models, and subsequently reasoning about counterfactual outcomes in high-dimensional non-linear settings. These derived counterfactual outcomes can help provide explanations to experts, such as doctors looking at brain images, and help them to better understand their domain." ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### Dataset\n", + "\n", + "Here, we load and process the dataset as described thus far. The raw data is available here, at the Morpho-MNIST [repository](https://github.com/dccastro/Morpho-MNIST#datasets).\n" + ] + }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcoAAAGFCAYAAAB9krNlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAL2UlEQVR4nO3cWYyedRnG4XeWLnQKLWGVpS2lC5sIYUkJAWIQLCqpEiGCiAgaFBcSAoIQExe2xCaIRhqRE5cgpAKRRaCgggq2VKAsZZPVQMGUUlIg3Wbm9cQTbLnTJ+04M8x1Hd9886UH34//ydPRtm3bAAAb1TnYXwAAhjKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACLo3dXhM54kD+T1gs9zdP3+wvwLvw28HQ9mm/HZ4UQJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAATdg/0FPig6umv/lB1jxpT2aw/fu7QfvXx1af/akRNK+/HL+mr7+YtKe4ChwosSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgGL63Xju7SvMXLz20tN/9kFdL+0un3lza7z16XWk/ofP+0n7qPWeU9u3q9aX9dkv7S3tgeOjs6Snt1xyxT2n/0oltad/RVdvPPPup0n5TeFECQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEw/fWa1u7NdpOXl3anz3p3tJ+l+7a58+65tul/Zpda7dYZ3zt4dK+6e+r7YFB0b3brqX9kz/cubT/zVG/KO0PH1u7Q/3YujWl/emPfbG07+je8lnzogSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiG8a3XtjSfesqS0v7Kz55c2h/8nYdK+0k/eKC0Bz6YunfdpbQ/8LZ/lfanbLWwtD/t92eX9lNurd2hHr342dJ+h7efKe0H4mq1FyUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAwfG+9DrCeGx8s7e8/fXppv+qS7Ur7PX9cu4/Y98aK0h7YMt4847DS/vbvzy3tZ192Xmn/8PVPl/bTVtZuw1b1D+inDwwvSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgMCt1/fTtqX5Dl9YXtp3X9dX2h973/Ol/Z2nHl7at48sLe2BjVv18XdL+1FNR2k/4YX1pX3fypWlPRvyogSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAjcet1CqvcUtzmutv/prz9a2l/zu1+V9j/a66DSvl2/rrSHAdNRu5X60iWzap9fO/vcTLtwWWl/34IdS/utL3qltF97V2nORnhRAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABG69biGdB+xT2v/z/DGl/Z+P+Glp/4d3Zpb2be/60h6Giq7tty/t7z9tbmn/p9W7lPZXHXx0aX/cuNrd5wv+tkdpP6V5vbRnQ16UABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAwYi59dq1/Xal/VOX7FnaL/nUVaX9st62tD/6+vNL++mXP13aN23t3iQMFX3Ll5f2nznn3NL+1Tm1O8jdo/tK+wPnnVPaT7lsUWnP5vOiBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIbOrddDP1yav3phf2l/7yHXlPajOmr/D3HA7d8q7fe+6PnSfuqKv5f2tWuTMHKMu6l2K3X6TQP0RRg2vCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgCCIXPr9e2pPaX92mc7SvtjF5xX2n/ormWl/YwXF5f2brECDA9elAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAMGQufW69fULi/sB+iL/1TuwHw/AMOFFCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEHS0bdsO9pcAgKHKixIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAgu5NHR7TeeJAfg/YbHf3zx/sr8BG+O1gKNuU3w0vSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgKB7sL8AQ0fn2LGlfbvftNK+65XlpX3v6/8u7QEGghclAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJA4NbrIOns6Sn/N2+esH9p/8axa0r7nx12XWk/e9zC0n7qgjNL++mnu/UKm6v/qANL+9fOWVfa33nwz0v73brHl/Z73PqV0n7GWYtL+03hRQkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABC49bqFdE/evbQ//o6Hy3/jqxPvL+1veXdcaX/Jc58s7S++aafSfq/5T5f2faU1jAzV35qDf/JgaX/ry/uV9kf/8vzSfsJzpXmzzz0vl/a9tY/fJF6UABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVuv76Oju/ZPs/MNK0v7KaPfKO2bpmmO+MZZpf24m2s3HnvaF2r7prZ3u5WRYNUps0r7iU+uKu17lzxZ2i/8yKjSfofmmeJ+YA3E7dYqL0oACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBg5Nx67egozZ+7/JDS/s5J80r7petWl/ZN0zRvzuwq7ce1bflvAO/VceC+pf38K+aW9l+efnRpz/+fFyUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAwYm69rjhzVmm/5OQrS/v9rjq3tF+zf/3W66zjnyztl19e/hPA/+gbP7q036lrq9J+7IKJpf1bV0wq7cfcsbi0Z0NelAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAMGIufW6ww1PlPYnPXBaaT/ptadK+08/8Exp3zRNM+/ZI0v7HZu3yn8DeK/Ovz5S2h8/5bDS/p05U0r7q+ddVdpfMOdLpX3/o7XfspHAixIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAYMrden587q7Tvm9hb2s+ct7q0f3n2hNL+4tP+WNpv1/VOad80TbPT97pK+7b8F4DN1a5fV9r33PhgaX/jdw8q7VfN2Ka0H/9oaT4ieFECQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQ+bWa9/4/tL+xU9cW/v842qfv3ht7VLq52/5emk/89LnS/umaZp2+dLyfwO819ufq92VfmNO7U70qMd7SvvdP/Zyab/nmEWl/aJblpT2bkRvyIsSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgGDK3Xmec/VBpP/vaU0v7jt7ardd26XOl/bT1C0v7vtIa2FLePOHd0v6b+/6ltL9h24NK+2W3TS7tf3v1itK+XbumtGdDXpQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQDBkLn12vTXrp+2/3iiti+tgQ+qySc9Xtrf1mxb2vc0Lwzovna1mi3BixIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASDoaNu2HewvAQBDlRclAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAAT/Ae4XiDgR6o+iAAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcoAAAGFCAYAAAB9krNlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAL+0lEQVR4nO3ce6zXdR3H8d+56CGkvOMdFRREUYnUpHQDCyvKWabONu2mabqVGaVL25za/YLLNktNTHTNRn+YNsnMnM0QMdGFRiKaTsELCCgRiOf8fv3R+sOBr523nNM5Rx6Pv19wvnPuPPn8825rtVqtBgCwWe0D/QEAMJgJJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQNDZ2+G09lP68ztgi9zVnDPQn8Cb8LuDwaw3vzu8KAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAILOgf6At4u2rq7Svn348NJ+7dSxpf2wl14r7ZcfW/ueEc+1Svsdbrq/tAcYLLwoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAgiF767Wts/bpT/zwiNL+PUc+Udpftvftpf0+nbV/o4xov6e0P3je6aX9hnUbSvud/z5k/9cBgo4dti/t1x53UGm//KSNpX1HR7O0H3P2U6V9b3hRAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEP2YGer2SrtdzhgVWl/3h6126q7dtS+Z9JNF5T2zX3Xl/YHfPrR0r7V3V3aAwOjc79Rpf2S7+xU2s+ZfE1pP7Hr3tL+ydf/Vdqf+XjtbnVb17alfW94UQJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARD9tZro9lTmu9ywpLS/sLPnlPaf3LGH0v7/S+eX9o3WrVbsrU1MFA6R+9X2k+5bVFpf07X86X9J+Z+qbQfNbc0b2x33xOlfdfqp0v7Whl6x4sSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgGLq3XvvZjrMXlPb3fubA0n7plaNK+4N+9Gxp3/3cstIe6Bsrzp1c2t9zyczS/pirZpT2+9xQu606dkXtd19Vf9xi7W9elAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFbr2+mWbtI2Pap2n78Lc+U9qffPb+0n/X5E0v7tr88UtoDm9fzoTWlfUejrbTfaXF3ad+zYkVpz6a8KAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAK3XvtI9wsv1v7A1Np9xyt+M720n33ztaX9NydMLe2b69aV9tBv2jtK8yd/cGRp3yo+J8ZdWLut+te5w0v70ZcsLu2X316asxlelAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFbr32k9f6Jpf3yGa+X9vccWbvdOm/DbqV9c/2G0h4Gi849dy/tHzrtytL+gQ3vKu1nTjy+tJ887LXS/qz7DintxzTml/ZsyosSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEg2GpuvXbuvVdp/4/vjSztF025prRf1dxY2h9924zSfvy3nintG80XansYJLqfW1baf+RrF5T2Kz/+79J+2LDaHefDZn25tD/gsgdL+1ZpzeZ4UQJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAASD5tZrz9RJpf36i9aU9ndO+FVp3178N8Rhfz67tB/39ZdK+wOXPVDad5fWsPV45y3zi/t++pC3yO3W/z8vSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgGDQ3Hp9dd+u0v7lxSNL+/f94aul/d53vlzaj3nskdLeLVaAocGLEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBg0t153/OX9tX0/fcf/9PTz3w/A0OBFCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAELS1Wq3WQH8EAAxWXpQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEHT2djit/ZT+/A7YYnc15wz0J7AZfncwmPXm94YXJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQCCUABAIJQAEQgkAgVACQNA50B/A0NWx446lfXPt2tK+1d1d2gN9oL2jNO/cc/fa399Re5+1Vr9S2ve8+mpp3xtelAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFbr0NI27sPKe2fOuVdpf0Hpz1c2l+91z2l/fifn1faj7p8XmkPW4O2ztqv7SU/PqK0nzn95tL+hOEPlvYdbbX32a3rRpT210w8vLTvDS9KAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAwK3XPtLW1VXat+7Ytfwzbj9odmn/uWc+UNrPXXhoaX/8+eNK+1HzHyjtgU09e+FRpf3ik39S2v9u3c6l/VELTyvtVy3fvrTf/rFtSvvd1vf97xkvSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgEAoASAQSgAIhBIAAqEEgMCt1z6ydNbBpf2pOz1U/hknTj21tO9Z8mRpP7bxYGkPW4P2YcNK+6duHFvadywaUdrvd93S0v6kG08s7buXLS/td2ksKe6HHi9KAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAwK3XN7HmjMml/dKpPyvtL11xSGnfaDQarxxeu5I4onjrFdjU0xdNKu0fP/bq0n7yrV8s7XtefKm0Z8t5UQJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAARbz63Xow4tzX96+VWl/ZSzzivtn5/8Fv7Tn7yuNB8xp/4jgDfqHt7q179/1ndnlvbTjzu/tB937sOlfau7u7TfGnhRAkAglAAQCCUABEIJAIFQAkAglAAQCCUABEIJAIFQAkAglAAQCCUABFvPrdcFi0rzSydMKe271i8s7cdftEtp32g0Gose3r/8Z4AtM/obC0r7Kfd+obRffsbG0v6fH72utD/4otod6n2+Pa+03xp4UQJAIJQAEAglAARCCQCBUAJAIJQAEAglAARCCQCBUAJAIJQAEAglAASD5tbrynMml/avj2gr7fe+7tHS/tVp40v7PS9YWtqP7FpT2jcajcbGK5aU9j3lnwBsqa47Hiztxyyr/a5Zecy60n7DyGZpz6a8KAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAIJBc+v19e1qt1v/NuPq0n7lV2r3Eb+/Ym1pf/e1R5f2r1y7oLRvNBqNRnNV/c/A2117R2m+8rdjSvvzD/xTaX/5wo+V9r947+zSfnjbNqX9qN+7+rylvCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgCCQXPrdY+Z80r76TdMrf2AVrM071nzSmm/a+P+0h7oG+3bDS/tr59wU2m/b2ftVurYo68v7c+Yf2Zpv8evty3t3zH3LdyV5g28KAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAAKhBIBAKAEgEEoACIQSAIJBc+u1qmf16oH+BGAQaK5dW9pfPOnD/fQl/1X9ntHdj/TPh9BnvCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgCCIXvrFeCtcCeaKi9KAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSAoK3VarUG+iMAYLDyogSAQCgBIBBKAAiEEgACoQSAQCgBIBBKAAiEEgACoQSA4D+7NojZwRwliQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -288,28 +258,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Causal Probabilsitic Program" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Model Description\n", - "\n", - "Just as in our previous examples, here we can encode our causal assumptions as a probabilsitic program in Pyro. Unlike the previous examples, however, in this example we need to be careful to define our model such that each endogenous variable is a deterministic transformation of exogenous noise. Thankfully, Pyro already provides support for defining custom pushforward distributions using `TransformModule`s. This means that later we can (i) tractably condition on endogenous variables using normalizing flows, and (ii) apply interventions on endogenous variables that don't implicitly change exogenous noise.\n", - "\n", - "As the model's implementation is somewhat more involved due to its use of neural network components, we'll start by defining the neural network components in isolation and then later composing them into a causal model. By the end of this section we'll have a Pyro program defining a causal generative model over morphological transformations of MNIST, containing endogenous variables representing the thickness ($T$) and intensity ($I$) of the image, a well as the resulting image itself ($X$).\n", - "\n", - "**Note:** In this example we perform maximum likelihood inference over neural network weights, and thus do not define priors.\n", + "## Model: deep structural causal model\n", "\n", - "First, we define a collection of base classes abstractly representing the transformations from exogenous noise to endogenous variables. Additional detail and background on these transformation modules can be found in the Pyro documentation for `TransformModule`s [here](https://docs.pyro.ai/en/stable/distributions.html?highlight=TransformModule#transformmodule)." + "The following code models morphological transformations of MNIST,\n", + "defining a causal generative model over digits that contains endogenous\n", + "variables to control the width $t$ and intensity $i$ of the stroke:" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -342,41 +300,7 @@ " def __init__(self, transforms: List[dist.transforms.Transform]):\n", " super().__init__([\n", " ConstantParamTransformModule(t) if not isinstance(t, torch.nn.Module) else t for t in transforms\n", - " ])\n", - " \n", - "\n", - "class InverseConditionalTransformModule(dist.conditional.ConditionalTransformModule):\n", - " \n", - " def __init__(self, transform: dist.conditional.ConditionalTransform):\n", - " super().__init__()\n", - " self._transform = transform\n", - " \n", - " @property\n", - " def inv(self) -> dist.conditional.ConditionalTransform:\n", - " return self._transform\n", - " \n", - " def condition(self, context: torch.Tensor):\n", - " return self._transform.condition(context).inv\n", - "\n", - "\n", - "class ConditionalComposeTransformModule(dist.conditional.ConditionalTransformModule):\n", - " def __init__(self, transforms: List):\n", - " self.transforms = [\n", - " dist.conditional.ConstantConditionalTransform(t)\n", - " if not isinstance(t, dist.conditional.ConditionalTransform)\n", - " else t\n", - " for t in transforms\n", - " ]\n", - " super().__init__()\n", - " # for parameter storage... TODO is this necessary?\n", - " self._transforms_module = torch.nn.ModuleList([t for t in transforms if isinstance(t, torch.nn.Module)])\n", - " \n", - " @property\n", - " def inv(self):\n", - " return InverseConditionalTransformModule(self)\n", - "\n", - " def condition(self, context: torch.Tensor):\n", - " return ComposeTransformModule([t.condition(context) for t in self.transforms]).with_cache(1)" + " ])" ] }, { @@ -384,7 +308,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we use these abstractions to construct transformations for each of our two scalar values in the causal model, stroke thickness ($T$) and lighting intensity ($I$). " + "We model stroke thickness with a learnable univariate spline transformation of a standard normal distribution (defined later):" ] }, { @@ -400,9 +324,24 @@ " dist.transforms.Spline(thickness_size, bound=1.),\n", " dist.transforms.AffineTransform(loc=bias, scale=weight),\n", " dist.transforms.biject_to(dist.constraints.positive),\n", - " ])\n", - "\n", - "class IntensityTransform(ConditionalComposeTransformModule):\n", + " ])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use a similar approach to model the *conditional* distribution of stroke intensity *given* stroke thickness. The parameters of the univariate spline transformation are not left free, but are themselves a learnable function of thickness. We parameterize this second learnable function with a small multilayer perceptron:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class IntensityTransform(dist.conditional.ConditionalComposeTransformModule):\n", " def __init__(\n", " self,\n", " intensity_size: int,\n", @@ -437,18 +376,36 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, "source": [ - "The transformation for the images is somewhat involved. Much of the neural network architecture is taken from [this PyTorch tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html) on normalizing flows, which readers are encouraged to peruse for further background on normalizing flows in general and this architecture in particular." + "The transformation for the images is somewhat more involved. We will define it in two parts: an expressive unconditional transformation and a smaller conditional transformation, both of which are themselves composed of repeated blocks of autoregressive normalizing flows. Unlike the transforms for thickness and intensity, which were defined as maps from gaussian noise to data, our image transformation will be defined in the reverse direction, as a mapping from data to noise; the causal probabilistic program we write later on will define a distribution on images using the inverse of our definition here.\n", + "\n", + "Much of the neural network architecture in the first, unconditional transformation is taken from [this excellent PyTorch tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html) on generative modelling of images with normalizing flows, which readers are encouraged to peruse for further background on normalizing flows in general and this architecture in particular. The only significant difference between the definitions of the unconditional image transform and its components in this notebook and the PyTorch tutorial above is that this notebook implements a `torch.distributions.Transform` interface for `MaskedAffineCouplingLayer`, making it fully compatible with Pyro models and inference algorithms. The code has also been tweaked to improve compatibility with Pyro's (ab)use of broadcasting in PyTorch." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "class DequantizationTransform(ComposeTransformModule):\n", + " def __init__(self, alpha: float = 1e-5):\n", + " layers = [\n", + " dist.transforms.IndependentTransform(\n", + " dist.transforms.ComposeTransform([\n", + " dist.transforms.AffineTransform(0., 1. / 256),\n", + " dist.transforms.AffineTransform(alpha, (1 - alpha)),\n", + " dist.transforms.SigmoidTransform().inv,\n", + " ]), 3)\n", + " ]\n", + " super().__init__(layers)\n", + "\n", + "\n", "class ConcatELU(torch.nn.Module):\n", " \"\"\"\n", " Activation function that applies ELU in both direction (inverted and plain).\n", @@ -536,7 +493,7 @@ " return self.nn(x)\n", "\n", "\n", - "class MaskedAffineCoupling(dist.torch_transform.TransformModule):\n", + "class MaskedAffineCouplingLayer(dist.torch_transform.TransformModule):\n", " bijective = True\n", " domain = dist.constraints.independent(dist.constraints.real, 3)\n", " codomain = dist.constraints.independent(dist.constraints.real, 3)\n", @@ -587,7 +544,7 @@ " return (x + t) * torch.exp(s)\n", "\n", "\n", - "class ImageTransform(ConditionalComposeTransformModule):\n", + "class UnconditionalImageTransform(ComposeTransformModule):\n", " def __init__(\n", " self,\n", " im_size: int,\n", @@ -598,11 +555,8 @@ " layers_per_block: int,\n", " hidden_channels: int,\n", " *,\n", - " num_cond_blocks: int = 1,\n", " alpha: float = 1e-5,\n", - " bn_momentum: float = 0.05,\n", " ln_momentum: float = 1e-5,\n", - " nonlinearity = torch.nn.ReLU(),\n", " ):\n", " self.im_size = im_size\n", " self.input_channels = input_channels\n", @@ -610,26 +564,14 @@ " self.num_blocks = num_blocks\n", " self.layers_per_block = layers_per_block\n", " \n", - " self.num_cond_blocks = num_cond_blocks\n", - " \n", - " self.flat_input_size = input_channels * im_size * im_size\n", - " \n", " layers = []\n", - "\n", + " \n", " # dequantization\n", - " layers += [\n", - " dist.transforms.IndependentTransform(\n", - " dist.transforms.ComposeTransform([\n", - " dist.transforms.AffineTransform(0., 1. / 256),\n", - " dist.transforms.AffineTransform(alpha, (1 - alpha)),\n", - " dist.transforms.SigmoidTransform().inv,\n", - " ]), 3)\n", - " ]\n", + " layers += [DequantizationTransform(alpha=alpha)]\n", " \n", - " # image flow with convolutional blocks\n", " for i in range(num_blocks):\n", " layers += [\n", - " MaskedAffineCoupling(\n", + " MaskedAffineCouplingLayer(\n", " GatedConvNet(input_channels, hidden_channels, layers_per_block, eps=ln_momentum),\n", " self.create_checkerboard_mask(im_size, im_size, invert=(i%2==1)),\n", " input_channels,\n", @@ -638,7 +580,49 @@ " ),\n", " ]\n", " \n", - " # conditioning on context\n", + " super().__init__(layers)\n", + "\n", + " @staticmethod\n", + " def create_checkerboard_mask(h: int, w: int, invert=False):\n", + " x, y = torch.arange(h, dtype=torch.int32), torch.arange(w, dtype=torch.int32)\n", + " xx, yy = torch.meshgrid(x, y, indexing='ij')\n", + " mask = torch.fmod(xx + yy, 2).to(torch.float32).view(1, 1, h, w)\n", + " return mask if not invert else (1. - mask) " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The unconditional generative flow above is compact and expressive enough to learn a high-quality approximation to the high-dimensional data distribution. However, to represent conditional distributions over images, we must compose it with a second transformation that is conditionally invertible given an arbitrary context vector.\n", + "\n", + "Pyro comes with a number of conditional transforms that are suitable for this task. In this example, we will use a series of `ConditionalAffineAutoRegressive` transforms because of their relative simplicity, speed, and stability during training. Detailed explanations of their internals are beyond the scope of this notebook; readers seeking more background information about these transformations should consult the Pyro documentation and source code for `ConditionalAffineAutoRegressive`, `ConditionalAutoRegressiveNN` and related functionality." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionalImageTransform(dist.conditional.ConditionalComposeTransformModule):\n", + " def __init__(\n", + " self,\n", + " im_size: int,\n", + " input_channels: int,\n", + " thickness_size: int,\n", + " intensity_size: int,\n", + " num_cond_blocks: int,\n", + " *,\n", + " nonlinearity = torch.nn.ReLU(),\n", + " ):\n", + " self.im_size = im_size\n", + " self.input_channels = input_channels\n", + " self.num_cond_blocks = num_cond_blocks\n", + " self.flat_input_size = input_channels * im_size * im_size\n", + " \n", + " layers = []\n", " layers += [dist.transforms.ReshapeTransform((input_channels, im_size, im_size), (self.flat_input_size,))]\n", " for i in range(self.num_cond_blocks):\n", " layers += [\n", @@ -653,18 +637,77 @@ " ),\n", " ] \n", " layers += [dist.transforms.ReshapeTransform((self.flat_input_size,), (input_channels, im_size, im_size))]\n", + " super().__init__(layers)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Having defined `UnconditionalImageTransform` and `ConditionalImageTransform`, we can compose them into the full conditionally invertible transformation we will be using in our causal model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ImageTransform(dist.conditional.ConditionalComposeTransformModule):\n", + " def __init__(\n", + " self,\n", + " im_size: int,\n", + " input_channels: int,\n", + " thickness_size: int,\n", + " intensity_size: int,\n", + " num_blocks: int,\n", + " layers_per_block: int,\n", + " hidden_channels: int,\n", + " *,\n", + " num_cond_blocks: int = 1,\n", + " alpha: float = 1e-5,\n", + " ln_momentum: float = 1e-5,\n", + " nonlinearity = torch.nn.ReLU(),\n", + " ):\n", + " self.im_size = im_size\n", + " self.input_channels = input_channels\n", + " self.hidden_channels = hidden_channels\n", + " self.num_blocks = num_blocks\n", + " self.layers_per_block = layers_per_block\n", + " self.num_cond_blocks = num_cond_blocks\n", + " self.flat_input_size = input_channels * im_size * im_size\n", " \n", - " super().__init__(layers)\n", + " layers = []\n", + "\n", + " # unconditional image flow: dequantization followed by convolutional blocks\n", + " layers += [UnconditionalImageTransform(\n", + " im_size=im_size,\n", + " input_channels=input_channels,\n", + " thickness_size=thickness_size,\n", + " intensity_size=intensity_size,\n", + " num_blocks=num_blocks,\n", + " layers_per_block=layers_per_block,\n", + " hidden_channels=hidden_channels,\n", + " alpha=alpha,\n", + " ln_momentum=ln_momentum,\n", + " )]\n", " \n", - " @staticmethod\n", - " def create_checkerboard_mask(h: int, w: int, invert=False):\n", - " x, y = torch.arange(h, dtype=torch.int32), torch.arange(w, dtype=torch.int32)\n", - " xx, yy = torch.meshgrid(x, y, indexing='ij')\n", - " mask = torch.fmod(xx + yy, 2).to(torch.float32).view(1, 1, h, w)\n", - " return mask if not invert else (1. - mask)" + " # conditioning on context with conditional autoregressive flows\n", + " layers += [ConditionalImageTransform(\n", + " im_size=im_size,\n", + " input_channels=input_channels,\n", + " thickness_size=thickness_size,\n", + " intensity_size=intensity_size,\n", + " num_cond_blocks=num_cond_blocks,\n", + " nonlinearity=nonlinearity,\n", + " )]\n", + " \n", + " super().__init__(layers)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -838,38 +881,12 @@ " layers_per_block=3,\n", " hidden_channels=16,\n", " nonlinearity=torch.nn.ELU(),\n", - ")\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This implementation is significantly more involved than previous examples, as we're now defining causal models with neural networks. However, after defining the neural-network transformations, the resulting causal relationships between variables is remarkably simple. We can see this in the rendering of the causal probabilsitic program below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + ")\n", "\n", "model = DeepSCM(thickness_transform, intensity_transform, image_transform)\n", "pyro.render_model(model, render_distributions=True)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Maximum Likelihood Inference\n", - "\n", - "Next, we can implement a `ConditionedDeepSCM` that wraps our original `DeepSCM` model in a `pyro.plate` and a `pyro.condition` context, representing the fact that we observe a collection of annotated images. " - ] - }, { "cell_type": "code", "execution_count": 7, @@ -965,14 +982,6 @@ "pyro.render_model(conditioned_model, model_args=(thickness[:2], intensity[:2], images[:2]), render_distributions=True)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Similar to the [CEVAE](cevae.ipynb) tutorial, first we'll update our neural network weights using maximum likelihood inference over the full conditioned dataset, and then later perform our causal inferences by conditioning on a subset of variables. The following code uses a custom implementation of SVI using `pytorch_lightning`. TODO: explain in a sentence why we do this vs. Pyro's original SVI." - ] - }, { "cell_type": "code", "execution_count": 8, @@ -1042,16 +1051,6 @@ " trainer.fit(model=lightning_svi, train_dataloaders=dataloader)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Informal Predictive Check: Visualizing Samples\n", - "\n", - "Before we move on to counterfactual inference, let's first inspect how well our model and inference represent the generative process over MNIST images. Similar to other examples, we can perform an informal assessment by simply generating samples from our model with learned neural network parameters, inspecting the resulting samples qualitatively." - ] - }, { "cell_type": "code", "execution_count": 9, @@ -1080,23 +1079,8 @@ ], "source": [ "predictive = pyro.infer.Predictive(model, guide=lambda *args: None, num_samples=1000, parallel=True).to(device=torch.device(\"cpu\"))\n", - "samples = predictive()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First let's look at the marginal distributions of stroke thickness and light intensity for MNIST images from our dataset and generated from our model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "samples = predictive()\n", + "\n", "fig = plt.figure()\n", "fig.add_subplot(1, 2, 1)\n", "plt.hist(thickness[..., 0], bins=50, alpha=0.5, density=True, label=\"observed\")\n", @@ -1104,28 +1088,12 @@ "plt.title(\"Thickness\")\n", "plt.legend()\n", "\n", - "\n", "fig.add_subplot(1, 2, 2)\n", "plt.hist(intensity[..., 0], bins=50, alpha=0.5, density=True, label=\"observed\")\n", "plt.hist(samples[\"I\"][..., 0].squeeze(), bins=50, alpha=0.5, density=True, label=\"sampled\")\n", "plt.title(\"Intensity\")\n", - "plt.legend()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we can take a look at a small collection of individual images samples from our model. We can clearly see that our model has learned to generate reasonable looking images for the number 5." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "plt.legend()\n", + "\n", "fig = plt.figure()\n", "plt.title(\"Images\")\n", "plt.axis(\"off\")\n", @@ -1265,12 +1233,12 @@ " \n", " def forward(self, x_obs: torch.Tensor, i_act: Optional[torch.Tensor], t_act: Optional[torch.Tensor]):\n", " assert i_act is not None or t_act is not None\n", - " with MultiWorldCounterfactual(first_available_dim=-2), \\\n", + " with MultiWorldCounterfactual(), \\\n", " pyro.plate(\"observations\", size=x_obs.shape[0], dim=-1), \\\n", " do(actions={\"I\": i_act}) if i_act is not None else contextlib.nullcontext(), \\\n", " do(actions={\"T\": t_act}) if t_act is not None else contextlib.nullcontext(), \\\n", " pyro.condition(data={\"X\": x_obs}):\n", - " return self.model()\n", + " return gather(self.model(), IndexSet(I={1}, T={1}), event_dim=3)\n", "\n", "cf_model = CounterfactualDeepSCM(model)\n", "pyro.render_model(cf_model, model_args=(images[:1], intensity[:1], None))" @@ -1281,27 +1249,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Results" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that we have a transformed causal model representing answers to our counterfactual query, we can sample from it to generate samples from our counterfactual distribution. \n", + "Counterfactuals cannot be identified from observational data in general\n", + "without further assumptions: learning parameters $\\theta$ that match\n", + "observed data does not guarantee that the counterfactual distribution\n", + "will match that of the true causal model. However, as discussed in e.g. the\n", + "original paper [Pawlowski et al. (2020)] in the context of modeling MRI\n", + "images, there are a number of valid practical reasons one might wish to\n", + "compute counterfactuals with a model anyway, such as explanation or expert evaluation.\n", "\n", - "**Note:** The process of inverting the normalizing flow networks happens automatically inside the ... \n", + "It just so happens that the data generating process for our images of handwritten digits has enough additional structure (specifically, one-dimensional covariates with monotonic mechanisms) that recovering approximations of the causal mechanisms from observational data may not be impossible in theory.\n", + "In the following pair of plots, we experimentally interrogate our trained model's causal knowledge by visualizing counterfactual images sampled from the model.\n", "\n", - "TODO: Eli, could you say more about this note that I started. I'm actually not totally sure where the exogenous noise terms are computed and then propagated forward. I think it is worth touching on a little bit, as a lot of the computation is obscured behind the transformations in this part of the example." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TODO: Eli, can you pick up here. I'm not following the difference between these two image sets, nor the rows vs. columns." + "The leftmost entry of each row is an image from the training dataset, followed by several counterfactual sample images drawn from the model given the original image and an intervention on one of the two covariates, with the other covariate held fixed at its observed value.\n", + "\n", + "The intervened values (intensity in the first plot, thickness in the second) for the samples in each row are monotonically increasing from left to right, starting from their ground truth observed values for the training image in that row. Successive samples in a row are shown alongside their pixelwise differences with their neighbors.\n", + "\n", + "In the first plot below, we can see stroke intensity increasing monotonically from left to right while the stroke thickness remains roughly constant." ] }, { @@ -1350,6 +1313,14 @@ " plt.axis(\"off\")" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the second plot, we can see the images' stroke thickness increasing monotonically from left to right in each row, while their stroke intensity remains roughly constant." + ] + }, { "cell_type": "code", "execution_count": 15, @@ -1397,7 +1368,19 @@ ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": {}, + "source": [ + "These qualitative results suggest that our deep generative model has indeed managed to approximately disentangle and recover the causal mechanisms of stroke thickness and intensity." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, "source": [ "## References\n", "\n", @@ -1414,10 +1397,7 @@ "Tavares, Zenna, James Koppel, Xin Zhang, and Armando Solar-Lezama. “A Language for Counterfactual Generative Models.” MIT Technical Report, 2020. http://www.jameskoppel.com/publication/omega/.\n", "\n", "\n" - ], - "metadata": { - "collapsed": false - } + ] } ], "metadata": { @@ -1436,7 +1416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.10.9" }, "vscode": { "interpreter": { @@ -1445,5 +1425,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }