Support configuring constraints by feature names (#6783)
Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -1193,6 +1193,7 @@ class Booster(object):
|
||||
|
||||
params = params or {}
|
||||
params = self._configure_metrics(params.copy())
|
||||
params = self._configure_constraints(params)
|
||||
if isinstance(params, list):
|
||||
params.append(('validate_parameters', True))
|
||||
else:
|
||||
@@ -1233,6 +1234,68 @@ class Booster(object):
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
def _transform_monotone_constrains(self, value: Union[dict, str]) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
constrained_features = set(value.keys())
|
||||
if not constrained_features.issubset(set(self.feature_names or [])):
|
||||
raise ValueError('Constrained features are not a subset of '
|
||||
'training data feature names')
|
||||
|
||||
return '(' + ','.join([str(value.get(feature_name, 0))
|
||||
for feature_name in self.feature_names]) + ')'
|
||||
|
||||
def _transform_interaction_constraints(self, value: Union[list, str]) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
feature_idx_mapping = {k: str(v) for v, k in enumerate(self.feature_names or [])}
|
||||
|
||||
try:
|
||||
s = "["
|
||||
for constraint in value:
|
||||
s += (
|
||||
"["
|
||||
+ ",".join(
|
||||
[feature_idx_mapping[feature_name] for feature_name in constraint]
|
||||
)
|
||||
+ "]"
|
||||
)
|
||||
return s + "]"
|
||||
except KeyError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
"Constrained features are not a subset of training data feature names"
|
||||
) from e
|
||||
|
||||
def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||
if isinstance(params, dict):
|
||||
value = params.get("monotone_constraints")
|
||||
if value:
|
||||
params[
|
||||
"monotone_constraints"
|
||||
] = self._transform_monotone_constrains(value)
|
||||
|
||||
value = params.get("interaction_constraints")
|
||||
if value:
|
||||
params[
|
||||
"interaction_constraints"
|
||||
] = self._transform_interaction_constraints(value)
|
||||
|
||||
elif isinstance(params, list):
|
||||
for idx, param in enumerate(params):
|
||||
name, value = param
|
||||
if not value:
|
||||
continue
|
||||
|
||||
if name == "monotone_constraints":
|
||||
params[idx] = (name, self._transform_monotone_constrains(value))
|
||||
elif name == "interaction_constraints":
|
||||
params[idx] = (name, self._transform_interaction_constraints(value))
|
||||
|
||||
return params
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'handle') and self.handle is not None:
|
||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||
|
||||
Reference in New Issue
Block a user