diff --git a/eland/ml/transformers/xgboost.py b/eland/ml/transformers/xgboost.py index 9eb1d189..67a71b7d 100644 --- a/eland/ml/transformers/xgboost.py +++ b/eland/ml/transformers/xgboost.py @@ -27,7 +27,7 @@ import_optional_dependency("xgboost", on_version="warn") -from xgboost import Booster, XGBClassifier, XGBRegressor # type: ignore +from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore class XGBoostForestTransformer(ModelTransformer): @@ -125,8 +125,6 @@ def build_forest(self) -> List[Tree]: :return: A list of Tree objects """ - self.check_model_booster() - tree_table: pd.DataFrame = self._model.trees_to_dataframe() transformed_trees = [] curr_tree: Optional[Any] = None @@ -155,17 +153,21 @@ def determine_target_type(self) -> str: def is_objective_supported(self) -> bool: return False - def check_model_booster(self) -> None: + @staticmethod + def check_model_booster(model: XGBModel) -> None: # xgboost v1 made booster default to 'None' meaning 'gbtree' - if self._model.booster not in {"dart", "gbtree", None}: + booster = ( + model.get_booster().booster + if hasattr(model.get_booster(), "booster") + else model.booster + ) + if booster not in {"dart", "gbtree", None}: raise ValueError( f"booster must exist and be of type 'dart' or " - f"'gbtree', was {self._model.booster!r}" + f"'gbtree', was {booster!r}" ) def transform(self) -> Ensemble: - self.check_model_booster() - if not self.is_objective_supported(): raise ValueError(f"Unsupported objective '{self._objective}'") @@ -189,6 +191,7 @@ def __init__(self, model: XGBRegressor, feature_names: List[str]): super().__init__( model.get_booster(), feature_names, base_score, model.objective ) + XGBoostForestTransformer.check_model_booster(model) def determine_target_type(self) -> str: return "regression" @@ -226,6 +229,7 @@ def __init__( model.objective, classification_labels, ) + XGBoostForestTransformer.check_model_booster(model) if model.classes_ is None: n_estimators = model.get_params()["n_estimators"] num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1