Support configuring constraints by feature names (#6783)
Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
parent
7e06c81894
commit
aa0d8f20c1
@ -1193,6 +1193,7 @@ class Booster(object):
|
|||||||
|
|
||||||
params = params or {}
|
params = params or {}
|
||||||
params = self._configure_metrics(params.copy())
|
params = self._configure_metrics(params.copy())
|
||||||
|
params = self._configure_constraints(params)
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
params.append(('validate_parameters', True))
|
params.append(('validate_parameters', True))
|
||||||
else:
|
else:
|
||||||
@ -1233,6 +1234,68 @@ class Booster(object):
|
|||||||
params += [('eval_metric', eval_metric)]
|
params += [('eval_metric', eval_metric)]
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def _transform_monotone_constrains(self, value: Union[dict, str]) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
constrained_features = set(value.keys())
|
||||||
|
if not constrained_features.issubset(set(self.feature_names or [])):
|
||||||
|
raise ValueError('Constrained features are not a subset of '
|
||||||
|
'training data feature names')
|
||||||
|
|
||||||
|
return '(' + ','.join([str(value.get(feature_name, 0))
|
||||||
|
for feature_name in self.feature_names]) + ')'
|
||||||
|
|
||||||
|
def _transform_interaction_constraints(self, value: Union[list, str]) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
feature_idx_mapping = {k: str(v) for v, k in enumerate(self.feature_names or [])}
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = "["
|
||||||
|
for constraint in value:
|
||||||
|
s += (
|
||||||
|
"["
|
||||||
|
+ ",".join(
|
||||||
|
[feature_idx_mapping[feature_name] for feature_name in constraint]
|
||||||
|
)
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
return s + "]"
|
||||||
|
except KeyError as e:
|
||||||
|
# pylint: disable=raise-missing-from
|
||||||
|
raise ValueError(
|
||||||
|
"Constrained features are not a subset of training data feature names"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||||
|
if isinstance(params, dict):
|
||||||
|
value = params.get("monotone_constraints")
|
||||||
|
if value:
|
||||||
|
params[
|
||||||
|
"monotone_constraints"
|
||||||
|
] = self._transform_monotone_constrains(value)
|
||||||
|
|
||||||
|
value = params.get("interaction_constraints")
|
||||||
|
if value:
|
||||||
|
params[
|
||||||
|
"interaction_constraints"
|
||||||
|
] = self._transform_interaction_constraints(value)
|
||||||
|
|
||||||
|
elif isinstance(params, list):
|
||||||
|
for idx, param in enumerate(params):
|
||||||
|
name, value = param
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name == "monotone_constraints":
|
||||||
|
params[idx] = (name, self._transform_monotone_constrains(value))
|
||||||
|
elif name == "interaction_constraints":
|
||||||
|
params[idx] = (name, self._transform_interaction_constraints(value))
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, 'handle') and self.handle is not None:
|
if hasattr(self, 'handle') and self.handle is not None:
|
||||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||||
|
|||||||
@ -9,7 +9,7 @@ rng = np.random.RandomState(1994)
|
|||||||
|
|
||||||
|
|
||||||
class TestInteractionConstraints:
|
class TestInteractionConstraints:
|
||||||
def run_interaction_constraints(self, tree_method):
|
def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'):
|
||||||
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||||
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||||
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
|
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
|
||||||
@ -17,13 +17,13 @@ class TestInteractionConstraints:
|
|||||||
+ np.random.normal(
|
+ np.random.normal(
|
||||||
loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1)
|
loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1)
|
||||||
X = np.column_stack((x1, x2, x3))
|
X = np.column_stack((x1, x2, x3))
|
||||||
dtrain = xgboost.DMatrix(X, label=y)
|
dtrain = xgboost.DMatrix(X, label=y, feature_names=feature_names)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'max_depth': 3,
|
'max_depth': 3,
|
||||||
'eta': 0.1,
|
'eta': 0.1,
|
||||||
'nthread': 2,
|
'nthread': 2,
|
||||||
'interaction_constraints': '[[0, 1]]',
|
'interaction_constraints': interaction_constraints,
|
||||||
'tree_method': tree_method
|
'tree_method': tree_method
|
||||||
}
|
}
|
||||||
num_boost_round = 12
|
num_boost_round = 12
|
||||||
@ -35,7 +35,7 @@ class TestInteractionConstraints:
|
|||||||
# by the same amount
|
# by the same amount
|
||||||
def f(x):
|
def f(x):
|
||||||
tmat = xgboost.DMatrix(
|
tmat = xgboost.DMatrix(
|
||||||
np.column_stack((x1, x2, np.repeat(x, 1000))))
|
np.column_stack((x1, x2, np.repeat(x, 1000))), feature_names=feature_names)
|
||||||
return bst.predict(tmat)
|
return bst.predict(tmat)
|
||||||
|
|
||||||
preds = [f(x) for x in [1, 2, 3]]
|
preds = [f(x) for x in [1, 2, 3]]
|
||||||
@ -57,6 +57,26 @@ class TestInteractionConstraints:
|
|||||||
def test_approx_interaction_constraints(self):
|
def test_approx_interaction_constraints(self):
|
||||||
self.run_interaction_constraints(tree_method='approx')
|
self.run_interaction_constraints(tree_method='approx')
|
||||||
|
|
||||||
|
def test_interaction_constraints_feature_names(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
constraints = [('feature_0', 'feature_1')]
|
||||||
|
self.run_interaction_constraints(tree_method='exact',
|
||||||
|
interaction_constraints=constraints)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
constraints = [('feature_0', 'feature_3')]
|
||||||
|
feature_names = ['feature_0', 'feature_1', 'feature_2']
|
||||||
|
self.run_interaction_constraints(tree_method='exact',
|
||||||
|
feature_names=feature_names,
|
||||||
|
interaction_constraints=constraints)
|
||||||
|
|
||||||
|
|
||||||
|
constraints = [('feature_0', 'feature_1')]
|
||||||
|
feature_names = ['feature_0', 'feature_1', 'feature_2']
|
||||||
|
self.run_interaction_constraints(tree_method='exact',
|
||||||
|
feature_names=feature_names,
|
||||||
|
interaction_constraints=constraints)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def training_accuracy(self, tree_method):
|
def training_accuracy(self, tree_method):
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|||||||
@ -14,7 +14,7 @@ def is_decreasing(y):
|
|||||||
return np.count_nonzero(np.diff(y) > 0.0) == 0
|
return np.count_nonzero(np.diff(y) > 0.0) == 0
|
||||||
|
|
||||||
|
|
||||||
def is_correctly_constrained(learner):
|
def is_correctly_constrained(learner, feature_names=None):
|
||||||
n = 100
|
n = 100
|
||||||
variable_x = np.linspace(0, 1, n).reshape((n, 1))
|
variable_x = np.linspace(0, 1, n).reshape((n, 1))
|
||||||
fixed_xs_values = np.linspace(0, 1, n)
|
fixed_xs_values = np.linspace(0, 1, n)
|
||||||
@ -22,13 +22,15 @@ def is_correctly_constrained(learner):
|
|||||||
for i in range(n):
|
for i in range(n):
|
||||||
fixed_x = fixed_xs_values[i] * np.ones((n, 1))
|
fixed_x = fixed_xs_values[i] * np.ones((n, 1))
|
||||||
monotonically_increasing_x = np.column_stack((variable_x, fixed_x))
|
monotonically_increasing_x = np.column_stack((variable_x, fixed_x))
|
||||||
monotonically_increasing_dset = xgb.DMatrix(monotonically_increasing_x)
|
monotonically_increasing_dset = xgb.DMatrix(monotonically_increasing_x,
|
||||||
|
feature_names=feature_names)
|
||||||
monotonically_increasing_y = learner.predict(
|
monotonically_increasing_y = learner.predict(
|
||||||
monotonically_increasing_dset
|
monotonically_increasing_dset
|
||||||
)
|
)
|
||||||
|
|
||||||
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x))
|
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x))
|
||||||
monotonically_decreasing_dset = xgb.DMatrix(monotonically_decreasing_x)
|
monotonically_decreasing_dset = xgb.DMatrix(monotonically_decreasing_x,
|
||||||
|
feature_names=feature_names)
|
||||||
monotonically_decreasing_y = learner.predict(
|
monotonically_decreasing_y = learner.predict(
|
||||||
monotonically_decreasing_dset
|
monotonically_decreasing_dset
|
||||||
)
|
)
|
||||||
@ -101,6 +103,38 @@ class TestMonotoneConstraints:
|
|||||||
|
|
||||||
assert is_correctly_constrained(constrained_hist_method)
|
assert is_correctly_constrained(constrained_hist_method)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('format', [dict, list])
|
||||||
|
def test_monotone_constraints_feature_names(self, format):
|
||||||
|
|
||||||
|
# next check monotonicity when initializing monotone_constraints by feature names
|
||||||
|
params = {
|
||||||
|
'tree_method': 'hist', 'verbosity': 1,
|
||||||
|
'grow_policy': 'lossguide',
|
||||||
|
'monotone_constraints': {'feature_0': 1, 'feature_1': -1}
|
||||||
|
}
|
||||||
|
|
||||||
|
if format == list:
|
||||||
|
params = list(params.items())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
xgb.train(params, training_dset)
|
||||||
|
|
||||||
|
feature_names =[ 'feature_0', 'feature_2']
|
||||||
|
training_dset_w_feature_names = xgb.DMatrix(x, label=y, feature_names=feature_names)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
xgb.train(params, training_dset_w_feature_names)
|
||||||
|
|
||||||
|
feature_names =[ 'feature_0', 'feature_1']
|
||||||
|
training_dset_w_feature_names = xgb.DMatrix(x, label=y, feature_names=feature_names)
|
||||||
|
|
||||||
|
constrained_learner = xgb.train(
|
||||||
|
params, training_dset_w_feature_names
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_correctly_constrained(constrained_learner, feature_names)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_training_accuracy(self):
|
def test_training_accuracy(self):
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user