Fix monotone constraint with tuple input. (#7891)

This commit is contained in:
Jiaming Yuan 2022-05-13 04:00:03 +08:00 committed by GitHub
parent 94ca52b7b7
commit db80671d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -1405,10 +1405,12 @@ class Booster:
self.set_param(params_processed or {})
def _transform_monotone_constrains(
self, value: Union[Dict[str, int], str]
self, value: Union[Dict[str, int], str, Tuple[int, ...]]
) -> Union[Tuple[int, ...], str]:
if isinstance(value, str):
return value
if isinstance(value, tuple):
return value
constrained_features = set(value.keys())
feature_names = self.feature_names or []

View File

@ -93,6 +93,11 @@ class TestMonotoneConstraints:
constrained = xgb.train(params_for_constrained, training_dset)
assert is_correctly_constrained(constrained)
def test_monotone_constraints_tuple(self) -> None:
params_for_constrained = {"monotone_constraints": (1, -1)}
constrained = xgb.train(params_for_constrained, training_dset)
assert is_correctly_constrained(constrained)
@pytest.mark.parametrize('format', [dict, list])
def test_monotone_constraints_feature_names(self, format):