From db80671d6b79715ce31209e8f17a61b8498ea86d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 13 May 2022 04:00:03 +0800 Subject: [PATCH] Fix monotone constraint with tuple input. (#7891) --- python-package/xgboost/core.py | 4 +++- tests/python/test_monotone_constraints.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5972db02f..0a84feb96 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1405,10 +1405,12 @@ class Booster: self.set_param(params_processed or {}) def _transform_monotone_constrains( - self, value: Union[Dict[str, int], str] + self, value: Union[Dict[str, int], str, Tuple[int, ...]] ) -> Union[Tuple[int, ...], str]: if isinstance(value, str): return value + if isinstance(value, tuple): + return value constrained_features = set(value.keys()) feature_names = self.feature_names or [] diff --git a/tests/python/test_monotone_constraints.py b/tests/python/test_monotone_constraints.py index c46569f6a..ae2c2917d 100644 --- a/tests/python/test_monotone_constraints.py +++ b/tests/python/test_monotone_constraints.py @@ -93,6 +93,11 @@ class TestMonotoneConstraints: constrained = xgb.train(params_for_constrained, training_dset) assert is_correctly_constrained(constrained) + def test_monotone_constraints_tuple(self) -> None: + params_for_constrained = {"monotone_constraints": (1, -1)} + constrained = xgb.train(params_for_constrained, training_dset) + assert is_correctly_constrained(constrained) + @pytest.mark.parametrize('format', [dict, list]) def test_monotone_constraints_feature_names(self, format):