Fix multi-output with alternating strategies. (#9933)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
5f7b5a6921
commit
621348abb3
@ -394,3 +394,14 @@ def train_result(
|
|||||||
assert booster.feature_types == dmat.feature_types
|
assert booster.feature_types == dmat.feature_types
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ResetStrategy(xgb.callback.TrainingCallback):
|
||||||
|
"""Callback for testing multi-output."""
|
||||||
|
|
||||||
|
def after_iteration(self, model: xgb.Booster, epoch: int, evals_log: dict) -> bool:
|
||||||
|
if epoch % 2 == 0:
|
||||||
|
model.set_param({"multi_strategy": "multi_output_tree"})
|
||||||
|
else:
|
||||||
|
model.set_param({"multi_strategy": "one_output_per_tree"})
|
||||||
|
return False
|
||||||
|
|||||||
@ -545,12 +545,12 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
||||||
if (p_impl_) {
|
if (out_preds.Shape(1) > 1) {
|
||||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
CHECK(p_mtimpl_);
|
||||||
} else if (p_mtimpl_) {
|
|
||||||
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
|
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
|
||||||
} else {
|
} else {
|
||||||
return false;
|
CHECK(p_impl_);
|
||||||
|
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class LintersPaths:
|
|||||||
"tests/python/test_dmatrix.py",
|
"tests/python/test_dmatrix.py",
|
||||||
"tests/python/test_dt.py",
|
"tests/python/test_dt.py",
|
||||||
"tests/python/test_demos.py",
|
"tests/python/test_demos.py",
|
||||||
|
"tests/python/test_multi_target.py",
|
||||||
"tests/python/test_predict.py",
|
"tests/python/test_predict.py",
|
||||||
"tests/python/test_quantile_dmatrix.py",
|
"tests/python/test_quantile_dmatrix.py",
|
||||||
"tests/python/test_tree_regularization.py",
|
"tests/python/test_tree_regularization.py",
|
||||||
@ -79,6 +80,7 @@ class LintersPaths:
|
|||||||
"tests/python/test_dt.py",
|
"tests/python/test_dt.py",
|
||||||
"tests/python/test_demos.py",
|
"tests/python/test_demos.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
|
"tests/python/test_multi_target.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/python-gpu/load_pickle.py",
|
"tests/python-gpu/load_pickle.py",
|
||||||
"tests/test_distributed/test_with_spark/test_data.py",
|
"tests/test_distributed/test_with_spark/test_data.py",
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.updater import ResetStrategy
|
||||||
|
|
||||||
dpath = tm.data_dir(__file__)
|
dpath = tm.data_dir(__file__)
|
||||||
|
|
||||||
@ -653,11 +654,6 @@ class TestModels:
|
|||||||
num_parallel_tree = 4
|
num_parallel_tree = 4
|
||||||
num_boost_round = 16
|
num_boost_round = 16
|
||||||
|
|
||||||
class ResetStrategy(xgb.callback.TrainingCallback):
|
|
||||||
def after_iteration(self, model, epoch: int, evals_log) -> bool:
|
|
||||||
model.set_param({"multi_strategy": "multi_output_tree"})
|
|
||||||
return False
|
|
||||||
|
|
||||||
booster = xgb.train(
|
booster = xgb.train(
|
||||||
{
|
{
|
||||||
"num_parallel_tree": num_parallel_tree,
|
"num_parallel_tree": num_parallel_tree,
|
||||||
|
|||||||
105
tests/python/test_multi_target.py
Normal file
105
tests/python/test_multi_target.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from hypothesis import given, note, settings, strategies
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.params import (
|
||||||
|
exact_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
|
hist_multi_parameter_strategy,
|
||||||
|
hist_parameter_strategy,
|
||||||
|
)
|
||||||
|
from xgboost.testing.updater import ResetStrategy, train_result
|
||||||
|
|
||||||
|
|
||||||
|
class TestTreeMethodMulti:
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
||||||
|
if dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
param["tree_method"] = "exact"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy,
|
||||||
|
hist_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
|
strategies.integers(1, 20),
|
||||||
|
tm.multi_dataset_strategy,
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_approx(
|
||||||
|
self,
|
||||||
|
param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
) -> None:
|
||||||
|
param["tree_method"] = "approx"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
note(str(result))
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy,
|
||||||
|
hist_multi_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
|
strategies.integers(1, 20),
|
||||||
|
tm.multi_dataset_strategy,
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_hist(
|
||||||
|
self,
|
||||||
|
param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
) -> None:
|
||||||
|
if dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
param["tree_method"] = "hist"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
note(str(result))
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiclass() -> None:
|
||||||
|
X, y = tm.datasets.make_classification(
|
||||||
|
128, n_features=12, n_informative=10, n_classes=4
|
||||||
|
)
|
||||||
|
clf = xgb.XGBClassifier(
|
||||||
|
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
|
||||||
|
)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
|
assert clf.objective == "multi:softprob"
|
||||||
|
assert tm.non_increasing(clf.evals_result()["validation_0"]["mlogloss"])
|
||||||
|
|
||||||
|
proba = clf.predict_proba(X)
|
||||||
|
assert proba.shape == (y.shape[0], 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilabel() -> None:
|
||||||
|
X, y = tm.datasets.make_multilabel_classification(128)
|
||||||
|
clf = xgb.XGBClassifier(
|
||||||
|
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
|
||||||
|
)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
|
assert clf.objective == "binary:logistic"
|
||||||
|
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
|
||||||
|
|
||||||
|
proba = clf.predict_proba(X)
|
||||||
|
assert proba.shape == y.shape
|
||||||
@ -12,7 +12,6 @@ from xgboost.testing.params import (
|
|||||||
cat_parameter_strategy,
|
cat_parameter_strategy,
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
hist_cache_strategy,
|
hist_cache_strategy,
|
||||||
hist_multi_parameter_strategy,
|
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
)
|
)
|
||||||
from xgboost.testing.updater import (
|
from xgboost.testing.updater import (
|
||||||
@ -25,69 +24,6 @@ from xgboost.testing.updater import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTreeMethodMulti:
|
|
||||||
@given(
|
|
||||||
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
|
||||||
)
|
|
||||||
@settings(deadline=None, print_blob=True)
|
|
||||||
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
|
||||||
if dataset.name.endswith("-l1"):
|
|
||||||
return
|
|
||||||
param["tree_method"] = "exact"
|
|
||||||
param = dataset.set_params(param)
|
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
|
||||||
|
|
||||||
@given(
|
|
||||||
exact_parameter_strategy,
|
|
||||||
hist_parameter_strategy,
|
|
||||||
hist_cache_strategy,
|
|
||||||
strategies.integers(1, 20),
|
|
||||||
tm.multi_dataset_strategy,
|
|
||||||
)
|
|
||||||
@settings(deadline=None, print_blob=True)
|
|
||||||
def test_approx(
|
|
||||||
self, param: Dict[str, Any],
|
|
||||||
hist_param: Dict[str, Any],
|
|
||||||
cache_param: Dict[str, Any],
|
|
||||||
num_rounds: int,
|
|
||||||
dataset: tm.TestDataset,
|
|
||||||
) -> None:
|
|
||||||
param["tree_method"] = "approx"
|
|
||||||
param = dataset.set_params(param)
|
|
||||||
param.update(hist_param)
|
|
||||||
param.update(cache_param)
|
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
|
||||||
note(str(result))
|
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
|
||||||
|
|
||||||
@given(
|
|
||||||
exact_parameter_strategy,
|
|
||||||
hist_multi_parameter_strategy,
|
|
||||||
hist_cache_strategy,
|
|
||||||
strategies.integers(1, 20),
|
|
||||||
tm.multi_dataset_strategy,
|
|
||||||
)
|
|
||||||
@settings(deadline=None, print_blob=True)
|
|
||||||
def test_hist(
|
|
||||||
self,
|
|
||||||
param: Dict[str, Any],
|
|
||||||
hist_param: Dict[str, Any],
|
|
||||||
cache_param: Dict[str, Any],
|
|
||||||
num_rounds: int,
|
|
||||||
dataset: tm.TestDataset,
|
|
||||||
) -> None:
|
|
||||||
if dataset.name.endswith("-l1"):
|
|
||||||
return
|
|
||||||
param["tree_method"] = "hist"
|
|
||||||
param = dataset.set_params(param)
|
|
||||||
param.update(hist_param)
|
|
||||||
param.update(cache_param)
|
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
|
||||||
note(str(result))
|
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
|
||||||
|
|
||||||
|
|
||||||
class TestTreeMethod:
|
class TestTreeMethod:
|
||||||
USE_ONEHOT = np.iinfo(np.int32).max
|
USE_ONEHOT = np.iinfo(np.int32).max
|
||||||
USE_PART = 1
|
USE_PART = 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user