Skip to content

Commit

Permalink
adding raking
Browse files Browse the repository at this point in the history
  • Loading branch information
MamadouSDiallo committed May 5, 2024
1 parent 18c34a4 commit 05947d5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 23 deletions.
99 changes: 79 additions & 20 deletions src/samplics/weighting/adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,32 +390,91 @@ def _raked_wgt(
def rake(
self,
samp_weight: Array,
control: DictStrNum,
margins: DictStrNum,
control: Optional[DictStrNum] = None,
factor: Optional[DictStrNum] = None,
ll_bound: Optional[Union[DictStrNum, Number]] = None,
up_bound: Optional[Union[DictStrNum, Number]] = None,
tol: float = 1e-4,
max_iter: int = 100,
display_iter: bool = False,
) -> np.ndarray:

# control_wgts = {}
# results_wgts = {}exit
# for margin in margins:
# levels = (
# pl.DataFrame(margins[margin]).filter(pl.col("column_0").is_not_null())["column_0"].unique().to_numpy()
# ) # NoneType and Object types are problematic for np.unique()
# ctls = np.zeros(levels.shape[0])
# wgts = np.zeros(levels.shape[0])
# for i, level in enumerate(levels):
# ctls[i] = control[margin][level]
# wgts[i] = np.add.reduce(samp_weight, where=np.array(margins[margin]) == level)

# control_wgts[margin] = np.column_stack((levels, ctls))
# results_wgts[margin] = np.column_stack((levels, wgts))

wgt_prev = samp_weight
for margin in margins:
wgt = self.normalize(samp_weight=wgt_prev, control=control[margin], domain=margins[margin])
wgt_prev = wgt
samp_weight = formats.numpy_array(samp_weight)

obs_tol = tol + 1
iter = 0

converged = False
bounded = True

while not (converged and bounded) and iter < max_iter:
if display_iter:
print(f"\nIteration {iter + 1}")

if iter == 0:
wgt_prev = samp_weight

for margin in margins:
domain = formats.numpy_array(margins[margin])
if control is not None:
wgt = self.poststratify(samp_weight=wgt_prev, control=control[margin], domain=domain)
elif factor is not None:
wgt = self.poststratify(samp_weight=wgt_prev, factor=factor[margin], domain=domain)
else:
raise AssertionError("control or factor must be specified!")

wgt_prev = wgt

sum_wgt = {}
for margin in margins:
domain = formats.numpy_array(margins[margin])
sum_wgt_domain = {}
for d in control[margin]:
sum_wgt_domain[d] = np.sum(wgt[domain == d])
sum_wgt[margin] = sum_wgt_domain

# diff = {}
max_diff = 0
for margin in margins:
if display_iter:
print(f" Margin: {margin}")

diff_margin = {}
for d in control[margin]:
diff_margin[d] = np.abs(control[margin][d] - sum_wgt[margin][d])
if display_iter:
print(f" Difference for '{d}': {diff_margin[d]}")

# diff[margin] = diff_margin
max_diff = max(max_diff, max(diff_margin.values()))

obs_tol = max_diff

if obs_tol <= tol:
converged = True

if ll_bound is not None or up_bound is not None:
wgt_ratios = wgt / samp_weight
min_ratio = np.min(wgt_ratios)
max_ratio = np.max(wgt_ratios)

if (
(
ll_bound is not None
and up_bound is not None
and (ll_bound <= min_ratio and max_ratio <= up_bound)
)
or (ll_bound is not None and ll_bound <= min_ratio)
or (up_bound is not None and up_bound >= max_ratio)
):
bounded = True
else:
bounded = False
else:
bounded = True

iter += 1

return wgt

Expand Down
10 changes: 7 additions & 3 deletions tests/weighting/test_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,18 @@ def test_ps_wgt_with_class():
inc_grp = {"low": 600_000, "middle": 1_400_000, "high": 640_000}
control = {"educ_level": educ_grp, "income_level": inc_grp}

income_sample2 = income_sample.filter(pl.col("educ_level").is_not_null(), pl.col("income_level").is_not_null())

margins = {
"educ_level": income_sample.filter(pl.col("design_wgt").is_not_null())["educ_level"].to_list(),
"income_level": income_sample.filter(pl.col("design_wgt").is_not_null())["income_level"].to_list(),
"educ_level": income_sample2["educ_level"].to_list(),
"income_level": income_sample2["income_level"].to_list(),
}

sample_wgt_rk_not_bound = SampleWeight()

rk_wgt_not_bound = sample_wgt_rk_not_bound.rake(samp_weight=design_wgt, control=control, margins=margins)
rk_wgt_not_bound = sample_wgt_rk_not_bound.rake(
samp_weight=income_sample2["design_wgt"], control=control, margins=margins, display_iter=True
)
# breakpoint()

# age_grp = {"<18": 21588, age}

0 comments on commit 05947d5

Please sign in to comment.