Skip to content

Commit

Permalink
[ML] Fix XGBoost model import for xgboost>=1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored Apr 20, 2022
1 parent cb839a9 commit 8294224
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions eland/ml/transformers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import_optional_dependency("xgboost", on_version="warn")

from xgboost import Booster, XGBClassifier, XGBRegressor # type: ignore
from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore


class XGBoostForestTransformer(ModelTransformer):
Expand Down Expand Up @@ -125,8 +125,6 @@ def build_forest(self) -> List[Tree]:
:return: A list of Tree objects
"""
self.check_model_booster()

tree_table: pd.DataFrame = self._model.trees_to_dataframe()
transformed_trees = []
curr_tree: Optional[Any] = None
Expand Down Expand Up @@ -155,17 +153,21 @@ def determine_target_type(self) -> str:
def is_objective_supported(self) -> bool:
return False

def check_model_booster(self) -> None:
@staticmethod
def check_model_booster(model: XGBModel) -> None:
# xgboost v1 made booster default to 'None' meaning 'gbtree'
if self._model.booster not in {"dart", "gbtree", None}:
booster = (
model.get_booster().booster
if hasattr(model.get_booster(), "booster")
else model.booster
)
if booster not in {"dart", "gbtree", None}:
raise ValueError(
f"booster must exist and be of type 'dart' or "
f"'gbtree', was {self._model.booster!r}"
f"'gbtree', was {booster!r}"
)

def transform(self) -> Ensemble:
self.check_model_booster()

if not self.is_objective_supported():
raise ValueError(f"Unsupported objective '{self._objective}'")

Expand All @@ -189,6 +191,7 @@ def __init__(self, model: XGBRegressor, feature_names: List[str]):
super().__init__(
model.get_booster(), feature_names, base_score, model.objective
)
XGBoostForestTransformer.check_model_booster(model)

def determine_target_type(self) -> str:
return "regression"
Expand Down Expand Up @@ -226,6 +229,7 @@ def __init__(
model.objective,
classification_labels,
)
XGBoostForestTransformer.check_model_booster(model)
if model.classes_ is None:
n_estimators = model.get_params()["n_estimators"]
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
Expand Down

0 comments on commit 8294224

Please sign in to comment.