Use hist as the default tree method. (#9320)

This commit is contained in:
Jiaming Yuan 2023-06-27 23:04:24 +08:00 committed by GitHub
parent bc267dd729
commit f4798718c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 138 additions and 178 deletions

View File

@ -85,9 +85,18 @@ test_that("dart prediction works", {
rnorm(100) rnorm(100)
set.seed(1994) set.seed(1994)
booster_by_xgboost <- xgboost(data = d, label = y, max_depth = 2, booster = "dart", booster_by_xgboost <- xgboost(
rate_drop = 0.5, one_drop = TRUE, data = d,
eta = 1, nthread = 2, nrounds = nrounds, objective = "reg:squarederror") label = y,
max_depth = 2,
booster = "dart",
rate_drop = 0.5,
one_drop = TRUE,
eta = 1,
nthread = 2,
nrounds = nrounds,
objective = "reg:squarederror"
)
pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, ntreelimit = 0) pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, ntreelimit = 0)
pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, ntreelimit = nrounds) pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, ntreelimit = nrounds)
expect_true(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE))) expect_true(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE)))
@ -97,19 +106,19 @@ test_that("dart prediction works", {
set.seed(1994) set.seed(1994)
dtrain <- xgb.DMatrix(data = d, info = list(label = y)) dtrain <- xgb.DMatrix(data = d, info = list(label = y))
booster_by_train <- xgb.train(params = list( booster_by_train <- xgb.train(
booster = "dart", params = list(
max_depth = 2, booster = "dart",
eta = 1, max_depth = 2,
rate_drop = 0.5, eta = 1,
one_drop = TRUE, rate_drop = 0.5,
nthread = 1, one_drop = TRUE,
tree_method = "exact", nthread = 1,
objective = "reg:squarederror" objective = "reg:squarederror"
), ),
data = dtrain, data = dtrain,
nrounds = nrounds nrounds = nrounds
) )
pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, ntreelimit = 0) pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, ntreelimit = 0)
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, ntreelimit = nrounds) pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, ntreelimit = nrounds)
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE) pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
@ -399,7 +408,7 @@ test_that("colsample_bytree works", {
xgb.importance(model = bst) xgb.importance(model = bst)
# If colsample_bytree works properly, a variety of features should be used # If colsample_bytree works properly, a variety of features should be used
# in the 100 trees # in the 100 trees
expect_gte(nrow(xgb.importance(model = bst)), 30) expect_gte(nrow(xgb.importance(model = bst)), 28)
}) })
test_that("Configuration works", { test_that("Configuration works", {

View File

@ -13,7 +13,10 @@ test_that("updating the model works", {
watchlist <- list(train = dtrain, test = dtest) watchlist <- list(train = dtrain, test = dtest)
# no-subsampling # no-subsampling
p1 <- list(objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2) p1 <- list(
objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2,
updater = "grow_colmaker,prune"
)
set.seed(11) set.seed(11)
bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0) bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0)
tr1 <- xgb.model.dt.tree(model = bst1) tr1 <- xgb.model.dt.tree(model = bst1)

View File

@ -39,7 +39,6 @@ namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gbtree); DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(Args const& cfg) { void GBTree::Configure(Args const& cfg) {
this->cfg_ = cfg;
std::string updater_seq = tparam_.updater_seq; std::string updater_seq = tparam_.updater_seq;
tparam_.UpdateAllowUnknown(cfg); tparam_.UpdateAllowUnknown(cfg);
tree_param_.UpdateAllowUnknown(cfg); tree_param_.UpdateAllowUnknown(cfg);
@ -78,10 +77,9 @@ void GBTree::Configure(Args const& cfg) {
monitor_.Init("GBTree"); monitor_.Init("GBTree");
specified_updater_ = std::any_of(cfg.cbegin(), cfg.cend(), specified_updater_ = std::any_of(
[](std::pair<std::string, std::string> const& arg) { cfg.cbegin(), cfg.cend(),
return arg.first == "updater"; [](std::pair<std::string, std::string> const& arg) { return arg.first == "updater"; });
});
if (specified_updater_ && !showed_updater_warning_) { if (specified_updater_ && !showed_updater_warning_) {
LOG(WARNING) << "DANGER AHEAD: You have manually specified `updater` " LOG(WARNING) << "DANGER AHEAD: You have manually specified `updater` "
@ -93,12 +91,19 @@ void GBTree::Configure(Args const& cfg) {
showed_updater_warning_ = true; showed_updater_warning_ = true;
} }
if (model_.learner_model_param->IsVectorLeaf()) {
CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto)
<< "Only the hist tree method is supported for building multi-target trees with vector "
"leaf.";
}
LOG(DEBUG) << "Using tree method: " << static_cast<int>(tparam_.tree_method);
this->ConfigureUpdaters(); this->ConfigureUpdaters();
if (updater_seq != tparam_.updater_seq) { if (updater_seq != tparam_.updater_seq) {
updaters_.clear(); updaters_.clear();
this->InitUpdater(cfg); this->InitUpdater(cfg);
} else { } else {
for (auto &up : updaters_) { for (auto& up : updaters_) {
up->Configure(cfg); up->Configure(cfg);
} }
} }
@ -106,66 +111,6 @@ void GBTree::Configure(Args const& cfg) {
configured_ = true; configured_ = true;
} }
// FIXME(trivialfis): This handles updaters. Because the choice of updaters depends on
// whether external memory is used and how large is dataset. We can remove the dependency
// on DMatrix once `hist` tree method can handle external memory so that we can make it
// default.
void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
CHECK(this->configured_);
std::string updater_seq = tparam_.updater_seq;
CHECK(tparam_.GetInitialised());
tparam_.UpdateAllowUnknown(cfg);
this->PerformTreeMethodHeuristic(fmat);
this->ConfigureUpdaters();
// initialize the updaters only when needed.
if (updater_seq != tparam_.updater_seq) {
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
this->updaters_.clear();
this->InitUpdater(cfg);
}
}
void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
if (specified_updater_) {
// This method is disabled when `updater` parameter is explicitly
// set, since only experts are expected to do so.
return;
}
if (model_.learner_model_param->IsVectorLeaf()) {
CHECK(tparam_.tree_method == TreeMethod::kHist)
<< "Only the hist tree method is supported for building multi-target trees with vector "
"leaf.";
}
// tparam_ is set before calling this function.
if (tparam_.tree_method != TreeMethod::kAuto) {
return;
}
if (collective::IsDistributed()) {
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
"for distributed training.";
tparam_.tree_method = TreeMethod::kApprox;
} else if (!fmat->SingleColBlock()) {
LOG(INFO) << "Tree method is automatically set to 'approx' "
"since external-memory data matrix is used.";
tparam_.tree_method = TreeMethod::kApprox;
} else if (fmat->Info().num_row_ >= (4UL << 20UL)) {
/* Choose tree_method='approx' automatically for large data matrix */
LOG(INFO) << "Tree method is automatically selected to be "
"'approx' for faster speed. To use old behavior "
"(exact greedy algorithm on single machine), "
"set tree_method to 'exact'.";
tparam_.tree_method = TreeMethod::kApprox;
} else {
tparam_.tree_method = TreeMethod::kExact;
}
LOG(DEBUG) << "Using tree method: " << static_cast<int>(tparam_.tree_method);
}
void GBTree::ConfigureUpdaters() { void GBTree::ConfigureUpdaters() {
if (specified_updater_) { if (specified_updater_) {
return; return;
@ -173,31 +118,25 @@ void GBTree::ConfigureUpdaters() {
// `updater` parameter was manually specified // `updater` parameter was manually specified
/* Choose updaters according to tree_method parameters */ /* Choose updaters according to tree_method parameters */
switch (tparam_.tree_method) { switch (tparam_.tree_method) {
case TreeMethod::kAuto: case TreeMethod::kAuto: // Use hist as default in 2.0
// Use heuristic to choose between 'exact' and 'approx' This case TreeMethod::kHist: {
// choice is carried out in PerformTreeMethodHeuristic() before tparam_.updater_seq = "grow_quantile_histmaker";
// calling this function.
break; break;
}
case TreeMethod::kApprox: case TreeMethod::kApprox:
tparam_.updater_seq = "grow_histmaker"; tparam_.updater_seq = "grow_histmaker";
break; break;
case TreeMethod::kExact: case TreeMethod::kExact:
tparam_.updater_seq = "grow_colmaker,prune"; tparam_.updater_seq = "grow_colmaker,prune";
break; break;
case TreeMethod::kHist: {
LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater "
"grow_quantile_histmaker.";
tparam_.updater_seq = "grow_quantile_histmaker";
break;
}
case TreeMethod::kGPUHist: { case TreeMethod::kGPUHist: {
common::AssertGPUSupport(); common::AssertGPUSupport();
tparam_.updater_seq = "grow_gpu_hist"; tparam_.updater_seq = "grow_gpu_hist";
break; break;
} }
default: default:
LOG(FATAL) << "Unknown tree_method (" LOG(FATAL) << "Unknown tree_method (" << static_cast<int>(tparam_.tree_method)
<< static_cast<int>(tparam_.tree_method) << ") detected"; << ") detected";
} }
} }
@ -253,7 +192,6 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry* predt, ObjFunction const* obj) { PredictionCacheEntry* predt, ObjFunction const* obj) {
TreesOneIter new_trees; TreesOneIter new_trees;
bst_target_t const n_groups = model_.learner_model_param->OutputLength(); bst_target_t const n_groups = model_.learner_model_param->OutputLength();
ConfigureWithKnownData(this->cfg_, p_fmat);
monitor_.Start("BoostNewTrees"); monitor_.Start("BoostNewTrees");
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let

View File

@ -56,9 +56,7 @@ DECLARE_FIELD_ENUM_CLASS(xgboost::TreeMethod);
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeProcessType); DECLARE_FIELD_ENUM_CLASS(xgboost::TreeProcessType);
DECLARE_FIELD_ENUM_CLASS(xgboost::PredictorType); DECLARE_FIELD_ENUM_CLASS(xgboost::PredictorType);
namespace xgboost { namespace xgboost::gbm {
namespace gbm {
/*! \brief training parameters */ /*! \brief training parameters */
struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> { struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
/*! \brief tree updater sequence */ /*! \brief tree updater sequence */
@ -192,12 +190,8 @@ class GBTree : public GradientBooster {
: GradientBooster{ctx}, model_(booster_config, ctx_) {} : GradientBooster{ctx}, model_(booster_config, ctx_) {}
void Configure(const Args& cfg) override; void Configure(const Args& cfg) override;
// Revise `tree_method` and `updater` parameters after seeing the training
// data matrix, only useful when tree_method is auto.
void PerformTreeMethodHeuristic(DMatrix* fmat);
/*! \brief Map `tree_method` parameter to `updater` parameter */ /*! \brief Map `tree_method` parameter to `updater` parameter */
void ConfigureUpdaters(); void ConfigureUpdaters();
void ConfigureWithKnownData(Args const& cfg, DMatrix* fmat);
/** /**
* \brief Optionally update the leaf value. * \brief Optionally update the leaf value.
@ -222,11 +216,7 @@ class GBTree : public GradientBooster {
return tparam_; return tparam_;
} }
void Load(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override { model_.Load(fi); }
model_.Load(fi);
this->cfg_.clear();
}
void Save(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
model_.Save(fo); model_.Save(fo);
} }
@ -416,8 +406,6 @@ class GBTree : public GradientBooster {
bool showed_updater_warning_ {false}; bool showed_updater_warning_ {false};
bool specified_updater_ {false}; bool specified_updater_ {false};
bool configured_ {false}; bool configured_ {false};
// configurations for tree
Args cfg_;
// the updaters that can be applied to each of tree // the updaters that can be applied to each of tree
std::vector<std::unique_ptr<TreeUpdater>> updaters_; std::vector<std::unique_ptr<TreeUpdater>> updaters_;
// Predictors // Predictors
@ -431,7 +419,6 @@ class GBTree : public GradientBooster {
common::Monitor monitor_; common::Monitor monitor_;
}; };
} // namespace gbm } // namespace xgboost::gbm
} // namespace xgboost
#endif // XGBOOST_GBM_GBTREE_H_ #endif // XGBOOST_GBM_GBTREE_H_

View File

@ -23,6 +23,7 @@ class LintersPaths:
"tests/python/test_predict.py", "tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py", "tests/python/test_quantile_dmatrix.py",
"tests/python/test_tree_regularization.py", "tests/python/test_tree_regularization.py",
"tests/python/test_shap.py",
"tests/python-gpu/test_gpu_data_iterator.py", "tests/python-gpu/test_gpu_data_iterator.py",
"tests/test_distributed/test_with_spark/", "tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/", "tests/test_distributed/test_gpu_with_spark/",

View File

@ -379,6 +379,8 @@ TEST(Learner, Seed) {
TEST(Learner, ConstantSeed) { TEST(Learner, ConstantSeed) {
auto m = RandomDataGenerator{10, 10, 0}.GenerateDMatrix(true); auto m = RandomDataGenerator{10, 10, 0}.GenerateDMatrix(true);
std::unique_ptr<Learner> learner{Learner::Create({m})}; std::unique_ptr<Learner> learner{Learner::Create({m})};
// Use exact as it doesn't initialize column sampler at construction, which alters the rng.
learner->SetParam("tree_method", "exact");
learner->Configure(); // seed the global random learner->Configure(); // seed the global random
std::uniform_real_distribution<float> dist; std::uniform_real_distribution<float> dist;

View File

@ -18,9 +18,8 @@ CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI')
def test_basic_walkthrough(): def test_basic_walkthrough():
script = os.path.join(PYTHON_DEMO_DIR, 'basic_walkthrough.py') script = os.path.join(PYTHON_DEMO_DIR, 'basic_walkthrough.py')
cmd = ['python', script] cmd = ['python', script]
subprocess.check_call(cmd) with tempfile.TemporaryDirectory() as tmpdir:
os.remove('dump.nice.txt') subprocess.check_call(cmd, cwd=tmpdir)
os.remove('dump.raw.txt')
@pytest.mark.skipif(**tm.no_matplotlib()) @pytest.mark.skipif(**tm.no_matplotlib())

View File

@ -6,35 +6,34 @@ import scipy
import scipy.special import scipy.special
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
class TestSHAP: class TestSHAP:
def test_feature_importances(self) -> None:
def test_feature_importances(self): rng = np.random.RandomState(1994)
data = np.random.randn(100, 5) data = rng.randn(100, 5)
target = np.array([0, 1] * 50) target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'] features = ["Feature1", "Feature2", "Feature3", "Feature4", "Feature5"]
dm = xgb.DMatrix(data, label=target, dm = xgb.DMatrix(data, label=target, feature_names=features)
feature_names=features) params = {
params = {'objective': 'multi:softprob', "objective": "multi:softprob",
'eval_metric': 'mlogloss', "eval_metric": "mlogloss",
'eta': 0.3, "eta": 0.3,
'num_class': 3} "num_class": 3,
}
bst = xgb.train(params, dm, num_boost_round=10) bst = xgb.train(params, dm, num_boost_round=10)
# number of feature importances should == number of features # number of feature importances should == number of features
scores1 = bst.get_score() scores1 = bst.get_score()
scores2 = bst.get_score(importance_type='weight') scores2 = bst.get_score(importance_type="weight")
scores3 = bst.get_score(importance_type='cover') scores3 = bst.get_score(importance_type="cover")
scores4 = bst.get_score(importance_type='gain') scores4 = bst.get_score(importance_type="gain")
scores5 = bst.get_score(importance_type='total_cover') scores5 = bst.get_score(importance_type="total_cover")
scores6 = bst.get_score(importance_type='total_gain') scores6 = bst.get_score(importance_type="total_gain")
assert len(scores1) == len(features) assert len(scores1) == len(features)
assert len(scores2) == len(features) assert len(scores2) == len(features)
assert len(scores3) == len(features) assert len(scores3) == len(features)
@ -46,12 +45,11 @@ class TestSHAP:
fscores = bst.get_fscore() fscores = bst.get_fscore()
assert scores1 == fscores assert scores1 == fscores
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?format=libsvm') dtrain, dtest = tm.load_agaricus(__file__)
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?format=libsvm')
def fn(max_depth, num_rounds): def fn(max_depth: int, num_rounds: int) -> None:
# train # train
params = {'max_depth': max_depth, 'eta': 1, 'verbosity': 0} params = {"max_depth": max_depth, "eta": 1, "verbosity": 0}
bst = xgb.train(params, dtrain, num_boost_round=num_rounds) bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
# predict # predict
@ -82,7 +80,7 @@ class TestSHAP:
assert out[0, 1] == 0.375 assert out[0, 1] == 0.375
assert out[0, 2] == 0.25 assert out[0, 2] == 0.25
def parse_model(model): def parse_model(model: xgb.Booster) -> list:
trees = [] trees = []
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)" r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)" r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
@ -93,7 +91,9 @@ class TestSHAP:
match = re.search(r_exp, line) match = re.search(r_exp, line)
if match is not None: if match is not None:
ind = int(match.group(1)) ind = int(match.group(1))
assert trees[-1] is not None
while ind >= len(trees[-1]): while ind >= len(trees[-1]):
assert isinstance(trees[-1], list)
trees[-1].append(None) trees[-1].append(None)
trees[-1][ind] = { trees[-1][ind] = {
"yes_ind": int(match.group(4)), "yes_ind": int(match.group(4)),
@ -101,17 +101,16 @@ class TestSHAP:
"value": None, "value": None,
"threshold": float(match.group(3)), "threshold": float(match.group(3)),
"feature_index": int(match.group(2)), "feature_index": int(match.group(2)),
"cover": float(match.group(6)) "cover": float(match.group(6)),
} }
else: else:
match = re.search(r_exp_leaf, line) match = re.search(r_exp_leaf, line)
ind = int(match.group(1)) ind = int(match.group(1))
while ind >= len(trees[-1]): while ind >= len(trees[-1]):
trees[-1].append(None) trees[-1].append(None)
trees[-1][ind] = { trees[-1][ind] = {
"value": float(match.group(2)), "value": float(match.group(2)),
"cover": float(match.group(3)) "cover": float(match.group(3)),
} }
return trees return trees
@ -121,7 +120,8 @@ class TestSHAP:
else: else:
ind = tree[i]["feature_index"] ind = tree[i]["feature_index"]
if z[ind] == 1: if z[ind] == 1:
if x[ind] < tree[i]["threshold"]: # 1e-6 for numeric error from parsing text dump.
if x[ind] + 1e-6 <= tree[i]["threshold"]:
return exp_value_rec(tree, z, x, tree[i]["yes_ind"]) return exp_value_rec(tree, z, x, tree[i]["yes_ind"])
else: else:
return exp_value_rec(tree, z, x, tree[i]["no_ind"]) return exp_value_rec(tree, z, x, tree[i]["no_ind"])
@ -136,10 +136,13 @@ class TestSHAP:
return val return val
def exp_value(trees, z, x): def exp_value(trees, z, x):
"E[f(z)|Z_s = X_s]"
return np.sum([exp_value_rec(tree, z, x) for tree in trees]) return np.sum([exp_value_rec(tree, z, x) for tree in trees])
def all_subsets(ss): def all_subsets(ss):
return itertools.chain(*map(lambda x: itertools.combinations(ss, x), range(0, len(ss) + 1))) return itertools.chain(
*map(lambda x: itertools.combinations(ss, x), range(0, len(ss) + 1))
)
def shap_value(trees, x, i, cond=None, cond_value=None): def shap_value(trees, x, i, cond=None, cond_value=None):
M = len(x) M = len(x)
@ -196,7 +199,9 @@ class TestSHAP:
z[i] = 0 z[i] = 0
v01 = exp_value(trees, z, x) v01 = exp_value(trees, z, x)
z[j] = 0 z[j] = 0
total += (v11 - v01 - v10 + v00) / (scipy.special.binom(M - 2, len(subset)) * (M - 1)) total += (v11 - v01 - v10 + v00) / (
scipy.special.binom(M - 2, len(subset)) * (M - 1)
)
z[list(subset)] = 0 z[list(subset)] = 0
return total return total
@ -220,11 +225,10 @@ class TestSHAP:
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4 assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
# test a random function # test a random function
np.random.seed(0)
M = 2 M = 2
N = 4 N = 4
X = np.random.randn(N, M) X = rng.randn(N, M)
y = np.random.randn(N) y = rng.randn(N)
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0} param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1) bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
brute_force = shap_values(parse_model(bst), X[0, :]) brute_force = shap_values(parse_model(bst), X[0, :])
@ -236,11 +240,10 @@ class TestSHAP:
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4 assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
# test another larger more complex random function # test another larger more complex random function
np.random.seed(0)
M = 5 M = 5
N = 100 N = 100
X = np.random.randn(N, M) X = rng.randn(N, M)
y = np.random.randn(N) y = rng.randn(N)
base_score = 1.0 base_score = 1.0
param = {"max_depth": 5, "base_score": base_score, "eta": 0.1, "gamma": 2.0} param = {"max_depth": 5, "base_score": base_score, "eta": 0.1, "gamma": 2.0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 10) bst = xgb.train(param, xgb.DMatrix(X, label=y), 10)

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import Optional, Tuple from typing import List, Optional, Tuple, cast
import numpy as np import numpy as np
import pytest import pytest
@ -62,8 +62,8 @@ def test_aft_survival_toy_data(
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1)) X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
dmat, y_lower, y_upper = toy_data dmat, y_lower, y_upper = toy_data
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper)
# the corresponding predicted label (y_pred) # includes the corresponding predicted label (y_pred)
acc_rec = [] acc_rec = []
class Callback(xgb.callback.TrainingCallback): class Callback(xgb.callback.TrainingCallback):
@ -71,21 +71,33 @@ def test_aft_survival_toy_data(
super().__init__() super().__init__()
def after_iteration( def after_iteration(
self, model: xgb.Booster, self,
model: xgb.Booster,
epoch: int, epoch: int,
evals_log: xgb.callback.TrainingCallback.EvalsLog evals_log: xgb.callback.TrainingCallback.EvalsLog,
): ):
y_pred = model.predict(dmat) y_pred = model.predict(dmat)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X)) acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper) / len(X))
acc_rec.append(acc) acc_rec.append(acc)
return False return False
evals_result = {} evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0} params = {
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result, "max_depth": 3,
callbacks=[Callback()]) "objective": "survival:aft",
"min_child_weight": 0,
"tree_method": "exact",
}
bst = xgb.train(
params,
dmat,
15,
[(dmat, "train")],
evals_result=evals_result,
callbacks=[Callback()],
)
nloglik_rec = evals_result['train']['aft-nloglik'] nloglik_rec = cast(List[float], evals_result["train"]["aft-nloglik"])
# AFT metric (negative log likelihood) improve monotonically # AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1])) assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
# "Accuracy" improve monotonically. # "Accuracy" improve monotonically.
@ -94,15 +106,17 @@ def test_aft_survival_toy_data(
assert acc_rec[-1] == 1.0 assert acc_rec[-1] == 1.0
def gather_split_thresholds(tree): def gather_split_thresholds(tree):
if 'split_condition' in tree: if "split_condition" in tree:
return (gather_split_thresholds(tree['children'][0]) return (
| gather_split_thresholds(tree['children'][1]) gather_split_thresholds(tree["children"][0])
| {tree['split_condition']}) | gather_split_thresholds(tree["children"][1])
| {tree["split_condition"]}
)
return set() return set()
# Only 2.5, 3.5, and 4.5 are used as split thresholds. # Only 2.5, 3.5, and 4.5 are used as split thresholds.
model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')] model_json = [json.loads(e) for e in bst.get_dump(dump_format="json")]
for tree in model_json: for i, tree in enumerate(model_json):
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5}) assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})

View File

@ -475,18 +475,22 @@ def test_rf_regression():
run_housing_rf_regression("hist") run_housing_rf_regression("hist")
def test_parameter_tuning(): @pytest.mark.parametrize("tree_method", ["exact", "hist", "approx"])
def test_parameter_tuning(tree_method: str) -> None:
from sklearn.datasets import fetch_california_housing from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
xgb_model = xgb.XGBRegressor(learning_rate=0.1) reg = xgb.XGBRegressor(learning_rate=0.1, tree_method=tree_method)
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4], grid_cv = GridSearchCV(
'n_estimators': [50, 200]}, reg, {"max_depth": [2, 4], "n_estimators": [50, 200]}, cv=2, verbose=1
cv=2, verbose=1) )
clf.fit(X, y) grid_cv.fit(X, y)
assert clf.best_score_ < 0.7 assert grid_cv.best_score_ < 0.7
assert clf.best_params_ == {'n_estimators': 200, 'max_depth': 4} assert grid_cv.best_params_ == {
"n_estimators": 200,
"max_depth": 4 if tree_method == "exact" else 2,
}
def test_regression_with_custom_objective(): def test_regression_with_custom_objective():
@ -750,7 +754,7 @@ def test_parameters_access():
]["tree_method"] ]["tree_method"]
return tm return tm
assert get_tm(clf) == "exact" assert get_tm(clf) == "auto" # Kept as auto, immutable since 2.0
clf = pickle.loads(pickle.dumps(clf)) clf = pickle.loads(pickle.dumps(clf))
@ -758,7 +762,7 @@ def test_parameters_access():
assert clf.n_estimators == 2 assert clf.n_estimators == 2
assert clf.get_params()["tree_method"] is None assert clf.get_params()["tree_method"] is None
assert clf.get_params()["n_estimators"] == 2 assert clf.get_params()["n_estimators"] == 2
assert get_tm(clf) == "exact" # preserved for pickle assert get_tm(clf) == "auto" # preserved for pickle
clf = save_load(clf) clf = save_load(clf)