diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5972db02f669..0a84feb96b11 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1405,10 +1405,12 @@ def __init__( self.set_param(params_processed or {}) def _transform_monotone_constrains( - self, value: Union[Dict[str, int], str] + self, value: Union[Dict[str, int], str, Tuple[int, ...]] ) -> Union[Tuple[int, ...], str]: if isinstance(value, str): return value + if isinstance(value, tuple): + return value constrained_features = set(value.keys()) feature_names = self.feature_names or [] diff --git a/tests/python/test_monotone_constraints.py b/tests/python/test_monotone_constraints.py index c46569f6a238..ae2c2917d7ff 100644 --- a/tests/python/test_monotone_constraints.py +++ b/tests/python/test_monotone_constraints.py @@ -93,6 +93,11 @@ def test_monotone_constraints(self, tree_method: str, policy: str) -> None: constrained = xgb.train(params_for_constrained, training_dset) assert is_correctly_constrained(constrained) + def test_monotone_constraints_tuple(self) -> None: + params_for_constrained = {"monotone_constraints": (1, -1)} + constrained = xgb.train(params_for_constrained, training_dset) + assert is_correctly_constrained(constrained) + @pytest.mark.parametrize('format', [dict, list]) def test_monotone_constraints_feature_names(self, format):