Fix multi-output with alternating strategies. (#9933)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2024-01-04 16:41:13 +08:00 committed by GitHub
parent 5f7b5a6921
commit 621348abb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 73 deletions

View File

@ -394,3 +394,14 @@ def train_result(
assert booster.feature_types == dmat.feature_types assert booster.feature_types == dmat.feature_types
return result return result
class ResetStrategy(xgb.callback.TrainingCallback):
"""Callback for testing multi-output."""
def after_iteration(self, model: xgb.Booster, epoch: int, evals_log: dict) -> bool:
if epoch % 2 == 0:
model.set_param({"multi_strategy": "multi_output_tree"})
else:
model.set_param({"multi_strategy": "one_output_per_tree"})
return False

View File

@ -545,12 +545,12 @@ class QuantileHistMaker : public TreeUpdater {
} }
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override { bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
if (p_impl_) { if (out_preds.Shape(1) > 1) {
return p_impl_->UpdatePredictionCache(data, out_preds); CHECK(p_mtimpl_);
} else if (p_mtimpl_) {
return p_mtimpl_->UpdatePredictionCache(data, out_preds); return p_mtimpl_->UpdatePredictionCache(data, out_preds);
} else { } else {
return false; CHECK(p_impl_);
return p_impl_->UpdatePredictionCache(data, out_preds);
} }
} }

View File

@ -22,6 +22,7 @@ class LintersPaths:
"tests/python/test_dmatrix.py", "tests/python/test_dmatrix.py",
"tests/python/test_dt.py", "tests/python/test_dt.py",
"tests/python/test_demos.py", "tests/python/test_demos.py",
"tests/python/test_multi_target.py",
"tests/python/test_predict.py", "tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py", "tests/python/test_quantile_dmatrix.py",
"tests/python/test_tree_regularization.py", "tests/python/test_tree_regularization.py",
@ -79,6 +80,7 @@ class LintersPaths:
"tests/python/test_dt.py", "tests/python/test_dt.py",
"tests/python/test_demos.py", "tests/python/test_demos.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",
"tests/python/test_multi_target.py",
"tests/python-gpu/test_gpu_data_iterator.py", "tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/load_pickle.py", "tests/python-gpu/load_pickle.py",
"tests/test_distributed/test_with_spark/test_data.py", "tests/test_distributed/test_with_spark/test_data.py",

View File

@ -8,6 +8,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.updater import ResetStrategy
dpath = tm.data_dir(__file__) dpath = tm.data_dir(__file__)
@ -653,11 +654,6 @@ class TestModels:
num_parallel_tree = 4 num_parallel_tree = 4
num_boost_round = 16 num_boost_round = 16
class ResetStrategy(xgb.callback.TrainingCallback):
def after_iteration(self, model, epoch: int, evals_log) -> bool:
model.set_param({"multi_strategy": "multi_output_tree"})
return False
booster = xgb.train( booster = xgb.train(
{ {
"num_parallel_tree": num_parallel_tree, "num_parallel_tree": num_parallel_tree,

View File

@ -0,0 +1,105 @@
from typing import Any, Dict
from hypothesis import given, note, settings, strategies
import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.params import (
exact_parameter_strategy,
hist_cache_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import ResetStrategy, train_result
class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
def test_multiclass() -> None:
X, y = tm.datasets.make_classification(
128, n_features=12, n_informative=10, n_classes=4
)
clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
)
clf.fit(X, y, eval_set=[(X, y)])
assert clf.objective == "multi:softprob"
assert tm.non_increasing(clf.evals_result()["validation_0"]["mlogloss"])
proba = clf.predict_proba(X)
assert proba.shape == (y.shape[0], 4)
def test_multilabel() -> None:
X, y = tm.datasets.make_multilabel_classification(128)
clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
)
clf.fit(X, y, eval_set=[(X, y)])
assert clf.objective == "binary:logistic"
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
proba = clf.predict_proba(X)
assert proba.shape == y.shape

View File

@ -12,7 +12,6 @@ from xgboost.testing.params import (
cat_parameter_strategy, cat_parameter_strategy,
exact_parameter_strategy, exact_parameter_strategy,
hist_cache_strategy, hist_cache_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy, hist_parameter_strategy,
) )
from xgboost.testing.updater import ( from xgboost.testing.updater import (
@ -25,69 +24,6 @@ from xgboost.testing.updater import (
) )
class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(
self, param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
class TestTreeMethod: class TestTreeMethod:
USE_ONEHOT = np.iinfo(np.int32).max USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1 USE_PART = 1