Fix monotone constraint with tuple input. (#7891)
This commit is contained in:
parent
94ca52b7b7
commit
db80671d6b
@ -1405,10 +1405,12 @@ class Booster:
|
|||||||
self.set_param(params_processed or {})
|
self.set_param(params_processed or {})
|
||||||
|
|
||||||
def _transform_monotone_constrains(
|
def _transform_monotone_constrains(
|
||||||
self, value: Union[Dict[str, int], str]
|
self, value: Union[Dict[str, int], str, Tuple[int, ...]]
|
||||||
) -> Union[Tuple[int, ...], str]:
|
) -> Union[Tuple[int, ...], str]:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value
|
return value
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return value
|
||||||
|
|
||||||
constrained_features = set(value.keys())
|
constrained_features = set(value.keys())
|
||||||
feature_names = self.feature_names or []
|
feature_names = self.feature_names or []
|
||||||
|
|||||||
@ -93,6 +93,11 @@ class TestMonotoneConstraints:
|
|||||||
constrained = xgb.train(params_for_constrained, training_dset)
|
constrained = xgb.train(params_for_constrained, training_dset)
|
||||||
assert is_correctly_constrained(constrained)
|
assert is_correctly_constrained(constrained)
|
||||||
|
|
||||||
|
def test_monotone_constraints_tuple(self) -> None:
|
||||||
|
params_for_constrained = {"monotone_constraints": (1, -1)}
|
||||||
|
constrained = xgb.train(params_for_constrained, training_dset)
|
||||||
|
assert is_correctly_constrained(constrained)
|
||||||
|
|
||||||
@pytest.mark.parametrize('format', [dict, list])
|
@pytest.mark.parametrize('format', [dict, list])
|
||||||
def test_monotone_constraints_feature_names(self, format):
|
def test_monotone_constraints_feature_names(self, format):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user