Skip to content

Commit

Permalink
Pass bootstrap_kwargs to bootstrap call in get_moments_cov (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Dec 29, 2023
1 parent 5deb5a2 commit 50c15c3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/estimagic/estimation/msm_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def func(data, **kwargs):
) # xxxx won't be necessary soon!
return out

cov_arr = bootstrap(data=data, outcome=func, outcome_kwargs=moment_kwargs).cov()
cov_arr = bootstrap(
data=data, outcome=func, outcome_kwargs=moment_kwargs, **bootstrap_kwargs
).cov()

if isinstance(cov_arr, pd.DataFrame):
cov_arr = cov_arr.to_numpy() # xxxx won't be necessary soon
Expand Down
21 changes: 21 additions & 0 deletions tests/estimation/test_msm_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,24 @@ def calc_moments(data, keys):
assert cov.shape == (3, 3)

assert cov[0, 0] > cov[1, 1] > cov[2, 2]


def test_get_moments_cov_passes_bootstrap_kwargs_to_bootstrap():
rng = get_rng(1234)
data = rng.normal(scale=[10, 5, 1], size=(100, 3))
data = pd.DataFrame(data=data)

def calc_moments(data, keys):
means = data.mean()
means.index = keys
return means.to_dict()

moment_kwargs = {"keys": ["a", "b", "c"]}

with pytest.raises(ValueError, match="a must be a positive integer unless no"):
get_moments_cov(
data=data,
calculate_moments=calc_moments,
moment_kwargs=moment_kwargs,
bootstrap_kwargs={"n_draws": -1},
)

0 comments on commit 50c15c3

Please sign in to comment.