Use hist as the default tree method. (#9320)
This commit is contained in:
parent
bc267dd729
commit
f4798718c7
@ -85,9 +85,18 @@ test_that("dart prediction works", {
|
||||
rnorm(100)
|
||||
|
||||
set.seed(1994)
|
||||
booster_by_xgboost <- xgboost(data = d, label = y, max_depth = 2, booster = "dart",
|
||||
rate_drop = 0.5, one_drop = TRUE,
|
||||
eta = 1, nthread = 2, nrounds = nrounds, objective = "reg:squarederror")
|
||||
booster_by_xgboost <- xgboost(
|
||||
data = d,
|
||||
label = y,
|
||||
max_depth = 2,
|
||||
booster = "dart",
|
||||
rate_drop = 0.5,
|
||||
one_drop = TRUE,
|
||||
eta = 1,
|
||||
nthread = 2,
|
||||
nrounds = nrounds,
|
||||
objective = "reg:squarederror"
|
||||
)
|
||||
pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, ntreelimit = 0)
|
||||
pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, ntreelimit = nrounds)
|
||||
expect_true(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE)))
|
||||
@ -97,19 +106,19 @@ test_that("dart prediction works", {
|
||||
|
||||
set.seed(1994)
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y))
|
||||
booster_by_train <- xgb.train(params = list(
|
||||
booster = "dart",
|
||||
max_depth = 2,
|
||||
eta = 1,
|
||||
rate_drop = 0.5,
|
||||
one_drop = TRUE,
|
||||
nthread = 1,
|
||||
tree_method = "exact",
|
||||
objective = "reg:squarederror"
|
||||
),
|
||||
data = dtrain,
|
||||
nrounds = nrounds
|
||||
)
|
||||
booster_by_train <- xgb.train(
|
||||
params = list(
|
||||
booster = "dart",
|
||||
max_depth = 2,
|
||||
eta = 1,
|
||||
rate_drop = 0.5,
|
||||
one_drop = TRUE,
|
||||
nthread = 1,
|
||||
objective = "reg:squarederror"
|
||||
),
|
||||
data = dtrain,
|
||||
nrounds = nrounds
|
||||
)
|
||||
pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, ntreelimit = 0)
|
||||
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, ntreelimit = nrounds)
|
||||
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
|
||||
@ -399,7 +408,7 @@ test_that("colsample_bytree works", {
|
||||
xgb.importance(model = bst)
|
||||
# If colsample_bytree works properly, a variety of features should be used
|
||||
# in the 100 trees
|
||||
expect_gte(nrow(xgb.importance(model = bst)), 30)
|
||||
expect_gte(nrow(xgb.importance(model = bst)), 28)
|
||||
})
|
||||
|
||||
test_that("Configuration works", {
|
||||
|
||||
@ -13,7 +13,10 @@ test_that("updating the model works", {
|
||||
watchlist <- list(train = dtrain, test = dtest)
|
||||
|
||||
# no-subsampling
|
||||
p1 <- list(objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2)
|
||||
p1 <- list(
|
||||
objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2,
|
||||
updater = "grow_colmaker,prune"
|
||||
)
|
||||
set.seed(11)
|
||||
bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0)
|
||||
tr1 <- xgb.model.dt.tree(model = bst1)
|
||||
|
||||
@ -39,7 +39,6 @@ namespace xgboost::gbm {
|
||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||
|
||||
void GBTree::Configure(Args const& cfg) {
|
||||
this->cfg_ = cfg;
|
||||
std::string updater_seq = tparam_.updater_seq;
|
||||
tparam_.UpdateAllowUnknown(cfg);
|
||||
tree_param_.UpdateAllowUnknown(cfg);
|
||||
@ -78,10 +77,9 @@ void GBTree::Configure(Args const& cfg) {
|
||||
|
||||
monitor_.Init("GBTree");
|
||||
|
||||
specified_updater_ = std::any_of(cfg.cbegin(), cfg.cend(),
|
||||
[](std::pair<std::string, std::string> const& arg) {
|
||||
return arg.first == "updater";
|
||||
});
|
||||
specified_updater_ = std::any_of(
|
||||
cfg.cbegin(), cfg.cend(),
|
||||
[](std::pair<std::string, std::string> const& arg) { return arg.first == "updater"; });
|
||||
|
||||
if (specified_updater_ && !showed_updater_warning_) {
|
||||
LOG(WARNING) << "DANGER AHEAD: You have manually specified `updater` "
|
||||
@ -93,12 +91,19 @@ void GBTree::Configure(Args const& cfg) {
|
||||
showed_updater_warning_ = true;
|
||||
}
|
||||
|
||||
if (model_.learner_model_param->IsVectorLeaf()) {
|
||||
CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto)
|
||||
<< "Only the hist tree method is supported for building multi-target trees with vector "
|
||||
"leaf.";
|
||||
}
|
||||
LOG(DEBUG) << "Using tree method: " << static_cast<int>(tparam_.tree_method);
|
||||
this->ConfigureUpdaters();
|
||||
|
||||
if (updater_seq != tparam_.updater_seq) {
|
||||
updaters_.clear();
|
||||
this->InitUpdater(cfg);
|
||||
} else {
|
||||
for (auto &up : updaters_) {
|
||||
for (auto& up : updaters_) {
|
||||
up->Configure(cfg);
|
||||
}
|
||||
}
|
||||
@ -106,66 +111,6 @@ void GBTree::Configure(Args const& cfg) {
|
||||
configured_ = true;
|
||||
}
|
||||
|
||||
// FIXME(trivialfis): This handles updaters. Because the choice of updaters depends on
|
||||
// whether external memory is used and how large is dataset. We can remove the dependency
|
||||
// on DMatrix once `hist` tree method can handle external memory so that we can make it
|
||||
// default.
|
||||
void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
|
||||
CHECK(this->configured_);
|
||||
std::string updater_seq = tparam_.updater_seq;
|
||||
CHECK(tparam_.GetInitialised());
|
||||
|
||||
tparam_.UpdateAllowUnknown(cfg);
|
||||
|
||||
this->PerformTreeMethodHeuristic(fmat);
|
||||
this->ConfigureUpdaters();
|
||||
|
||||
// initialize the updaters only when needed.
|
||||
if (updater_seq != tparam_.updater_seq) {
|
||||
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
|
||||
this->updaters_.clear();
|
||||
this->InitUpdater(cfg);
|
||||
}
|
||||
}
|
||||
|
||||
void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
||||
if (specified_updater_) {
|
||||
// This method is disabled when `updater` parameter is explicitly
|
||||
// set, since only experts are expected to do so.
|
||||
return;
|
||||
}
|
||||
if (model_.learner_model_param->IsVectorLeaf()) {
|
||||
CHECK(tparam_.tree_method == TreeMethod::kHist)
|
||||
<< "Only the hist tree method is supported for building multi-target trees with vector "
|
||||
"leaf.";
|
||||
}
|
||||
|
||||
// tparam_ is set before calling this function.
|
||||
if (tparam_.tree_method != TreeMethod::kAuto) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(INFO) << "Tree method is automatically selected to be 'approx' "
|
||||
"for distributed training.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else if (!fmat->SingleColBlock()) {
|
||||
LOG(INFO) << "Tree method is automatically set to 'approx' "
|
||||
"since external-memory data matrix is used.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else if (fmat->Info().num_row_ >= (4UL << 20UL)) {
|
||||
/* Choose tree_method='approx' automatically for large data matrix */
|
||||
LOG(INFO) << "Tree method is automatically selected to be "
|
||||
"'approx' for faster speed. To use old behavior "
|
||||
"(exact greedy algorithm on single machine), "
|
||||
"set tree_method to 'exact'.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else {
|
||||
tparam_.tree_method = TreeMethod::kExact;
|
||||
}
|
||||
LOG(DEBUG) << "Using tree method: " << static_cast<int>(tparam_.tree_method);
|
||||
}
|
||||
|
||||
void GBTree::ConfigureUpdaters() {
|
||||
if (specified_updater_) {
|
||||
return;
|
||||
@ -173,31 +118,25 @@ void GBTree::ConfigureUpdaters() {
|
||||
// `updater` parameter was manually specified
|
||||
/* Choose updaters according to tree_method parameters */
|
||||
switch (tparam_.tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
// Use heuristic to choose between 'exact' and 'approx' This
|
||||
// choice is carried out in PerformTreeMethodHeuristic() before
|
||||
// calling this function.
|
||||
case TreeMethod::kAuto: // Use hist as default in 2.0
|
||||
case TreeMethod::kHist: {
|
||||
tparam_.updater_seq = "grow_quantile_histmaker";
|
||||
break;
|
||||
}
|
||||
case TreeMethod::kApprox:
|
||||
tparam_.updater_seq = "grow_histmaker";
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
tparam_.updater_seq = "grow_colmaker,prune";
|
||||
break;
|
||||
case TreeMethod::kHist: {
|
||||
LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater "
|
||||
"grow_quantile_histmaker.";
|
||||
tparam_.updater_seq = "grow_quantile_histmaker";
|
||||
break;
|
||||
}
|
||||
case TreeMethod::kGPUHist: {
|
||||
common::AssertGPUSupport();
|
||||
tparam_.updater_seq = "grow_gpu_hist";
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown tree_method ("
|
||||
<< static_cast<int>(tparam_.tree_method) << ") detected";
|
||||
LOG(FATAL) << "Unknown tree_method (" << static_cast<int>(tparam_.tree_method)
|
||||
<< ") detected";
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,7 +192,6 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
||||
PredictionCacheEntry* predt, ObjFunction const* obj) {
|
||||
TreesOneIter new_trees;
|
||||
bst_target_t const n_groups = model_.learner_model_param->OutputLength();
|
||||
ConfigureWithKnownData(this->cfg_, p_fmat);
|
||||
monitor_.Start("BoostNewTrees");
|
||||
|
||||
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
|
||||
|
||||
@ -56,9 +56,7 @@ DECLARE_FIELD_ENUM_CLASS(xgboost::TreeMethod);
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeProcessType);
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::PredictorType);
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
namespace xgboost::gbm {
|
||||
/*! \brief training parameters */
|
||||
struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
|
||||
/*! \brief tree updater sequence */
|
||||
@ -192,12 +190,8 @@ class GBTree : public GradientBooster {
|
||||
: GradientBooster{ctx}, model_(booster_config, ctx_) {}
|
||||
|
||||
void Configure(const Args& cfg) override;
|
||||
// Revise `tree_method` and `updater` parameters after seeing the training
|
||||
// data matrix, only useful when tree_method is auto.
|
||||
void PerformTreeMethodHeuristic(DMatrix* fmat);
|
||||
/*! \brief Map `tree_method` parameter to `updater` parameter */
|
||||
void ConfigureUpdaters();
|
||||
void ConfigureWithKnownData(Args const& cfg, DMatrix* fmat);
|
||||
|
||||
/**
|
||||
* \brief Optionally update the leaf value.
|
||||
@ -222,11 +216,7 @@ class GBTree : public GradientBooster {
|
||||
return tparam_;
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
model_.Load(fi);
|
||||
this->cfg_.clear();
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override { model_.Load(fi); }
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
model_.Save(fo);
|
||||
}
|
||||
@ -416,8 +406,6 @@ class GBTree : public GradientBooster {
|
||||
bool showed_updater_warning_ {false};
|
||||
bool specified_updater_ {false};
|
||||
bool configured_ {false};
|
||||
// configurations for tree
|
||||
Args cfg_;
|
||||
// the updaters that can be applied to each of tree
|
||||
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
||||
// Predictors
|
||||
@ -431,7 +419,6 @@ class GBTree : public GradientBooster {
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::gbm
|
||||
|
||||
#endif // XGBOOST_GBM_GBTREE_H_
|
||||
|
||||
@ -23,6 +23,7 @@ class LintersPaths:
|
||||
"tests/python/test_predict.py",
|
||||
"tests/python/test_quantile_dmatrix.py",
|
||||
"tests/python/test_tree_regularization.py",
|
||||
"tests/python/test_shap.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/test_distributed/test_with_spark/",
|
||||
"tests/test_distributed/test_gpu_with_spark/",
|
||||
|
||||
@ -379,6 +379,8 @@ TEST(Learner, Seed) {
|
||||
TEST(Learner, ConstantSeed) {
|
||||
auto m = RandomDataGenerator{10, 10, 0}.GenerateDMatrix(true);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
// Use exact as it doesn't initialize column sampler at construction, which alters the rng.
|
||||
learner->SetParam("tree_method", "exact");
|
||||
learner->Configure(); // seed the global random
|
||||
|
||||
std::uniform_real_distribution<float> dist;
|
||||
|
||||
@ -18,9 +18,8 @@ CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI')
|
||||
def test_basic_walkthrough():
|
||||
script = os.path.join(PYTHON_DEMO_DIR, 'basic_walkthrough.py')
|
||||
cmd = ['python', script]
|
||||
subprocess.check_call(cmd)
|
||||
os.remove('dump.nice.txt')
|
||||
os.remove('dump.raw.txt')
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
subprocess.check_call(cmd, cwd=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||
|
||||
@ -6,35 +6,34 @@ import scipy
|
||||
import scipy.special
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
from xgboost import testing as tm
|
||||
|
||||
|
||||
class TestSHAP:
|
||||
|
||||
def test_feature_importances(self):
|
||||
data = np.random.randn(100, 5)
|
||||
def test_feature_importances(self) -> None:
|
||||
rng = np.random.RandomState(1994)
|
||||
data = rng.randn(100, 5)
|
||||
target = np.array([0, 1] * 50)
|
||||
|
||||
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
|
||||
features = ["Feature1", "Feature2", "Feature3", "Feature4", "Feature5"]
|
||||
|
||||
dm = xgb.DMatrix(data, label=target,
|
||||
feature_names=features)
|
||||
params = {'objective': 'multi:softprob',
|
||||
'eval_metric': 'mlogloss',
|
||||
'eta': 0.3,
|
||||
'num_class': 3}
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
params = {
|
||||
"objective": "multi:softprob",
|
||||
"eval_metric": "mlogloss",
|
||||
"eta": 0.3,
|
||||
"num_class": 3,
|
||||
}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=10)
|
||||
|
||||
# number of feature importances should == number of features
|
||||
scores1 = bst.get_score()
|
||||
scores2 = bst.get_score(importance_type='weight')
|
||||
scores3 = bst.get_score(importance_type='cover')
|
||||
scores4 = bst.get_score(importance_type='gain')
|
||||
scores5 = bst.get_score(importance_type='total_cover')
|
||||
scores6 = bst.get_score(importance_type='total_gain')
|
||||
scores2 = bst.get_score(importance_type="weight")
|
||||
scores3 = bst.get_score(importance_type="cover")
|
||||
scores4 = bst.get_score(importance_type="gain")
|
||||
scores5 = bst.get_score(importance_type="total_cover")
|
||||
scores6 = bst.get_score(importance_type="total_gain")
|
||||
assert len(scores1) == len(features)
|
||||
assert len(scores2) == len(features)
|
||||
assert len(scores3) == len(features)
|
||||
@ -46,12 +45,11 @@ class TestSHAP:
|
||||
fscores = bst.get_fscore()
|
||||
assert scores1 == fscores
|
||||
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?format=libsvm')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?format=libsvm')
|
||||
dtrain, dtest = tm.load_agaricus(__file__)
|
||||
|
||||
def fn(max_depth, num_rounds):
|
||||
def fn(max_depth: int, num_rounds: int) -> None:
|
||||
# train
|
||||
params = {'max_depth': max_depth, 'eta': 1, 'verbosity': 0}
|
||||
params = {"max_depth": max_depth, "eta": 1, "verbosity": 0}
|
||||
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
|
||||
|
||||
# predict
|
||||
@ -82,7 +80,7 @@ class TestSHAP:
|
||||
assert out[0, 1] == 0.375
|
||||
assert out[0, 2] == 0.25
|
||||
|
||||
def parse_model(model):
|
||||
def parse_model(model: xgb.Booster) -> list:
|
||||
trees = []
|
||||
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
|
||||
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
|
||||
@ -93,7 +91,9 @@ class TestSHAP:
|
||||
match = re.search(r_exp, line)
|
||||
if match is not None:
|
||||
ind = int(match.group(1))
|
||||
assert trees[-1] is not None
|
||||
while ind >= len(trees[-1]):
|
||||
assert isinstance(trees[-1], list)
|
||||
trees[-1].append(None)
|
||||
trees[-1][ind] = {
|
||||
"yes_ind": int(match.group(4)),
|
||||
@ -101,17 +101,16 @@ class TestSHAP:
|
||||
"value": None,
|
||||
"threshold": float(match.group(3)),
|
||||
"feature_index": int(match.group(2)),
|
||||
"cover": float(match.group(6))
|
||||
"cover": float(match.group(6)),
|
||||
}
|
||||
else:
|
||||
|
||||
match = re.search(r_exp_leaf, line)
|
||||
ind = int(match.group(1))
|
||||
while ind >= len(trees[-1]):
|
||||
trees[-1].append(None)
|
||||
trees[-1][ind] = {
|
||||
"value": float(match.group(2)),
|
||||
"cover": float(match.group(3))
|
||||
"cover": float(match.group(3)),
|
||||
}
|
||||
return trees
|
||||
|
||||
@ -121,7 +120,8 @@ class TestSHAP:
|
||||
else:
|
||||
ind = tree[i]["feature_index"]
|
||||
if z[ind] == 1:
|
||||
if x[ind] < tree[i]["threshold"]:
|
||||
# 1e-6 for numeric error from parsing text dump.
|
||||
if x[ind] + 1e-6 <= tree[i]["threshold"]:
|
||||
return exp_value_rec(tree, z, x, tree[i]["yes_ind"])
|
||||
else:
|
||||
return exp_value_rec(tree, z, x, tree[i]["no_ind"])
|
||||
@ -136,10 +136,13 @@ class TestSHAP:
|
||||
return val
|
||||
|
||||
def exp_value(trees, z, x):
|
||||
"E[f(z)|Z_s = X_s]"
|
||||
return np.sum([exp_value_rec(tree, z, x) for tree in trees])
|
||||
|
||||
def all_subsets(ss):
|
||||
return itertools.chain(*map(lambda x: itertools.combinations(ss, x), range(0, len(ss) + 1)))
|
||||
return itertools.chain(
|
||||
*map(lambda x: itertools.combinations(ss, x), range(0, len(ss) + 1))
|
||||
)
|
||||
|
||||
def shap_value(trees, x, i, cond=None, cond_value=None):
|
||||
M = len(x)
|
||||
@ -196,7 +199,9 @@ class TestSHAP:
|
||||
z[i] = 0
|
||||
v01 = exp_value(trees, z, x)
|
||||
z[j] = 0
|
||||
total += (v11 - v01 - v10 + v00) / (scipy.special.binom(M - 2, len(subset)) * (M - 1))
|
||||
total += (v11 - v01 - v10 + v00) / (
|
||||
scipy.special.binom(M - 2, len(subset)) * (M - 1)
|
||||
)
|
||||
z[list(subset)] = 0
|
||||
return total
|
||||
|
||||
@ -220,11 +225,10 @@ class TestSHAP:
|
||||
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
|
||||
|
||||
# test a random function
|
||||
np.random.seed(0)
|
||||
M = 2
|
||||
N = 4
|
||||
X = np.random.randn(N, M)
|
||||
y = np.random.randn(N)
|
||||
X = rng.randn(N, M)
|
||||
y = rng.randn(N)
|
||||
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
|
||||
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
|
||||
brute_force = shap_values(parse_model(bst), X[0, :])
|
||||
@ -236,11 +240,10 @@ class TestSHAP:
|
||||
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
|
||||
|
||||
# test another larger more complex random function
|
||||
np.random.seed(0)
|
||||
M = 5
|
||||
N = 100
|
||||
X = np.random.randn(N, M)
|
||||
y = np.random.randn(N)
|
||||
X = rng.randn(N, M)
|
||||
y = rng.randn(N)
|
||||
base_score = 1.0
|
||||
param = {"max_depth": 5, "base_score": base_score, "eta": 0.1, "gamma": 2.0}
|
||||
bst = xgb.train(param, xgb.DMatrix(X, label=y), 10)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -62,8 +62,8 @@ def test_aft_survival_toy_data(
|
||||
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
||||
dmat, y_lower, y_upper = toy_data
|
||||
|
||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||
# the corresponding predicted label (y_pred)
|
||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper)
|
||||
# includes the corresponding predicted label (y_pred)
|
||||
acc_rec = []
|
||||
|
||||
class Callback(xgb.callback.TrainingCallback):
|
||||
@ -71,21 +71,33 @@ def test_aft_survival_toy_data(
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster,
|
||||
self,
|
||||
model: xgb.Booster,
|
||||
epoch: int,
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog,
|
||||
):
|
||||
y_pred = model.predict(dmat)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper) / len(X))
|
||||
acc_rec.append(acc)
|
||||
return False
|
||||
|
||||
evals_result = {}
|
||||
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0}
|
||||
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
||||
callbacks=[Callback()])
|
||||
evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||
params = {
|
||||
"max_depth": 3,
|
||||
"objective": "survival:aft",
|
||||
"min_child_weight": 0,
|
||||
"tree_method": "exact",
|
||||
}
|
||||
bst = xgb.train(
|
||||
params,
|
||||
dmat,
|
||||
15,
|
||||
[(dmat, "train")],
|
||||
evals_result=evals_result,
|
||||
callbacks=[Callback()],
|
||||
)
|
||||
|
||||
nloglik_rec = evals_result['train']['aft-nloglik']
|
||||
nloglik_rec = cast(List[float], evals_result["train"]["aft-nloglik"])
|
||||
# AFT metric (negative log likelihood) improve monotonically
|
||||
assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
|
||||
# "Accuracy" improve monotonically.
|
||||
@ -94,15 +106,17 @@ def test_aft_survival_toy_data(
|
||||
assert acc_rec[-1] == 1.0
|
||||
|
||||
def gather_split_thresholds(tree):
|
||||
if 'split_condition' in tree:
|
||||
return (gather_split_thresholds(tree['children'][0])
|
||||
| gather_split_thresholds(tree['children'][1])
|
||||
| {tree['split_condition']})
|
||||
if "split_condition" in tree:
|
||||
return (
|
||||
gather_split_thresholds(tree["children"][0])
|
||||
| gather_split_thresholds(tree["children"][1])
|
||||
| {tree["split_condition"]}
|
||||
)
|
||||
return set()
|
||||
|
||||
# Only 2.5, 3.5, and 4.5 are used as split thresholds.
|
||||
model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')]
|
||||
for tree in model_json:
|
||||
model_json = [json.loads(e) for e in bst.get_dump(dump_format="json")]
|
||||
for i, tree in enumerate(model_json):
|
||||
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
||||
|
||||
|
||||
|
||||
@ -475,18 +475,22 @@ def test_rf_regression():
|
||||
run_housing_rf_regression("hist")
|
||||
|
||||
|
||||
def test_parameter_tuning():
|
||||
@pytest.mark.parametrize("tree_method", ["exact", "hist", "approx"])
|
||||
def test_parameter_tuning(tree_method: str) -> None:
|
||||
from sklearn.datasets import fetch_california_housing
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
X, y = fetch_california_housing(return_X_y=True)
|
||||
xgb_model = xgb.XGBRegressor(learning_rate=0.1)
|
||||
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4],
|
||||
'n_estimators': [50, 200]},
|
||||
cv=2, verbose=1)
|
||||
clf.fit(X, y)
|
||||
assert clf.best_score_ < 0.7
|
||||
assert clf.best_params_ == {'n_estimators': 200, 'max_depth': 4}
|
||||
reg = xgb.XGBRegressor(learning_rate=0.1, tree_method=tree_method)
|
||||
grid_cv = GridSearchCV(
|
||||
reg, {"max_depth": [2, 4], "n_estimators": [50, 200]}, cv=2, verbose=1
|
||||
)
|
||||
grid_cv.fit(X, y)
|
||||
assert grid_cv.best_score_ < 0.7
|
||||
assert grid_cv.best_params_ == {
|
||||
"n_estimators": 200,
|
||||
"max_depth": 4 if tree_method == "exact" else 2,
|
||||
}
|
||||
|
||||
|
||||
def test_regression_with_custom_objective():
|
||||
@ -750,7 +754,7 @@ def test_parameters_access():
|
||||
]["tree_method"]
|
||||
return tm
|
||||
|
||||
assert get_tm(clf) == "exact"
|
||||
assert get_tm(clf) == "auto" # Kept as auto, immutable since 2.0
|
||||
|
||||
clf = pickle.loads(pickle.dumps(clf))
|
||||
|
||||
@ -758,7 +762,7 @@ def test_parameters_access():
|
||||
assert clf.n_estimators == 2
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
assert clf.get_params()["n_estimators"] == 2
|
||||
assert get_tm(clf) == "exact" # preserved for pickle
|
||||
assert get_tm(clf) == "auto" # preserved for pickle
|
||||
|
||||
clf = save_load(clf)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user