diff --git a/pyfixest/did/estimation.py b/pyfixest/did/estimation.py index 0333aa1d..eceabc89 100644 --- a/pyfixest/did/estimation.py +++ b/pyfixest/did/estimation.py @@ -5,6 +5,7 @@ from pyfixest.did.did2s import DID2S, _did2s_estimate, _did2s_vcov from pyfixest.did.lpdid import LPDID from pyfixest.did.twfe import TWFE +from pyfixest.estimation.literals import VcovTypeOptions def event_study( @@ -268,7 +269,7 @@ def lpdid( idname: str, tname: str, gname: str, - vcov: Optional[Union[str, dict[str, str]]] = None, + vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, pre_window: Optional[int] = None, post_window: Optional[int] = None, never_treated: int = 0, diff --git a/pyfixest/did/lpdid.py b/pyfixest/did/lpdid.py index f6d633d1..79954a13 100644 --- a/pyfixest/did/lpdid.py +++ b/pyfixest/did/lpdid.py @@ -6,6 +6,7 @@ from pyfixest.did.did import DID from pyfixest.estimation.estimation import feols from pyfixest.estimation.feols_ import Feols +from pyfixest.estimation.literals import VcovTypeOptions from pyfixest.report.visualize import _coefplot @@ -51,7 +52,7 @@ def __init__( xfml: str, att: bool, cluster: str, - vcov: Optional[Union[str, dict[str, str]]] = None, + vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, pre_window: Optional[int] = None, post_window: Optional[int] = None, never_treated: Optional[int] = 0, @@ -201,7 +202,7 @@ def _lpdid_estimate( pre_window: int, post_window: int, att: bool = True, - vcov: Optional[Union[str, dict[str, str]]] = None, + vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, xfml: Optional[str] = None, ) -> pd.DataFrame: """ @@ -219,7 +220,7 @@ def _lpdid_estimate( Variable name for calendar period. gname: str unit-specific time of initial treatment. - vcov: str + vcov: VcovTypeOptions, dict[str, str], None The name of the cluster variable. If None, then defaults to {"CRV1": idname}. Either "iid", "hetero", or a dictionary, e.g. {"CRV1": idname} or {"CRV3": "idname"}. You can pass anything that is accepted by the vcov diff --git a/pyfixest/estimation/__init__.py b/pyfixest/estimation/__init__.py index 5ef22587..4a8328ae 100644 --- a/pyfixest/estimation/__init__.py +++ b/pyfixest/estimation/__init__.py @@ -1,3 +1,4 @@ +from pyfixest.estimation import literals from pyfixest.estimation.demean_ import ( demean, ) @@ -40,4 +41,5 @@ "Fepois", "Feiv", "FixestMulti", + "literals", ] diff --git a/pyfixest/estimation/estimation.py b/pyfixest/estimation/estimation.py index af8876d7..0afbb5c8 100644 --- a/pyfixest/estimation/estimation.py +++ b/pyfixest/estimation/estimation.py @@ -6,6 +6,12 @@ from pyfixest.estimation.feols_ import Feols from pyfixest.estimation.fepois_ import Fepois from pyfixest.estimation.FixestMulti_ import FixestMulti +from pyfixest.estimation.literals import ( + FixedRmOptions, + SolverOptions, + VcovTypeOptions, + WeightsTypeOptions, +) from pyfixest.utils.dev_utils import DataFrameType from pyfixest.utils.utils import ssc @@ -13,10 +19,10 @@ def feols( fml: str, data: DataFrameType, # type: ignore - vcov: Optional[Union[str, dict[str, str]]] = None, + vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, weights: Union[None, str] = None, ssc: dict[str, Union[str, bool]] = ssc(), - fixef_rm: str = "none", + fixef_rm: FixedRmOptions = "none", fixef_tol=1e-08, collin_tol: float = 1e-10, drop_intercept: bool = False, @@ -24,8 +30,8 @@ def feols( copy_data: bool = True, store_data: bool = True, lean: bool = False, - weights_type: str = "aweights", - solver: str = "np.linalg.solve", + weights_type: WeightsTypeOptions = "aweights", + solver: SolverOptions = "np.linalg.solve", use_compression: bool = False, reps: int = 100, seed: Optional[int] = None, @@ -47,7 +53,7 @@ def feols( data : DataFrameType A pandas or polars dataframe containing the variables in the formula. - vcov : Union[str, dict[str, str]] + vcov : Union[VcovTypeOptions, dict[str, str]] Type of variance-covariance matrix for inference. Options include "iid", "hetero", "HC1", "HC2", "HC3", or a dictionary for CRV1/CRV3 inference. @@ -59,7 +65,7 @@ def feols( ssc : str A ssc object specifying the small sample correction for inference. - fixef_rm : str + fixef_rm : FixedRmOptions Specifies whether to drop singleton fixed effects. Options: "none" (default), "singleton". @@ -101,12 +107,12 @@ def feols( to obtain the appropriate standard-errors at estimation time, since obtaining different SEs won't be possible afterwards. - weights_type: str, optional + weights_type: WeightsTypeOptions, optional Options include `aweights` or `fweights`. `aweights` implement analytic or precision weights, while `fweights` implement frequency weights. For details see this blog post: https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/. - solver : str, optional. + solver : SolverOptions, optional. The solver to use for the regression. Can be either "np.linalg.solve" or "np.linalg.lstsq". Defaults to "np.linalg.solve". @@ -412,15 +418,15 @@ class for multiple models specified via `fml`. def fepois( fml: str, data: DataFrameType, # type: ignore - vcov: Optional[Union[str, dict[str, str]]] = None, + vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, ssc: dict[str, Union[str, bool]] = ssc(), - fixef_rm: str = "none", + fixef_rm: FixedRmOptions = "none", fixef_tol: float = 1e-08, iwls_tol: float = 1e-08, iwls_maxiter: int = 25, collin_tol: float = 1e-10, separation_check: Optional[list[str]] = ["fe"], - solver: str = "np.linalg.solve", + solver: SolverOptions = "np.linalg.solve", drop_intercept: bool = False, i_ref1=None, copy_data: bool = True, @@ -449,14 +455,14 @@ def fepois( data : DataFrameType A pandas or polars dataframe containing the variables in the formula. - vcov : Union[str, dict[str, str]] + vcov : Union[VcovTypeOptions, dict[str, str]] Type of variance-covariance matrix for inference. Options include "iid", "hetero", "HC1", "HC2", "HC3", or a dictionary for CRV1/CRV3 inference. ssc : str A ssc object specifying the small sample correction for inference. - fixef_rm : str + fixef_rm : FixedRmOptions Specifies whether to drop singleton fixed effects. Options: "none" (default), "singleton". @@ -476,7 +482,7 @@ def fepois( Methods to identify and drop separated observations. Either "fe" or "ir". Executes "fe" by default. - solver : str, optional. + solver : SolverOptions, optional. The solver to use for the regression. Can be either "np.linalg.solve" or "np.linalg.lstsq". Defaults to "np.linalg.solve". diff --git a/pyfixest/estimation/feols_.py b/pyfixest/estimation/feols_.py index de2fc4aa..17907d40 100644 --- a/pyfixest/estimation/feols_.py +++ b/pyfixest/estimation/feols_.py @@ -2,7 +2,7 @@ import gc import warnings from importlib import import_module -from typing import Literal, Optional, Union, get_args +from typing import Optional, Union import numba as nb import numpy as np @@ -15,6 +15,7 @@ from pyfixest.errors import VcovTypeNotSupportedError from pyfixest.estimation.demean_ import demean_model from pyfixest.estimation.FormulaParser import FixestFormula +from pyfixest.estimation.literals import PredictionType, _validate_literal_argument from pyfixest.estimation.model_matrix_fixest_ import model_matrix_fixest from pyfixest.estimation.ritest import ( _decode_resampvar, @@ -40,8 +41,6 @@ ) from pyfixest.utils.utils import get_ssc, simultaneous_crit_val -prediction_type = Literal["response", "link"] - class Feols: """ @@ -1557,7 +1556,7 @@ def predict( newdata: Optional[DataFrameType] = None, atol: float = 1e-6, btol: float = 1e-6, - type: prediction_type = "link", + type: PredictionType = "link", ) -> np.ndarray: """ Predict values of the model on new data. @@ -1597,11 +1596,8 @@ def predict( raise NotImplementedError( "The predict() method is currently not supported for IV models." ) - valid_types = get_args(prediction_type) - if type not in valid_types: - raise ValueError( - f"Invalid prediction type. Expecting one of {valid_types}. Got {type}" - ) + + _validate_literal_argument(type, PredictionType) if newdata is None: if type == "link" or self._method == "feols": diff --git a/pyfixest/estimation/feols_compressed_.py b/pyfixest/estimation/feols_compressed_.py index 88f2b805..a4bac1ef 100644 --- a/pyfixest/estimation/feols_compressed_.py +++ b/pyfixest/estimation/feols_compressed_.py @@ -6,7 +6,7 @@ import polars as pl from tqdm import tqdm -from pyfixest.estimation.feols_ import Feols, prediction_type +from pyfixest.estimation.feols_ import Feols, PredictionType from pyfixest.estimation.FormulaParser import FixestFormula from pyfixest.utils.dev_utils import DataFrameType @@ -331,7 +331,7 @@ def predict( newdata: Optional[DataFrameType] = None, atol: float = 1e-6, btol: float = 1e-6, - type: prediction_type = "link", + type: PredictionType = "link", ) -> np.ndarray: """ Compute predicted values. diff --git a/pyfixest/estimation/fepois_.py b/pyfixest/estimation/fepois_.py index 393fcd1e..0ea65797 100644 --- a/pyfixest/estimation/fepois_.py +++ b/pyfixest/estimation/fepois_.py @@ -10,7 +10,7 @@ NotImplementedError, ) from pyfixest.estimation.demean_ import demean -from pyfixest.estimation.feols_ import Feols, prediction_type +from pyfixest.estimation.feols_ import Feols, PredictionType from pyfixest.estimation.FormulaParser import FixestFormula from pyfixest.utils.dev_utils import DataFrameType, _to_integer @@ -350,7 +350,7 @@ def predict( newdata: Optional[DataFrameType] = None, atol: float = 1e-6, btol: float = 1e-6, - type: prediction_type = "link", + type: PredictionType = "link", ) -> np.ndarray: """ Return predicted values from regression model. diff --git a/pyfixest/estimation/literals.py b/pyfixest/estimation/literals.py new file mode 100644 index 00000000..eec5d436 --- /dev/null +++ b/pyfixest/estimation/literals.py @@ -0,0 +1,40 @@ +from typing import Any, Literal, get_args + +PredictionType = Literal["response", "link"] +VcovTypeOptions = Literal["iid", "hetero", "HC1", "HC2", "HC3"] +WeightsTypeOptions = Literal["aweights", "fweights"] +FixedRmOptions = Literal["singleton", "none"] +SolverOptions = Literal["np.linalg.solve", "np.linalg.lstsq"] + + +def _validate_literal_argument(arg: Any, literal: Any) -> None: + """ + Validate if the given argument matches one of the allowed literal types. + + This function checks whether the provided `arg` is among the valid types + returned by `get_args(literal)`. If not, it raises a ValueError with an + appropriate error message. + + Parameters + ---------- + arg : Any + The argument to validate. + literal : Any + A Literal type that defines the allowed values for `arg`. + + Raises + ------ + TypeError + If `literal` does not have valid types. + ValueError + If `arg` is not one of the valid types defined by `literal`. + """ + valid_types = get_args(literal) + + if len(valid_types) < 1: + raise TypeError( + f"{literal} must be a Literal[...] type argument with least one type" + ) + + if arg not in valid_types: + raise ValueError(f"Invalid argument. Expecting one of {valid_types}. Got {arg}")