Cleanup configuration for constraints. (#7758)
This commit is contained in:
@@ -1392,50 +1392,46 @@ class Booster:
|
||||
raise TypeError('Unknown type:', model_file)
|
||||
|
||||
params = params or {}
|
||||
params = _configure_metrics(params.copy())
|
||||
params = self._configure_constraints(params)
|
||||
if isinstance(params, list):
|
||||
params.append(('validate_parameters', True))
|
||||
params_processed = _configure_metrics(params.copy())
|
||||
params_processed = self._configure_constraints(params_processed)
|
||||
if isinstance(params_processed, list):
|
||||
params_processed.append(("validate_parameters", True))
|
||||
else:
|
||||
params['validate_parameters'] = True
|
||||
params_processed["validate_parameters"] = True
|
||||
|
||||
self.set_param(params or {})
|
||||
if (params is not None) and ('booster' in params):
|
||||
self.booster = params['booster']
|
||||
else:
|
||||
self.booster = 'gbtree'
|
||||
self.set_param(params_processed or {})
|
||||
|
||||
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
|
||||
def _transform_monotone_constrains(
|
||||
self, value: Union[Dict[str, int], str]
|
||||
) -> Union[Tuple[int, ...], str]:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
constrained_features = set(value.keys())
|
||||
if not constrained_features.issubset(set(self.feature_names or [])):
|
||||
raise ValueError('Constrained features are not a subset of '
|
||||
'training data feature names')
|
||||
feature_names = self.feature_names or []
|
||||
if not constrained_features.issubset(set(feature_names)):
|
||||
raise ValueError(
|
||||
"Constrained features are not a subset of training data feature names"
|
||||
)
|
||||
|
||||
return '(' + ','.join([str(value.get(feature_name, 0))
|
||||
for feature_name in self.feature_names]) + ')'
|
||||
return tuple(value.get(name, 0) for name in feature_names)
|
||||
|
||||
def _transform_interaction_constraints(
|
||||
self, value: Union[List[Tuple[str]], str]
|
||||
) -> str:
|
||||
self, value: Union[Sequence[Sequence[str]], str]
|
||||
) -> Union[str, List[List[int]]]:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
feature_idx_mapping = {k: str(v) for v, k in enumerate(self.feature_names or [])}
|
||||
feature_idx_mapping = {
|
||||
name: idx for idx, name in enumerate(self.feature_names or [])
|
||||
}
|
||||
|
||||
try:
|
||||
s = "["
|
||||
result = []
|
||||
for constraint in value:
|
||||
s += (
|
||||
"["
|
||||
+ ",".join(
|
||||
[feature_idx_mapping[feature_name] for feature_name in constraint]
|
||||
)
|
||||
+ "],"
|
||||
result.append(
|
||||
[feature_idx_mapping[feature_name] for feature_name in constraint]
|
||||
)
|
||||
return s[:-1] + "]"
|
||||
return result
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Constrained features are not a subset of training data feature names"
|
||||
@@ -1444,17 +1440,16 @@ class Booster:
|
||||
def _configure_constraints(self, params: Union[List, Dict]) -> Union[List, Dict]:
|
||||
if isinstance(params, dict):
|
||||
value = params.get("monotone_constraints")
|
||||
if value:
|
||||
params[
|
||||
"monotone_constraints"
|
||||
] = self._transform_monotone_constrains(value)
|
||||
if value is not None:
|
||||
params["monotone_constraints"] = self._transform_monotone_constrains(
|
||||
value
|
||||
)
|
||||
|
||||
value = params.get("interaction_constraints")
|
||||
if value:
|
||||
if value is not None:
|
||||
params[
|
||||
"interaction_constraints"
|
||||
] = self._transform_interaction_constraints(value)
|
||||
|
||||
elif isinstance(params, list):
|
||||
for idx, param in enumerate(params):
|
||||
name, value = param
|
||||
@@ -2462,11 +2457,9 @@ class Booster:
|
||||
if not PANDAS_INSTALLED:
|
||||
raise ImportError(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError(
|
||||
f"This method is not defined for Booster type {self.booster}"
|
||||
)
|
||||
booster = json.loads(self.save_config())["learner"]["gradient_booster"]["name"]
|
||||
if booster not in {"gbtree", "dart"}:
|
||||
raise ValueError(f"This method is not defined for Booster type {booster}")
|
||||
|
||||
tree_ids = []
|
||||
node_ids = []
|
||||
|
||||
Reference in New Issue
Block a user