Fix multi-output with alternating strategies. (#9933)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
5f7b5a6921
commit
621348abb3
@ -394,3 +394,14 @@ def train_result(
|
||||
assert booster.feature_types == dmat.feature_types
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ResetStrategy(xgb.callback.TrainingCallback):
|
||||
"""Callback for testing multi-output."""
|
||||
|
||||
def after_iteration(self, model: xgb.Booster, epoch: int, evals_log: dict) -> bool:
|
||||
if epoch % 2 == 0:
|
||||
model.set_param({"multi_strategy": "multi_output_tree"})
|
||||
else:
|
||||
model.set_param({"multi_strategy": "one_output_per_tree"})
|
||||
return False
|
||||
|
||||
@ -545,12 +545,12 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
||||
if (p_impl_) {
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (p_mtimpl_) {
|
||||
if (out_preds.Shape(1) > 1) {
|
||||
CHECK(p_mtimpl_);
|
||||
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
CHECK(p_impl_);
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class LintersPaths:
|
||||
"tests/python/test_dmatrix.py",
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_demos.py",
|
||||
"tests/python/test_multi_target.py",
|
||||
"tests/python/test_predict.py",
|
||||
"tests/python/test_quantile_dmatrix.py",
|
||||
"tests/python/test_tree_regularization.py",
|
||||
@ -79,6 +80,7 @@ class LintersPaths:
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_demos.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python/test_multi_target.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/python-gpu/load_pickle.py",
|
||||
"tests/test_distributed/test_with_spark/test_data.py",
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.updater import ResetStrategy
|
||||
|
||||
dpath = tm.data_dir(__file__)
|
||||
|
||||
@ -653,11 +654,6 @@ class TestModels:
|
||||
num_parallel_tree = 4
|
||||
num_boost_round = 16
|
||||
|
||||
class ResetStrategy(xgb.callback.TrainingCallback):
|
||||
def after_iteration(self, model, epoch: int, evals_log) -> bool:
|
||||
model.set_param({"multi_strategy": "multi_output_tree"})
|
||||
return False
|
||||
|
||||
booster = xgb.train(
|
||||
{
|
||||
"num_parallel_tree": num_parallel_tree,
|
||||
|
||||
105
tests/python/test_multi_target.py
Normal file
105
tests/python/test_multi_target.py
Normal file
@ -0,0 +1,105 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from hypothesis import given, note, settings, strategies
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.params import (
|
||||
exact_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
hist_multi_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
)
|
||||
from xgboost.testing.updater import ResetStrategy, train_result
|
||||
|
||||
|
||||
class TestTreeMethodMulti:
|
||||
@given(
|
||||
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
||||
if dataset.name.endswith("-l1"):
|
||||
return
|
||||
param["tree_method"] = "exact"
|
||||
param = dataset.set_params(param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@given(
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
strategies.integers(1, 20),
|
||||
tm.multi_dataset_strategy,
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_approx(
|
||||
self,
|
||||
param: Dict[str, Any],
|
||||
hist_param: Dict[str, Any],
|
||||
cache_param: Dict[str, Any],
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
) -> None:
|
||||
param["tree_method"] = "approx"
|
||||
param = dataset.set_params(param)
|
||||
param.update(hist_param)
|
||||
param.update(cache_param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(str(result))
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@given(
|
||||
exact_parameter_strategy,
|
||||
hist_multi_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
strategies.integers(1, 20),
|
||||
tm.multi_dataset_strategy,
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_hist(
|
||||
self,
|
||||
param: Dict[str, Any],
|
||||
hist_param: Dict[str, Any],
|
||||
cache_param: Dict[str, Any],
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
) -> None:
|
||||
if dataset.name.endswith("-l1"):
|
||||
return
|
||||
param["tree_method"] = "hist"
|
||||
param = dataset.set_params(param)
|
||||
param.update(hist_param)
|
||||
param.update(cache_param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(str(result))
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
|
||||
def test_multiclass() -> None:
|
||||
X, y = tm.datasets.make_classification(
|
||||
128, n_features=12, n_informative=10, n_classes=4
|
||||
)
|
||||
clf = xgb.XGBClassifier(
|
||||
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
|
||||
)
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
assert clf.objective == "multi:softprob"
|
||||
assert tm.non_increasing(clf.evals_result()["validation_0"]["mlogloss"])
|
||||
|
||||
proba = clf.predict_proba(X)
|
||||
assert proba.shape == (y.shape[0], 4)
|
||||
|
||||
|
||||
def test_multilabel() -> None:
|
||||
X, y = tm.datasets.make_multilabel_classification(128)
|
||||
clf = xgb.XGBClassifier(
|
||||
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
|
||||
)
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
assert clf.objective == "binary:logistic"
|
||||
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
|
||||
|
||||
proba = clf.predict_proba(X)
|
||||
assert proba.shape == y.shape
|
||||
@ -12,7 +12,6 @@ from xgboost.testing.params import (
|
||||
cat_parameter_strategy,
|
||||
exact_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
hist_multi_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
)
|
||||
from xgboost.testing.updater import (
|
||||
@ -25,69 +24,6 @@ from xgboost.testing.updater import (
|
||||
)
|
||||
|
||||
|
||||
class TestTreeMethodMulti:
|
||||
@given(
|
||||
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
||||
if dataset.name.endswith("-l1"):
|
||||
return
|
||||
param["tree_method"] = "exact"
|
||||
param = dataset.set_params(param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@given(
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
strategies.integers(1, 20),
|
||||
tm.multi_dataset_strategy,
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_approx(
|
||||
self, param: Dict[str, Any],
|
||||
hist_param: Dict[str, Any],
|
||||
cache_param: Dict[str, Any],
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
) -> None:
|
||||
param["tree_method"] = "approx"
|
||||
param = dataset.set_params(param)
|
||||
param.update(hist_param)
|
||||
param.update(cache_param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(str(result))
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@given(
|
||||
exact_parameter_strategy,
|
||||
hist_multi_parameter_strategy,
|
||||
hist_cache_strategy,
|
||||
strategies.integers(1, 20),
|
||||
tm.multi_dataset_strategy,
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_hist(
|
||||
self,
|
||||
param: Dict[str, Any],
|
||||
hist_param: Dict[str, Any],
|
||||
cache_param: Dict[str, Any],
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
) -> None:
|
||||
if dataset.name.endswith("-l1"):
|
||||
return
|
||||
param["tree_method"] = "hist"
|
||||
param = dataset.set_params(param)
|
||||
param.update(hist_param)
|
||||
param.update(cache_param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(str(result))
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
|
||||
class TestTreeMethod:
|
||||
USE_ONEHOT = np.iinfo(np.int32).max
|
||||
USE_PART = 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user