Skip to content

Commit

Permalink
adding raking
Browse files Browse the repository at this point in the history
  • Loading branch information
MamadouSDiallo committed May 5, 2024
1 parent 05947d5 commit 058d85c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
34 changes: 24 additions & 10 deletions src/samplics/weighting/adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def poststratify(
raise AssertionError("control or factor must be specified.")

if isinstance(control, dict):
# breakpoint()
if (np.unique(domain) != np.unique(list(control.keys()))).any():
raise ValueError("control dictionary keys do not match domain values.")

Expand Down Expand Up @@ -396,6 +397,7 @@ def rake(
ll_bound: Optional[Union[DictStrNum, Number]] = None,
up_bound: Optional[Union[DictStrNum, Number]] = None,
tol: float = 1e-4,
ctrl_tol: float = 1e-4,
max_iter: int = 100,
display_iter: bool = False,
) -> np.ndarray:
Expand All @@ -413,49 +415,58 @@ def rake(
print(f"\nIteration {iter + 1}")

if iter == 0:
rk_wgt = samp_weight
wgt_prev = samp_weight

for margin in margins:
domain = formats.numpy_array(margins[margin])
if control is not None:
wgt = self.poststratify(samp_weight=wgt_prev, control=control[margin], domain=domain)
rk_wgt = self.poststratify(samp_weight=rk_wgt, control=control[margin], domain=domain)
elif factor is not None:
wgt = self.poststratify(samp_weight=wgt_prev, factor=factor[margin], domain=domain)
rk_wgt = self.poststratify(samp_weight=rk_wgt, factor=factor[margin], domain=domain)
else:
raise AssertionError("control or factor must be specified!")

wgt_prev = wgt

sum_wgt = {}
sum_prev_wgt = {}
for margin in margins:
domain = formats.numpy_array(margins[margin])
sum_wgt_domain = {}
sum_prev_wgt_domain = {}
for d in control[margin]:
sum_wgt_domain[d] = np.sum(wgt[domain == d])
sum_wgt_domain[d] = np.sum(rk_wgt[domain == d])
sum_prev_wgt_domain[d] = np.sum(wgt_prev[domain == d])
sum_wgt[margin] = sum_wgt_domain
sum_prev_wgt[margin] = sum_prev_wgt_domain

# diff = {}
max_diff = 0
max_ctrl_diff = 0
for margin in margins:
if display_iter:
print(f" Margin: {margin}")

diff_margin = {}
diff_ctrl_margin = {}
for d in control[margin]:
diff_margin[d] = np.abs(control[margin][d] - sum_wgt[margin][d])
diff_margin[d] = np.abs(sum_wgt[margin][d] - sum_prev_wgt[margin][d]) / sum_prev_wgt[margin][d]
diff_ctrl_margin[d] = np.abs(sum_wgt[margin][d] - control[margin][d]) / control[margin][d]
if display_iter:
print(f" Difference for '{d}': {diff_margin[d]}")
print(f" Difference against previous value for '{d}': {diff_margin[d]}")
print(f" Difference against control value for '{d}': {diff_ctrl_margin[d]}")

# diff[margin] = diff_margin
max_diff = max(max_diff, max(diff_margin.values()))
max_ctrl_diff = max(max_ctrl_diff, max(diff_ctrl_margin.values()))

obs_tol = max_diff
obs_ctrl_tol = max_ctrl_diff

if obs_tol <= tol:
if obs_tol <= tol and obs_ctrl_tol <= ctrl_tol:
converged = True

if ll_bound is not None or up_bound is not None:
wgt_ratios = wgt / samp_weight
wgt_ratios = rk_wgt / samp_weight
min_ratio = np.min(wgt_ratios)
max_ratio = np.max(wgt_ratios)

Expand All @@ -474,9 +485,12 @@ def rake(
else:
bounded = True

wgt_prev = rk_wgt
iter += 1

return wgt
self.adj_method = "raking"

return rk_wgt

@staticmethod
def _calib_covariates(
Expand Down
43 changes: 40 additions & 3 deletions tests/weighting/test_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,40 @@
from samplics.weighting import SampleWeight


# stata example

# nhis_sam = pl.read_csv("~/Downloads/nhis_sam.csv").with_columns(
# pl.when(pl.col("hisp") == 4).then(pl.lit(3)).otherwise(pl.col("hisp")).alias("hisp")
# )

# age_grp = {
# "<18": 5991,
# "18-24": 2014,
# "25-44": 6124,
# "45-64": 5011,
# "65+": 2448,
# }
# hisp_race = {1: 5031, 2: 12637, 3: 3920}
# control = {"age_grp": age_grp, "hisp": hisp_race}

# # breakpoint()

# ll = 0.8
# ul = 1.2

# margins = {
# "age_grp": nhis_sam["age_grp"].to_list(),
# "hisp": nhis_sam["hisp"].to_list(),
# }

# nhis_sam_rk = SampleWeight()

# nhis_sam = nhis_sam.with_columns(
# rake_wt_2=nhis_sam_rk.rake(
# samp_weight=nhis_sam["wt"], control=control, margins=margins, display_iter=True, tol=1e-6
# )
# ).with_columns(diff=pl.col("rake_wt_2") - pl.col("rake_wt"))

# synthetic data for testing

wgt = np.random.uniform(0, 1, 1000)
Expand Down Expand Up @@ -275,9 +309,12 @@ def test_ps_wgt_with_class():

sample_wgt_rk_not_bound = SampleWeight()

rk_wgt_not_bound = sample_wgt_rk_not_bound.rake(
samp_weight=income_sample2["design_wgt"], control=control, margins=margins, display_iter=True
)
# rk_wgt_not_bound = sample_wgt_rk_not_bound.rake(
# samp_weight=income_sample2["design_wgt"], control=control, margins=margins, display_iter=True, tol=1e-4
# )



# breakpoint()

# age_grp = {"<18": 21588, age}

0 comments on commit 058d85c

Please sign in to comment.