diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 25b62a5a6..eac3497b8 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1193,6 +1193,7 @@ class Booster(object): params = params or {} params = self._configure_metrics(params.copy()) + params = self._configure_constraints(params) if isinstance(params, list): params.append(('validate_parameters', True)) else: @@ -1233,6 +1234,68 @@ class Booster(object): params += [('eval_metric', eval_metric)] return params + def _transform_monotone_constrains(self, value: Union[dict, str]) -> str: + if isinstance(value, str): + return value + + constrained_features = set(value.keys()) + if not constrained_features.issubset(set(self.feature_names or [])): + raise ValueError('Constrained features are not a subset of ' + 'training data feature names') + + return '(' + ','.join([str(value.get(feature_name, 0)) + for feature_name in self.feature_names]) + ')' + + def _transform_interaction_constraints(self, value: Union[list, str]) -> str: + if isinstance(value, str): + return value + + feature_idx_mapping = {k: str(v) for v, k in enumerate(self.feature_names or [])} + + try: + s = "[" + for constraint in value: + s += ( + "[" + + ",".join( + [feature_idx_mapping[feature_name] for feature_name in constraint] + ) + + "]" + ) + return s + "]" + except KeyError as e: + # pylint: disable=raise-missing-from + raise ValueError( + "Constrained features are not a subset of training data feature names" + ) from e + + def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List]: + if isinstance(params, dict): + value = params.get("monotone_constraints") + if value: + params[ + "monotone_constraints" + ] = self._transform_monotone_constrains(value) + + value = params.get("interaction_constraints") + if value: + params[ + "interaction_constraints" + ] = self._transform_interaction_constraints(value) + + elif isinstance(params, list): + for idx, param in enumerate(params): + name, value = param + if not value: + continue + + if name == "monotone_constraints": + params[idx] = (name, self._transform_monotone_constrains(value)) + elif name == "interaction_constraints": + params[idx] = (name, self._transform_interaction_constraints(value)) + + return params + def __del__(self): if hasattr(self, 'handle') and self.handle is not None: _check_call(_LIB.XGBoosterFree(self.handle)) diff --git a/tests/python/test_interaction_constraints.py b/tests/python/test_interaction_constraints.py index cebda8159..6c3a442a6 100644 --- a/tests/python/test_interaction_constraints.py +++ b/tests/python/test_interaction_constraints.py @@ -9,7 +9,7 @@ rng = np.random.RandomState(1994) class TestInteractionConstraints: - def run_interaction_constraints(self, tree_method): + def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'): x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) x3 = np.random.choice([1, 2, 3], size=1000, replace=True) @@ -17,13 +17,13 @@ class TestInteractionConstraints: + np.random.normal( loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1) X = np.column_stack((x1, x2, x3)) - dtrain = xgboost.DMatrix(X, label=y) + dtrain = xgboost.DMatrix(X, label=y, feature_names=feature_names) params = { 'max_depth': 3, 'eta': 0.1, 'nthread': 2, - 'interaction_constraints': '[[0, 1]]', + 'interaction_constraints': interaction_constraints, 'tree_method': tree_method } num_boost_round = 12 @@ -35,7 +35,7 @@ class TestInteractionConstraints: # by the same amount def f(x): tmat = xgboost.DMatrix( - np.column_stack((x1, x2, np.repeat(x, 1000)))) + np.column_stack((x1, x2, np.repeat(x, 1000))), feature_names=feature_names) return bst.predict(tmat) preds = [f(x) for x in [1, 2, 3]] @@ -57,6 +57,26 @@ class TestInteractionConstraints: def test_approx_interaction_constraints(self): self.run_interaction_constraints(tree_method='approx') + def test_interaction_constraints_feature_names(self): + with pytest.raises(ValueError): + constraints = [('feature_0', 'feature_1')] + self.run_interaction_constraints(tree_method='exact', + interaction_constraints=constraints) + + with pytest.raises(ValueError): + constraints = [('feature_0', 'feature_3')] + feature_names = ['feature_0', 'feature_1', 'feature_2'] + self.run_interaction_constraints(tree_method='exact', + feature_names=feature_names, + interaction_constraints=constraints) + + + constraints = [('feature_0', 'feature_1')] + feature_names = ['feature_0', 'feature_1', 'feature_2'] + self.run_interaction_constraints(tree_method='exact', + feature_names=feature_names, + interaction_constraints=constraints) + @pytest.mark.skipif(**tm.no_sklearn()) def training_accuracy(self, tree_method): from sklearn.metrics import accuracy_score diff --git a/tests/python/test_monotone_constraints.py b/tests/python/test_monotone_constraints.py index 066324955..8e29a53fa 100644 --- a/tests/python/test_monotone_constraints.py +++ b/tests/python/test_monotone_constraints.py @@ -14,7 +14,7 @@ def is_decreasing(y): return np.count_nonzero(np.diff(y) > 0.0) == 0 -def is_correctly_constrained(learner): +def is_correctly_constrained(learner, feature_names=None): n = 100 variable_x = np.linspace(0, 1, n).reshape((n, 1)) fixed_xs_values = np.linspace(0, 1, n) @@ -22,13 +22,15 @@ def is_correctly_constrained(learner): for i in range(n): fixed_x = fixed_xs_values[i] * np.ones((n, 1)) monotonically_increasing_x = np.column_stack((variable_x, fixed_x)) - monotonically_increasing_dset = xgb.DMatrix(monotonically_increasing_x) + monotonically_increasing_dset = xgb.DMatrix(monotonically_increasing_x, + feature_names=feature_names) monotonically_increasing_y = learner.predict( monotonically_increasing_dset ) monotonically_decreasing_x = np.column_stack((fixed_x, variable_x)) - monotonically_decreasing_dset = xgb.DMatrix(monotonically_decreasing_x) + monotonically_decreasing_dset = xgb.DMatrix(monotonically_decreasing_x, + feature_names=feature_names) monotonically_decreasing_y = learner.predict( monotonically_decreasing_dset ) @@ -101,6 +103,38 @@ class TestMonotoneConstraints: assert is_correctly_constrained(constrained_hist_method) + @pytest.mark.parametrize('format', [dict, list]) + def test_monotone_constraints_feature_names(self, format): + + # next check monotonicity when initializing monotone_constraints by feature names + params = { + 'tree_method': 'hist', 'verbosity': 1, + 'grow_policy': 'lossguide', + 'monotone_constraints': {'feature_0': 1, 'feature_1': -1} + } + + if format == list: + params = list(params.items()) + + with pytest.raises(ValueError): + xgb.train(params, training_dset) + + feature_names =[ 'feature_0', 'feature_2'] + training_dset_w_feature_names = xgb.DMatrix(x, label=y, feature_names=feature_names) + + with pytest.raises(ValueError): + xgb.train(params, training_dset_w_feature_names) + + feature_names =[ 'feature_0', 'feature_1'] + training_dset_w_feature_names = xgb.DMatrix(x, label=y, feature_names=feature_names) + + constrained_learner = xgb.train( + params, training_dset_w_feature_names + ) + + assert is_correctly_constrained(constrained_learner, feature_names) + + @pytest.mark.skipif(**tm.no_sklearn()) def test_training_accuracy(self): from sklearn.metrics import accuracy_score