Skip to content

Commit

Permalink
some cleanup on the testing files
Browse files Browse the repository at this point in the history
  • Loading branch information
MamadouSDiallo committed Jan 4, 2024
1 parent 8b4332d commit b6b398e
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 31 deletions.
3 changes: 0 additions & 3 deletions src/samplics/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
from samplics.apis.sae.area_eblup import fit_eblup, predict_eblup


__all__ = ["fit_eblup", "predict_eblup"]
2 changes: 1 addition & 1 deletion src/samplics/apis/sae/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from samplics.apis.sae.area_eblup import _log_likelihood, fit_eblup, predict_eblup
from samplics.apis.sae.area_eblup import _log_likelihood
from samplics.types.errors import (
CertaintyError,
DimensionError,
Expand Down
28 changes: 6 additions & 22 deletions src/samplics/apis/sae/area_eblup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# Fitting a EBLUP model
def fit_eblup(
def _fit_eblup(
y: DirectEst,
x: AuxVars,
method: FitMethod,
Expand Down Expand Up @@ -317,9 +317,7 @@ def _fixed_coefficients(
) -> tuple[np.ndarray, np.ndarray]:
y_vec = y.to_numpy(keep_vars="est").flatten()
if intercept:
x_mat = np.insert(
x.to_numpy(drop_vars=["__record_id", "__domain"]), 0, 1, axis=1
) # add the intercept
x_mat = np.insert(x.to_numpy(drop_vars=["__record_id", "__domain"]), 0, 1, axis=1) # add the intercept
else:
x_mat = x.to_numpy(drop_vars=["__record_id", "__domain"])

Expand Down Expand Up @@ -434,7 +432,6 @@ def _log_likelihood(
intercept=intercept,
)
case FitMethod.reml:
# breakpoint()
loglike = _log_likelihood_reml(
method=method,
y=y,
Expand All @@ -450,7 +447,7 @@ def _log_likelihood(
return loglike


def predict_eblup(
def _predict_eblup(
x: AuxVars,
fit_eblup: GlmmFitStats,
y: DirectEst,
Expand Down Expand Up @@ -480,9 +477,7 @@ def predict_eblup(
b_const=b_const,
)

return EblupEst(
pred=est, fit_stats=fit_eblup, domain=None, mse=mse, mse_boot=None, mse_jkn=None
)
return EblupEst(pred=est, fit_stats=fit_eblup, domain=None, mse=mse, mse_boot=None, mse_jkn=None)


def _eblup_estimates(
Expand All @@ -496,16 +491,7 @@ def _eblup_estimates(
sigma2_v_cov: Number,
intercept: bool,
b_const: dict,
) -> tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,]:
m = len(yhat)
b_const_vec = np.array(list(b_const.values()))
v_i = np.array(list(sigma2_e.values())) + sigma2_v * (b_const_vec**2)
Expand All @@ -515,9 +501,7 @@ def _eblup_estimates(
b = (G @ np.transpose(Z)) @ v_inv

if intercept:
x = np.insert(
auxvars.to_numpy(drop_vars=["__record_id", "__domain"]), 0, 1, axis=1
) # add the intercept
x = np.insert(auxvars.to_numpy(drop_vars=["__record_id", "__domain"]), 0, 1, axis=1) # add the intercept
else:
x = auxvars.to_numpy(drop_vars=["__record_id", "__domain"])

Expand Down
2 changes: 1 addition & 1 deletion src/samplics/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
__domains = None
if domain is not None:
__domain = numpy_array(domain).tolist()
auxdata_dict = x.insert_at_idx(0, pl.Series(__domain).alias("__domain")).partition_by(
auxdata_dict = x.insert_column(0, pl.Series(__domain).alias("__domain")).partition_by(
"__domain", as_dict=True
)

Expand Down
8 changes: 4 additions & 4 deletions tests/apis/sae/test_area_eblup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars as pl
import pytest

from samplics.apis.sae import fit_eblup, predict_eblup
from samplics.apis.sae import _fit_eblup, _predict_eblup

# from samplics.apis.sae import _log_likelihood, fit_eblup, predict_eblup
from samplics.types import AuxVars, DirectEst, FitMethod, Mse
Expand Down Expand Up @@ -49,13 +49,13 @@
auxvars = AuxVars(x=x, domain=area)

# Fit the linear mixed model
fit_ml = fit_eblup(y=yhat, x=auxvars, method=FitMethod.ml)
fit_reml = fit_eblup(y=yhat, x=auxvars, method=FitMethod.reml)
fit_ml = _fit_eblup(y=yhat, x=auxvars, method=FitMethod.ml)
fit_reml = _fit_eblup(y=yhat, x=auxvars, method=FitMethod.reml)
# fit_fh = fit_eblup(y=yhat, x=auxvars, method=FitMethod.fh)
# breakpoint()

# Predict the small area estimates
est_milk_reml = predict_eblup(x=auxvars, fit_eblup=fit_reml, y=yhat, mse=Mse.taylor)
est_milk_reml = _predict_eblup(x=auxvars, fit_eblup=fit_reml, y=yhat, mse=Mse.taylor)

# est_milk_reml.fit_stats.log_llike

Expand Down

0 comments on commit b6b398e

Please sign in to comment.