Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding literals to feols and fepois api's #680

Merged
merged 11 commits into from
Nov 2, 2024
3 changes: 2 additions & 1 deletion pyfixest/did/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyfixest.did.did2s import DID2S, _did2s_estimate, _did2s_vcov
from pyfixest.did.lpdid import LPDID
from pyfixest.did.twfe import TWFE
from pyfixest.estimation.literals import VcovTypeOptions


def event_study(
Expand Down Expand Up @@ -268,7 +269,7 @@ def lpdid(
idname: str,
tname: str,
gname: str,
vcov: Optional[Union[str, dict[str, str]]] = None,
vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None,
pre_window: Optional[int] = None,
post_window: Optional[int] = None,
never_treated: int = 0,
Expand Down
7 changes: 4 additions & 3 deletions pyfixest/did/lpdid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyfixest.did.did import DID
from pyfixest.estimation.estimation import feols
from pyfixest.estimation.feols_ import Feols
from pyfixest.estimation.literals import VcovTypeOptions
from pyfixest.report.visualize import _coefplot


Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
xfml: str,
att: bool,
cluster: str,
vcov: Optional[Union[str, dict[str, str]]] = None,
vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None,
pre_window: Optional[int] = None,
post_window: Optional[int] = None,
never_treated: Optional[int] = 0,
Expand Down Expand Up @@ -201,7 +202,7 @@ def _lpdid_estimate(
pre_window: int,
post_window: int,
att: bool = True,
vcov: Optional[Union[str, dict[str, str]]] = None,
vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None,
xfml: Optional[str] = None,
) -> pd.DataFrame:
"""
Expand All @@ -219,7 +220,7 @@ def _lpdid_estimate(
Variable name for calendar period.
gname: str
unit-specific time of initial treatment.
vcov: str
vcov: VcovTypeOptions, dict[str, str], None
The name of the cluster variable. If None, then defaults to {"CRV1": idname}.
Either "iid", "hetero", or a dictionary, e.g. {"CRV1": idname} or
{"CRV3": "idname"}. You can pass anything that is accepted by the vcov
Expand Down
2 changes: 2 additions & 0 deletions pyfixest/estimation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pyfixest.estimation import literals
from pyfixest.estimation.demean_ import (
demean,
)
Expand Down Expand Up @@ -40,4 +41,5 @@
"Fepois",
"Feiv",
"FixestMulti",
"literals",
]
34 changes: 20 additions & 14 deletions pyfixest/estimation/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,32 @@
from pyfixest.estimation.feols_ import Feols
from pyfixest.estimation.fepois_ import Fepois
from pyfixest.estimation.FixestMulti_ import FixestMulti
from pyfixest.estimation.literals import (
FixedRmOptions,
SolverOptions,
VcovTypeOptions,
WeightsTypeOptions,
)
from pyfixest.utils.dev_utils import DataFrameType
from pyfixest.utils.utils import ssc


def feols(
fml: str,
data: DataFrameType, # type: ignore
vcov: Optional[Union[str, dict[str, str]]] = None,
vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None,
weights: Union[None, str] = None,
ssc: dict[str, Union[str, bool]] = ssc(),
fixef_rm: str = "none",
fixef_rm: FixedRmOptions = "none",
fixef_tol=1e-08,
collin_tol: float = 1e-10,
drop_intercept: bool = False,
i_ref1=None,
copy_data: bool = True,
store_data: bool = True,
lean: bool = False,
weights_type: str = "aweights",
solver: str = "np.linalg.solve",
weights_type: WeightsTypeOptions = "aweights",
solver: SolverOptions = "np.linalg.solve",
use_compression: bool = False,
reps: int = 100,
seed: Optional[int] = None,
Expand All @@ -47,7 +53,7 @@ def feols(
data : DataFrameType
A pandas or polars dataframe containing the variables in the formula.

vcov : Union[str, dict[str, str]]
vcov : Union[VcovTypeOptions, dict[str, str]]
Type of variance-covariance matrix for inference. Options include "iid",
"hetero", "HC1", "HC2", "HC3", or a dictionary for CRV1/CRV3 inference.

Expand All @@ -59,7 +65,7 @@ def feols(
ssc : str
A ssc object specifying the small sample correction for inference.

fixef_rm : str
fixef_rm : FixedRmOptions
Specifies whether to drop singleton fixed effects.
Options: "none" (default), "singleton".

Expand Down Expand Up @@ -101,12 +107,12 @@ def feols(
to obtain the appropriate standard-errors at estimation time,
since obtaining different SEs won't be possible afterwards.

weights_type: str, optional
weights_type: WeightsTypeOptions, optional
Options include `aweights` or `fweights`. `aweights` implement analytic or
precision weights, while `fweights` implement frequency weights. For details
see this blog post: https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/.

solver : str, optional.
solver : SolverOptions, optional.
The solver to use for the regression. Can be either "np.linalg.solve" or
"np.linalg.lstsq". Defaults to "np.linalg.solve".

Expand Down Expand Up @@ -412,15 +418,15 @@ class for multiple models specified via `fml`.
def fepois(
fml: str,
data: DataFrameType, # type: ignore
vcov: Optional[Union[str, dict[str, str]]] = None,
vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None,
ssc: dict[str, Union[str, bool]] = ssc(),
fixef_rm: str = "none",
fixef_rm: FixedRmOptions = "none",
fixef_tol: float = 1e-08,
iwls_tol: float = 1e-08,
iwls_maxiter: int = 25,
collin_tol: float = 1e-10,
separation_check: Optional[list[str]] = ["fe"],
solver: str = "np.linalg.solve",
solver: SolverOptions = "np.linalg.solve",
drop_intercept: bool = False,
i_ref1=None,
copy_data: bool = True,
Expand Down Expand Up @@ -449,14 +455,14 @@ def fepois(
data : DataFrameType
A pandas or polars dataframe containing the variables in the formula.

vcov : Union[str, dict[str, str]]
vcov : Union[VcovTypeOptions, dict[str, str]]
Type of variance-covariance matrix for inference. Options include "iid",
"hetero", "HC1", "HC2", "HC3", or a dictionary for CRV1/CRV3 inference.

ssc : str
A ssc object specifying the small sample correction for inference.

fixef_rm : str
fixef_rm : FixedRmOptions
Specifies whether to drop singleton fixed effects.
Options: "none" (default), "singleton".

Expand All @@ -476,7 +482,7 @@ def fepois(
Methods to identify and drop separated observations.
Either "fe" or "ir". Executes "fe" by default.

solver : str, optional.
solver : SolverOptions, optional.
The solver to use for the regression. Can be either "np.linalg.solve" or
"np.linalg.lstsq". Defaults to "np.linalg.solve".

Expand Down
14 changes: 5 additions & 9 deletions pyfixest/estimation/feols_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import gc
import warnings
from importlib import import_module
from typing import Literal, Optional, Union, get_args
from typing import Optional, Union

import numba as nb
import numpy as np
Expand All @@ -15,6 +15,7 @@
from pyfixest.errors import VcovTypeNotSupportedError
from pyfixest.estimation.demean_ import demean_model
from pyfixest.estimation.FormulaParser import FixestFormula
from pyfixest.estimation.literals import PredictionType, _validate_literal_argument
from pyfixest.estimation.model_matrix_fixest_ import model_matrix_fixest
from pyfixest.estimation.ritest import (
_decode_resampvar,
Expand All @@ -40,8 +41,6 @@
)
from pyfixest.utils.utils import get_ssc, simultaneous_crit_val

prediction_type = Literal["response", "link"]


class Feols:
"""
Expand Down Expand Up @@ -1557,7 +1556,7 @@ def predict(
newdata: Optional[DataFrameType] = None,
atol: float = 1e-6,
btol: float = 1e-6,
type: prediction_type = "link",
type: PredictionType = "link",
) -> np.ndarray:
"""
Predict values of the model on new data.
Expand Down Expand Up @@ -1597,11 +1596,8 @@ def predict(
raise NotImplementedError(
"The predict() method is currently not supported for IV models."
)
valid_types = get_args(prediction_type)
if type not in valid_types:
raise ValueError(
f"Invalid prediction type. Expecting one of {valid_types}. Got {type}"
)

_validate_literal_argument(type, PredictionType)

if newdata is None:
if type == "link" or self._method == "feols":
Expand Down
4 changes: 2 additions & 2 deletions pyfixest/estimation/feols_compressed_.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import polars as pl
from tqdm import tqdm

from pyfixest.estimation.feols_ import Feols, prediction_type
from pyfixest.estimation.feols_ import Feols, PredictionType
from pyfixest.estimation.FormulaParser import FixestFormula
from pyfixest.utils.dev_utils import DataFrameType

Expand Down Expand Up @@ -331,7 +331,7 @@ def predict(
newdata: Optional[DataFrameType] = None,
atol: float = 1e-6,
btol: float = 1e-6,
type: prediction_type = "link",
type: PredictionType = "link",
) -> np.ndarray:
"""
Compute predicted values.
Expand Down
4 changes: 2 additions & 2 deletions pyfixest/estimation/fepois_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
NotImplementedError,
)
from pyfixest.estimation.demean_ import demean
from pyfixest.estimation.feols_ import Feols, prediction_type
from pyfixest.estimation.feols_ import Feols, PredictionType
from pyfixest.estimation.FormulaParser import FixestFormula
from pyfixest.utils.dev_utils import DataFrameType, _to_integer

Expand Down Expand Up @@ -350,7 +350,7 @@ def predict(
newdata: Optional[DataFrameType] = None,
atol: float = 1e-6,
btol: float = 1e-6,
type: prediction_type = "link",
type: PredictionType = "link",
) -> np.ndarray:
"""
Return predicted values from regression model.
Expand Down
40 changes: 40 additions & 0 deletions pyfixest/estimation/literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Literal, get_args

PredictionType = Literal["response", "link"]
VcovTypeOptions = Literal["iid", "hetero", "HC1", "HC2", "HC3"]
WeightsTypeOptions = Literal["aweights", "fweights"]
FixedRmOptions = Literal["singleton", "none"]
SolverOptions = Literal["np.linalg.solve", "np.linalg.lstsq"]


def _validate_literal_argument(arg: Any, literal: Any) -> None:
"""
Validate if the given argument matches one of the allowed literal types.

This function checks whether the provided `arg` is among the valid types
returned by `get_args(literal)`. If not, it raises a ValueError with an
appropriate error message.

Parameters
----------
arg : Any
The argument to validate.
literal : Any
A Literal type that defines the allowed values for `arg`.

Raises
------
TypeError
If `literal` does not have valid types.
ValueError
If `arg` is not one of the valid types defined by `literal`.
"""
valid_types = get_args(literal)
marcandre259 marked this conversation as resolved.
Show resolved Hide resolved

if len(valid_types) < 1:
raise TypeError(

Check warning on line 35 in pyfixest/estimation/literals.py

View check run for this annotation

Codecov / codecov/patch

pyfixest/estimation/literals.py#L35

Added line #L35 was not covered by tests
f"{literal} must be a Literal[...] type argument with least one type"
)

if arg not in valid_types:
raise ValueError(f"Invalid argument. Expecting one of {valid_types}. Got {arg}")

Check warning on line 40 in pyfixest/estimation/literals.py

View check run for this annotation

Codecov / codecov/patch

pyfixest/estimation/literals.py#L40

Added line #L40 was not covered by tests
Loading