Skip to content

Commit

Permalink
Fix 2sls marginal effect
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Jun 10, 2020
1 parent 223096b commit 50c5613
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
13 changes: 11 additions & 2 deletions econml/tests/test_two_stage_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def make_random(d):
sz = (n, d) if d >= 0 else (n,)
return np.random.normal(size=sz)

for d_t in [1, 2]:
for d_t in [-1, 1, 2]:
n_t = d_t if d_t > 0 else 1
for d_y in [1, 2]:
for d_y in [-1, 1, 2]:
for d_x in [1, 5]:
for d_z in [1, 2]:
d_w = 1
Expand All @@ -80,9 +80,18 @@ def make_random(d):
dt_featurizer=DPolynomialFeatures())

est.fit(Y, T, X, W, Z)

eff = est.effect(X)
marg_eff = est.marginal_effect(T, X)

effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
marginal_effect_shape = ((n if d_x else 1,) +
((d_y,) if d_y > 0 else ()) +
((d_t,) if d_t > 0 else()))

self.assertEqual(shape(marg_eff), marginal_effect_shape)
self.assertEqual(shape(eff), effect_shape)

def test_marg_eff(self):
X = np.random.normal(size=(5000, 2))
Z = np.random.normal(size=(5000, 2))
Expand Down
23 changes: 16 additions & 7 deletions econml/two_stage_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,12 @@ def fit(self, Y, T, X, W, Z, inference=None):
W = np.empty((shape(Y)[0], 0))
assert shape(Y)[0] == shape(T)[0] == shape(X)[0] == shape(W)[0] == shape(Z)[0]

# make T 2D if if was a vector
if ndim(T) == 1:
T = reshape(T, (-1, 1))

# store number of columns of W so that we can create correctly shaped zero array in effect and marginal effect
self._d_w = shape(W)[1]
# store number of columns of T so that we can pass scalars to effect
# TODO: support vector T and Y
self._d_t = shape(T)[1]

# two stage approximation
# first, get basis expansions of T, X, and Z
Expand Down Expand Up @@ -285,9 +286,13 @@ def effect(self, X=None, T0=0, T1=1):
"""
if ndim(T0) == 0:
T0 = np.full((1 if X is None else shape(X)[0], self._d_t), T0)
T0 = np.full((1 if X is None else shape(X)[0],) + self._d_t, T0)
if ndim(T1) == 0:
T1 = np.full((1 if X is None else shape(X)[0], self._d_t), T1)
T1 = np.full((1 if X is None else shape(X)[0],) + self._d_t, T1)
if ndim(T0) == 1:
T0 = reshape(T0, (-1, 1))
if ndim(T1) == 1:
T1 = reshape(T1, (-1, 1))
if X is None:
X = np.empty((shape(T0)[0], 0))
assert shape(T0) == shape(T1)
Expand Down Expand Up @@ -329,7 +334,7 @@ def marginal_effect(self, T, X=None):

ft_X = self._x_featurizer.transform(X)
n = shape(T)[0]
dT = self._dt_featurizer.transform(T)
dT = self._dt_featurizer.transform(T if ndim(T) == 2 else reshape(T, (-1, 1)))
W = np.zeros((size(T), self._d_w))
# dT should be an n×dₜ×fₜ array (but if T was a vector, or if there is only one feature,
# dT may be only 2-dimensional)
Expand All @@ -342,4 +347,8 @@ def marginal_effect(self, T, X=None):
features = transpose(features, [0, 1, 3, 2]) # swap last two dims to match cross_product
features = reshape(features, (size(T), -1))
output = self._model_Y.predict(_add_zeros(np.hstack([W, features])))
return reshape(output, shape(T) + (shape(output)[-1],))
output = reshape(output, shape(T) + shape(output)[1:])
if ndim(output) == 3:
return transpose(output, (0, 2, 1)) # transpose trailing T and Y dims
else:
return output

0 comments on commit 50c5613

Please sign in to comment.