Cleanup booster param types. (#8756)

This commit is contained in:
Jiaming Yuan 2023-02-07 15:52:19 +08:00 committed by GitHub
parent 7b3d473593
commit c4802bfcd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 21 deletions

View File

@ -25,7 +25,7 @@ DataType = Any
FeatureInfo = Sequence[str]
FeatureNames = FeatureInfo
FeatureTypes = FeatureInfo
BoosterParam = Union[List, Dict] # better be sequence
BoosterParam = Union[List, Dict[str, Any]] # better be sequence
ArrayLike = Any
PathLike = Union[str, os.PathLike]

View File

@ -1655,27 +1655,18 @@ class Booster:
def _configure_constraints(self, params: BoosterParam) -> BoosterParam:
if isinstance(params, dict):
value = params.get("monotone_constraints")
if value is not None:
params["monotone_constraints"] = self._transform_monotone_constrains(
value
)
# we must use list in the internal code as there can be multiple metrics
# with the same parameter name `eval_metric` (same key for dictionary).
params = list(params.items())
for idx, param in enumerate(params):
name, value = param
if value is None:
continue
value = params.get("interaction_constraints")
if value is not None:
params[
"interaction_constraints"
] = self._transform_interaction_constraints(value)
elif isinstance(params, list):
for idx, param in enumerate(params):
name, value = param
if not value:
continue
if name == "monotone_constraints":
params[idx] = (name, self._transform_monotone_constrains(value))
elif name == "interaction_constraints":
params[idx] = (name, self._transform_interaction_constraints(value))
if name == "monotone_constraints":
params[idx] = (name, self._transform_monotone_constrains(value))
elif name == "interaction_constraints":
params[idx] = (name, self._transform_interaction_constraints(value))
return params