Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] add parameters in the ctor of all estimators. #9202

Merged
merged 13 commits into from
May 28, 2023
20 changes: 14 additions & 6 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,9 @@ def _validate_params(self) -> None:

if self.getOrDefault(self.features_cols):
if not self.getOrDefault(self.use_gpu):
raise ValueError("features_cols param requires enabling use_gpu.")

get_logger(self.__class__.__name__).warning(
"If features_cols param set, then features_col param is ignored."
)
raise ValueError(
"features_col param with list value requires enabling use_gpu."
)

if self.getOrDefault("objective") is not None:
if not isinstance(self.getOrDefault("objective"), str):
Expand Down Expand Up @@ -547,6 +545,8 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:


class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_input_kwargs: Dict[str, Any]

def __init__(self) -> None:
super().__init__()
self._set_xgb_params_default()
Expand Down Expand Up @@ -576,6 +576,11 @@ def setParams(
raise ValueError("Invalid param name: 'arbitrary_params_dict'.")

for k, v in kwargs.items():
# We're not allowing user use features_cols directly.
if k == self.features_cols.name:
raise ValueError(
f"Unsupported param '{k}' please use features_col instead."
)
if k in _inverse_pyspark_param_alias_map:
raise ValueError(
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
Expand All @@ -591,7 +596,10 @@ def setParams(
k = real_k

if self.hasParam(k):
self._set(**{str(k): v})
if k == "features_col" and isinstance(v, list):
self._set(**{"features_cols": v})
else:
self._set(**{str(k): v})
else:
if (
k in _unsupported_xgb_params
Expand Down
Loading