Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
parent
173096a6a7
commit
228a46e8ad
@ -53,9 +53,8 @@ def quantile_loss(args: argparse.Namespace) -> None:
|
|||||||
"tree_method": "hist",
|
"tree_method": "hist",
|
||||||
"quantile_alpha": alpha,
|
"quantile_alpha": alpha,
|
||||||
# Let's try not to overfit.
|
# Let's try not to overfit.
|
||||||
"learning_rate": 0.01,
|
"learning_rate": 0.04,
|
||||||
"max_depth": 3,
|
"max_depth": 5,
|
||||||
"min_child_weight": 16.0,
|
|
||||||
},
|
},
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=32,
|
num_boost_round=32,
|
||||||
@ -80,9 +79,8 @@ def quantile_loss(args: argparse.Namespace) -> None:
|
|||||||
"objective": "reg:squarederror",
|
"objective": "reg:squarederror",
|
||||||
"tree_method": "hist",
|
"tree_method": "hist",
|
||||||
# Let's try not to overfit.
|
# Let's try not to overfit.
|
||||||
"learning_rate": 0.01,
|
"learning_rate": 0.04,
|
||||||
"max_depth": 3,
|
"max_depth": 5,
|
||||||
"min_child_weight": 16.0,
|
|
||||||
},
|
},
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=32,
|
num_boost_round=32,
|
||||||
|
|||||||
@ -116,12 +116,13 @@ class ObjFunction : public Configurable {
|
|||||||
*
|
*
|
||||||
* \param position The leaf index for each rows.
|
* \param position The leaf index for each rows.
|
||||||
* \param info MetaInfo providing labels and weights.
|
* \param info MetaInfo providing labels and weights.
|
||||||
|
* \param learning_rate The learning rate for current iteration.
|
||||||
* \param prediction Model prediction after transformation.
|
* \param prediction Model prediction after transformation.
|
||||||
* \param group_idx The group index for this tree, 0 when it's not multi-target or multi-class.
|
* \param group_idx The group index for this tree, 0 when it's not multi-target or multi-class.
|
||||||
* \param p_tree Tree that needs to be updated.
|
* \param p_tree Tree that needs to be updated.
|
||||||
*/
|
*/
|
||||||
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
|
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
|
||||||
MetaInfo const& /*info*/,
|
MetaInfo const& /*info*/, float /*learning_rate*/,
|
||||||
HostDeviceVector<float> const& /*prediction*/,
|
HostDeviceVector<float> const& /*prediction*/,
|
||||||
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
|
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,9 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
struct TrainParam;
|
||||||
|
}
|
||||||
|
|
||||||
class Json;
|
class Json;
|
||||||
struct Context;
|
struct Context;
|
||||||
@ -56,8 +59,10 @@ class TreeUpdater : public Configurable {
|
|||||||
* tree can be used.
|
* tree can be used.
|
||||||
*/
|
*/
|
||||||
virtual bool HasNodePosition() const { return false; }
|
virtual bool HasNodePosition() const { return false; }
|
||||||
/*!
|
/**
|
||||||
* \brief perform update to the tree models
|
* \brief perform update to the tree models
|
||||||
|
*
|
||||||
|
* \param param Hyper-parameter for constructing trees.
|
||||||
* \param gpair the gradient pair statistics of the data
|
* \param gpair the gradient pair statistics of the data
|
||||||
* \param data The data matrix passed to the updater.
|
* \param data The data matrix passed to the updater.
|
||||||
* \param out_position The leaf index for each row. The index is negated if that row is
|
* \param out_position The leaf index for each row. The index is negated if that row is
|
||||||
@ -67,8 +72,8 @@ class TreeUpdater : public Configurable {
|
|||||||
* but maybe different random seeds, usually one tree is passed in at a time,
|
* but maybe different random seeds, usually one tree is passed in at a time,
|
||||||
* there can be multiple trees when we train random forest style model
|
* there can be multiple trees when we train random forest style model
|
||||||
*/
|
*/
|
||||||
virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
|
virtual void Update(tree::TrainParam const* param, HostDeviceVector<GradientPair>* gpair,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree*>& out_trees) = 0;
|
const std::vector<RegTree*>& out_trees) = 0;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -32,15 +32,14 @@
|
|||||||
#include "xgboost/string_view.h"
|
#include "xgboost/string_view.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::gbm {
|
||||||
namespace gbm {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||||
|
|
||||||
void GBTree::Configure(const Args& cfg) {
|
void GBTree::Configure(Args const& cfg) {
|
||||||
this->cfg_ = cfg;
|
this->cfg_ = cfg;
|
||||||
std::string updater_seq = tparam_.updater_seq;
|
std::string updater_seq = tparam_.updater_seq;
|
||||||
tparam_.UpdateAllowUnknown(cfg);
|
tparam_.UpdateAllowUnknown(cfg);
|
||||||
|
tree_param_.UpdateAllowUnknown(cfg);
|
||||||
|
|
||||||
model_.Configure(cfg);
|
model_.Configure(cfg);
|
||||||
|
|
||||||
@ -235,9 +234,11 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
|
|||||||
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
||||||
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
||||||
<< "Boosting random forest is not supported for current objective.";
|
<< "Boosting random forest is not supported for current objective.";
|
||||||
|
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
|
||||||
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
||||||
auto const& position = node_position.at(tree_idx);
|
auto const& position = node_position.at(tree_idx);
|
||||||
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, group_idx, trees[tree_idx].get());
|
obj->UpdateTreeLeaf(position, p_fmat->Info(), tree_param_.learning_rate / trees.size(),
|
||||||
|
predictions, group_idx, trees[tree_idx].get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,9 +389,15 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
|||||||
|
|
||||||
CHECK(out_position);
|
CHECK(out_position);
|
||||||
out_position->resize(new_trees.size());
|
out_position->resize(new_trees.size());
|
||||||
|
|
||||||
|
// Rescale learning rate according to the size of trees
|
||||||
|
auto lr = tree_param_.learning_rate;
|
||||||
|
tree_param_.learning_rate /= static_cast<float>(new_trees.size());
|
||||||
for (auto& up : updaters_) {
|
for (auto& up : updaters_) {
|
||||||
up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
|
up->Update(&tree_param_, gpair, p_fmat,
|
||||||
|
common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
|
||||||
}
|
}
|
||||||
|
tree_param_.learning_rate = lr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
|
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
|
||||||
@ -404,6 +411,8 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
|||||||
void GBTree::LoadConfig(Json const& in) {
|
void GBTree::LoadConfig(Json const& in) {
|
||||||
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
||||||
FromJson(in["gbtree_train_param"], &tparam_);
|
FromJson(in["gbtree_train_param"], &tparam_);
|
||||||
|
FromJson(in["tree_train_param"], &tree_param_);
|
||||||
|
|
||||||
// Process type cannot be kUpdate from loaded model
|
// Process type cannot be kUpdate from loaded model
|
||||||
// This would cause all trees to be pushed to trees_to_update
|
// This would cause all trees to be pushed to trees_to_update
|
||||||
// e.g. updating a model, then saving and loading it would result in an empty model
|
// e.g. updating a model, then saving and loading it would result in an empty model
|
||||||
@ -451,6 +460,7 @@ void GBTree::SaveConfig(Json* p_out) const {
|
|||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
out["name"] = String("gbtree");
|
out["name"] = String("gbtree");
|
||||||
out["gbtree_train_param"] = ToJson(tparam_);
|
out["gbtree_train_param"] = ToJson(tparam_);
|
||||||
|
out["tree_train_param"] = ToJson(tree_param_);
|
||||||
|
|
||||||
// Process type cannot be kUpdate from loaded model
|
// Process type cannot be kUpdate from loaded model
|
||||||
// This would cause all trees to be pushed to trees_to_update
|
// This would cause all trees to be pushed to trees_to_update
|
||||||
@ -1058,5 +1068,4 @@ XGBOOST_REGISTER_GBM(Dart, "dart")
|
|||||||
GBTree* p = new Dart(booster_config, ctx);
|
GBTree* p = new Dart(booster_config, ctx);
|
||||||
return p;
|
return p;
|
||||||
});
|
});
|
||||||
} // namespace gbm
|
} // namespace xgboost::gbm
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
|
#include "../tree/param.h" // TrainParam
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -405,8 +406,8 @@ class GBTree : public GradientBooster {
|
|||||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||||
std::string format) const override {
|
std::string format) const override {
|
||||||
return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format);
|
return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,6 +429,8 @@ class GBTree : public GradientBooster {
|
|||||||
GBTreeModel model_;
|
GBTreeModel model_;
|
||||||
// training parameter
|
// training parameter
|
||||||
GBTreeTrainParam tparam_;
|
GBTreeTrainParam tparam_;
|
||||||
|
// Tree training parameter
|
||||||
|
tree::TrainParam tree_param_;
|
||||||
// ----training fields----
|
// ----training fields----
|
||||||
bool showed_updater_warning_ {false};
|
bool showed_updater_warning_ {false};
|
||||||
bool specified_updater_ {false};
|
bool specified_updater_ {false};
|
||||||
|
|||||||
@ -76,7 +76,7 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||||
std::int32_t group_idx, MetaInfo const& info,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
size_t n_leaf = nidx.size();
|
size_t n_leaf = nidx.size();
|
||||||
if (nptr.empty()) {
|
if (nptr.empty()) {
|
||||||
std::vector<float> quantiles;
|
std::vector<float> quantiles;
|
||||||
UpdateLeafValues(&quantiles, nidx, p_tree);
|
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,12 +133,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
quantiles.at(k) = q;
|
quantiles.at(k) = q;
|
||||||
});
|
});
|
||||||
|
|
||||||
UpdateLeafValues(&quantiles, nidx, p_tree);
|
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
|
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
|
||||||
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
|
MetaInfo const&, float learning_rate, HostDeviceVector<float> const&,
|
||||||
|
float, RegTree*) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -140,7 +140,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||||
std::int32_t group_idx, MetaInfo const& info,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||||
dh::device_vector<size_t> ridx;
|
dh::device_vector<size_t> ridx;
|
||||||
@ -151,7 +151,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
|
|
||||||
if (nptr.Empty()) {
|
if (nptr.Empty()) {
|
||||||
std::vector<float> quantiles;
|
std::vector<float> quantiles;
|
||||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), p_tree);
|
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<float> quantiles;
|
HostDeviceVector<float> quantiles;
|
||||||
@ -186,7 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
w_it + d_weights.size(), &quantiles);
|
w_it + d_weights.size(), &quantiles);
|
||||||
}
|
}
|
||||||
|
|
||||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), p_tree);
|
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
|
|||||||
@ -36,7 +36,7 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
||||||
RegTree* p_tree) {
|
float learning_rate, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
auto& quantiles = *p_quantiles;
|
auto& quantiles = *p_quantiles;
|
||||||
auto const& h_node_idx = nidx;
|
auto const& h_node_idx = nidx;
|
||||||
@ -71,7 +71,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
auto nidx = h_node_idx[i];
|
auto nidx = h_node_idx[i];
|
||||||
auto q = quantiles[i];
|
auto q = quantiles[i];
|
||||||
CHECK(tree[nidx].IsLeaf());
|
CHECK(tree[nidx].IsLeaf());
|
||||||
tree[nidx].SetLeaf(q);
|
tree[nidx].SetLeaf(q * learning_rate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,24 +85,24 @@ inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||||
std::int32_t group_idx, MetaInfo const& info,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||||
|
|
||||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||||
std::int32_t group_idx, MetaInfo const& info,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
|
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
|
||||||
std::int32_t group_idx, MetaInfo const& info,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
if (ctx->IsCPU()) {
|
if (ctx->IsCPU()) {
|
||||||
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, predt, alpha,
|
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
|
||||||
p_tree);
|
predt, alpha, p_tree);
|
||||||
} else {
|
} else {
|
||||||
position.SetDevice(ctx->gpu_id);
|
position.SetDevice(ctx->gpu_id);
|
||||||
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, predt, alpha,
|
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
|
||||||
p_tree);
|
predt, alpha, p_tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
|
|||||||
@ -183,10 +183,11 @@ class QuantileRegression : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||||
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
|
float learning_rate, HostDeviceVector<float> const& prediction,
|
||||||
RegTree* p_tree) const override {
|
std::int32_t group_idx, RegTree* p_tree) const override {
|
||||||
auto alpha = param_.quantile_alpha[group_idx];
|
auto alpha = param_.quantile_alpha[group_idx];
|
||||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, alpha, p_tree);
|
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction,
|
||||||
|
alpha, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Configure(Args const& args) override {
|
void Configure(Args const& args) override {
|
||||||
|
|||||||
@ -742,9 +742,10 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||||
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
|
float learning_rate, HostDeviceVector<float> const& prediction,
|
||||||
RegTree* p_tree) const override {
|
std::int32_t group_idx, RegTree* p_tree) const override {
|
||||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, 0.5, p_tree);
|
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction, 0.5,
|
||||||
|
p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* DefaultEvalMetric() const override { return "mae"; }
|
const char* DefaultEvalMetric() const override { return "mae"; }
|
||||||
|
|||||||
@ -17,13 +17,11 @@
|
|||||||
#include "../../common/random.h"
|
#include "../../common/random.h"
|
||||||
#include "../../data/gradient_index.h"
|
#include "../../data/gradient_index.h"
|
||||||
#include "../constraints.h"
|
#include "../constraints.h"
|
||||||
#include "../param.h"
|
#include "../param.h" // for TrainParam
|
||||||
#include "../split_evaluator.h"
|
#include "../split_evaluator.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
template <typename ExpandEntry>
|
template <typename ExpandEntry>
|
||||||
class HistEvaluator {
|
class HistEvaluator {
|
||||||
private:
|
private:
|
||||||
@ -36,7 +34,7 @@ class HistEvaluator {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Context const* ctx_;
|
Context const* ctx_;
|
||||||
TrainParam param_;
|
TrainParam const* param_;
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
TreeEvaluator tree_evaluator_;
|
TreeEvaluator tree_evaluator_;
|
||||||
bool is_col_split_{false};
|
bool is_col_split_{false};
|
||||||
@ -55,8 +53,9 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsValid(GradStats const &left, GradStats const &right) const {
|
[[nodiscard]] bool IsValid(GradStats const &left, GradStats const &right) const {
|
||||||
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
|
return left.GetHess() >= param_->min_child_weight &&
|
||||||
|
right.GetHess() >= param_->min_child_weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -95,9 +94,10 @@ class HistEvaluator {
|
|||||||
right_sum = GradStats{hist[i]};
|
right_sum = GradStats{hist[i]};
|
||||||
left_sum.SetSubstract(parent.stats, right_sum);
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
if (IsValid(left_sum, right_sum)) {
|
if (IsValid(left_sum, right_sum)) {
|
||||||
auto missing_left_chg = static_cast<float>(
|
auto missing_left_chg =
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
|
||||||
parent.root_gain);
|
GradStats{right_sum}) -
|
||||||
|
parent.root_gain);
|
||||||
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
|
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,9 +105,10 @@ class HistEvaluator {
|
|||||||
right_sum.Add(missing);
|
right_sum.Add(missing);
|
||||||
left_sum.SetSubstract(parent.stats, right_sum);
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
if (IsValid(left_sum, right_sum)) {
|
if (IsValid(left_sum, right_sum)) {
|
||||||
auto missing_right_chg = static_cast<float>(
|
auto missing_right_chg =
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
|
||||||
parent.root_gain);
|
GradStats{right_sum}) -
|
||||||
|
parent.root_gain);
|
||||||
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
|
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -152,7 +153,7 @@ class HistEvaluator {
|
|||||||
bst_bin_t f_begin = cut_ptr[fidx];
|
bst_bin_t f_begin = cut_ptr[fidx];
|
||||||
bst_bin_t f_end = cut_ptr[fidx + 1];
|
bst_bin_t f_end = cut_ptr[fidx + 1];
|
||||||
bst_bin_t n_bins_feature{f_end - f_begin};
|
bst_bin_t n_bins_feature{f_end - f_begin};
|
||||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);
|
||||||
|
|
||||||
// statistics on both sides of split
|
// statistics on both sides of split
|
||||||
GradStats left_sum;
|
GradStats left_sum;
|
||||||
@ -181,9 +182,9 @@ class HistEvaluator {
|
|||||||
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
|
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
|
||||||
}
|
}
|
||||||
if (IsValid(left_sum, right_sum)) {
|
if (IsValid(left_sum, right_sum)) {
|
||||||
auto loss_chg =
|
auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
GradStats{right_sum}) -
|
||||||
parent.root_gain;
|
parent.root_gain;
|
||||||
// We don't have a numeric split point, nan here is a dummy split.
|
// We don't have a numeric split point, nan here is a dummy split.
|
||||||
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
||||||
left_sum, right_sum)) {
|
left_sum, right_sum)) {
|
||||||
@ -256,7 +257,7 @@ class HistEvaluator {
|
|||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
// forward enumeration: split at right bound of each bin
|
// forward enumeration: split at right bound of each bin
|
||||||
loss_chg =
|
loss_chg =
|
||||||
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
|
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
|
||||||
GradStats{right_sum}) -
|
GradStats{right_sum}) -
|
||||||
parent.root_gain);
|
parent.root_gain);
|
||||||
split_pt = cut_val[i]; // not used for partition based
|
split_pt = cut_val[i]; // not used for partition based
|
||||||
@ -264,7 +265,7 @@ class HistEvaluator {
|
|||||||
} else {
|
} else {
|
||||||
// backward enumeration: split at left bound of each bin
|
// backward enumeration: split at left bound of each bin
|
||||||
loss_chg =
|
loss_chg =
|
||||||
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
|
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
|
||||||
GradStats{left_sum}) -
|
GradStats{left_sum}) -
|
||||||
parent.root_gain);
|
parent.root_gain);
|
||||||
if (i == imin) {
|
if (i == imin) {
|
||||||
@ -326,7 +327,7 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
||||||
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
|
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
|
||||||
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
|
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
|
||||||
} else {
|
} else {
|
||||||
std::vector<size_t> sorted_idx(n_bins);
|
std::vector<size_t> sorted_idx(n_bins);
|
||||||
@ -334,8 +335,8 @@ class HistEvaluator {
|
|||||||
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
|
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
|
||||||
// Sort the histogram to get contiguous partitions.
|
// Sort the histogram to get contiguous partitions.
|
||||||
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
|
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
|
||||||
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
|
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
|
||||||
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
evaluator.CalcWeightCat(*param_, feat_hist[r]);
|
||||||
return ret;
|
return ret;
|
||||||
});
|
});
|
||||||
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||||
@ -382,24 +383,22 @@ class HistEvaluator {
|
|||||||
|
|
||||||
GradStats parent_sum = candidate.split.left_sum;
|
GradStats parent_sum = candidate.split.left_sum;
|
||||||
parent_sum.Add(candidate.split.right_sum);
|
parent_sum.Add(candidate.split.right_sum);
|
||||||
auto base_weight =
|
auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
|
||||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
|
|
||||||
|
|
||||||
auto left_weight =
|
auto left_weight =
|
||||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.left_sum});
|
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
|
||||||
auto right_weight =
|
auto right_weight =
|
||||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
|
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
if (candidate.split.is_cat) {
|
if (candidate.split.is_cat) {
|
||||||
tree.ExpandCategorical(
|
tree.ExpandCategorical(
|
||||||
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
|
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
|
||||||
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
|
candidate.split.DefaultLeft(), base_weight, left_weight * param_->learning_rate,
|
||||||
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
|
right_weight * param_->learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||||
} else {
|
} else {
|
||||||
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||||
candidate.split.DefaultLeft(), base_weight,
|
candidate.split.DefaultLeft(), base_weight,
|
||||||
left_weight * param_.learning_rate, right_weight * param_.learning_rate,
|
left_weight * param_->learning_rate, right_weight * param_->learning_rate,
|
||||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||||
}
|
}
|
||||||
@ -415,11 +414,11 @@ class HistEvaluator {
|
|||||||
max_node = std::max(candidate.nid, max_node);
|
max_node = std::max(candidate.nid, max_node);
|
||||||
snode_.resize(tree.GetNodes().size());
|
snode_.resize(tree.GetNodes().size());
|
||||||
snode_.at(left_child).stats = candidate.split.left_sum;
|
snode_.at(left_child).stats = candidate.split.left_sum;
|
||||||
snode_.at(left_child).root_gain = evaluator.CalcGain(
|
snode_.at(left_child).root_gain =
|
||||||
candidate.nid, param_, GradStats{candidate.split.left_sum});
|
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
|
||||||
snode_.at(right_child).stats = candidate.split.right_sum;
|
snode_.at(right_child).stats = candidate.split.right_sum;
|
||||||
snode_.at(right_child).root_gain = evaluator.CalcGain(
|
snode_.at(right_child).root_gain =
|
||||||
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
interaction_constraints_.Split(candidate.nid,
|
interaction_constraints_.Split(candidate.nid,
|
||||||
tree[candidate.nid].SplitIndex(), left_child,
|
tree[candidate.nid].SplitIndex(), left_child,
|
||||||
@ -429,31 +428,31 @@ class HistEvaluator {
|
|||||||
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
||||||
auto const& Stats() const { return snode_; }
|
auto const& Stats() const { return snode_; }
|
||||||
|
|
||||||
float InitRoot(GradStats const& root_sum) {
|
float InitRoot(GradStats const &root_sum) {
|
||||||
snode_.resize(1);
|
snode_.resize(1);
|
||||||
auto root_evaluator = tree_evaluator_.GetEvaluator();
|
auto root_evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
|
||||||
snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
|
snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
|
||||||
snode_[0].root_gain = root_evaluator.CalcGain(RegTree::kRoot, param_,
|
snode_[0].root_gain =
|
||||||
GradStats{snode_[0].stats});
|
root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
|
||||||
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_,
|
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
|
||||||
GradStats{snode_[0].stats});
|
|
||||||
return weight;
|
return weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// The column sampler must be constructed by caller since we need to preserve the rng
|
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||||
// for the entire training session.
|
// for the entire training session.
|
||||||
explicit HistEvaluator(Context const* ctx, TrainParam const ¶m, MetaInfo const &info,
|
explicit HistEvaluator(Context const *ctx, TrainParam const *param, MetaInfo const &info,
|
||||||
std::shared_ptr<common::ColumnSampler> sampler)
|
std::shared_ptr<common::ColumnSampler> sampler)
|
||||||
: ctx_{ctx}, param_{param},
|
: ctx_{ctx},
|
||||||
|
param_{param},
|
||||||
column_sampler_{std::move(sampler)},
|
column_sampler_{std::move(sampler)},
|
||||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||||
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
||||||
interaction_constraints_.Configure(param, info.num_col_);
|
interaction_constraints_.Configure(*param, info.num_col_);
|
||||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||||
param_.colsample_bynode, param_.colsample_bylevel,
|
param_->colsample_bynode, param_->colsample_bylevel,
|
||||||
param_.colsample_bytree);
|
param_->colsample_bytree);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -488,6 +487,5 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
|
|||||||
@ -23,8 +23,7 @@
|
|||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_approx);
|
DMLC_REGISTRY_FILE_TAG(updater_approx);
|
||||||
|
|
||||||
@ -41,7 +40,7 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
|
|||||||
|
|
||||||
class GloablApproxBuilder {
|
class GloablApproxBuilder {
|
||||||
protected:
|
protected:
|
||||||
TrainParam param_;
|
TrainParam const* param_;
|
||||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
HistEvaluator<CPUExpandEntry> evaluator_;
|
HistEvaluator<CPUExpandEntry> evaluator_;
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
||||||
@ -64,7 +63,7 @@ class GloablApproxBuilder {
|
|||||||
bst_bin_t n_total_bins = 0;
|
bst_bin_t n_total_bins = 0;
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, task_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, task_))) {
|
||||||
if (n_total_bins == 0) {
|
if (n_total_bins == 0) {
|
||||||
n_total_bins = page.cut.TotalBins();
|
n_total_bins = page.cut.TotalBins();
|
||||||
feature_values_ = page.cut;
|
feature_values_ = page.cut;
|
||||||
@ -75,7 +74,7 @@ class GloablApproxBuilder {
|
|||||||
n_batches_++;
|
n_batches_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
||||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
@ -96,7 +95,7 @@ class GloablApproxBuilder {
|
|||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
||||||
{}, gpair);
|
{}, gpair);
|
||||||
i++;
|
i++;
|
||||||
@ -105,7 +104,7 @@ class GloablApproxBuilder {
|
|||||||
auto weight = evaluator_.InitRoot(root_sum);
|
auto weight = evaluator_.InitRoot(root_sum);
|
||||||
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
|
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
|
||||||
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
||||||
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
|
(*p_tree)[RegTree::kRoot].SetLeaf(param_->learning_rate * weight);
|
||||||
|
|
||||||
auto const &histograms = histogram_builder_.Histogram();
|
auto const &histograms = histogram_builder_.Histogram();
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
@ -147,7 +146,7 @@ class GloablApproxBuilder {
|
|||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||||
nodes_to_build, nodes_to_sub, gpair);
|
nodes_to_build, nodes_to_sub, gpair);
|
||||||
i++;
|
i++;
|
||||||
@ -168,10 +167,10 @@ class GloablApproxBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, Context const *ctx,
|
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
|
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
|
||||||
common::Monitor *monitor)
|
common::Monitor *monitor)
|
||||||
: param_{std::move(param)},
|
: param_{param},
|
||||||
col_sampler_{std::move(column_sampler)},
|
col_sampler_{std::move(column_sampler)},
|
||||||
evaluator_{ctx, param_, info, col_sampler_},
|
evaluator_{ctx, param_, info, col_sampler_},
|
||||||
ctx_{ctx},
|
ctx_{ctx},
|
||||||
@ -183,7 +182,7 @@ class GloablApproxBuilder {
|
|||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
this->InitData(p_fmat, hess);
|
this->InitData(p_fmat, hess);
|
||||||
|
|
||||||
Driver<CPUExpandEntry> driver(param_);
|
Driver<CPUExpandEntry> driver(*param_);
|
||||||
auto &tree = *p_tree;
|
auto &tree = *p_tree;
|
||||||
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
|
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
|
||||||
auto expand_set = driver.Pop();
|
auto expand_set = driver.Pop();
|
||||||
@ -213,7 +212,7 @@ class GloablApproxBuilder {
|
|||||||
|
|
||||||
monitor_->Start("UpdatePosition");
|
monitor_->Start("UpdatePosition");
|
||||||
size_t page_id = 0;
|
size_t page_id = 0;
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
||||||
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
|
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
|
||||||
page_id++;
|
page_id++;
|
||||||
}
|
}
|
||||||
@ -250,7 +249,6 @@ class GloablApproxBuilder {
|
|||||||
* iteration.
|
* iteration.
|
||||||
*/
|
*/
|
||||||
class GlobalApproxUpdater : public TreeUpdater {
|
class GlobalApproxUpdater : public TreeUpdater {
|
||||||
TrainParam param_;
|
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
// specializations for different histogram precision.
|
// specializations for different histogram precision.
|
||||||
std::unique_ptr<GloablApproxBuilder> pimpl_;
|
std::unique_ptr<GloablApproxBuilder> pimpl_;
|
||||||
@ -265,15 +263,9 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
monitor_.Init(__func__);
|
monitor_.Init(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
void Configure(Args const &) override {}
|
||||||
void LoadConfig(Json const &in) override {
|
void LoadConfig(Json const &) override {}
|
||||||
auto const &config = get<Object const>(in);
|
void SaveConfig(Json *) const override {}
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
|
||||||
}
|
|
||||||
void SaveConfig(Json *p_out) const override {
|
|
||||||
auto &out = *p_out;
|
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
||||||
linalg::Matrix<GradientPair> *sampled) {
|
linalg::Matrix<GradientPair> *sampled) {
|
||||||
@ -283,20 +275,17 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
SampleGradient(ctx_, param, sampled->HostView());
|
SampleGradient(ctx_, param, sampled->HostView());
|
||||||
}
|
}
|
||||||
|
|
||||||
char const *Name() const override { return "grow_histmaker"; }
|
[[nodiscard]] char const *Name() const override { return "grow_histmaker"; }
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
|
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
float lr = param_.learning_rate;
|
pimpl_ = std::make_unique<GloablApproxBuilder>(param, m->Info(), ctx_, column_sampler_, task_,
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
|
|
||||||
pimpl_ = std::make_unique<GloablApproxBuilder>(param_, m->Info(), ctx_, column_sampler_, task_,
|
|
||||||
&monitor_);
|
&monitor_);
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> h_gpair;
|
linalg::Matrix<GradientPair> h_gpair;
|
||||||
// Obtain the hessian values for weighted sketching
|
// Obtain the hessian values for weighted sketching
|
||||||
InitData(param_, gpair, &h_gpair);
|
InitData(*param, gpair, &h_gpair);
|
||||||
std::vector<float> hess(h_gpair.Size());
|
std::vector<float> hess(h_gpair.Size());
|
||||||
auto const &s_gpair = h_gpair.Data()->ConstHostVector();
|
auto const &s_gpair = h_gpair.Data()->ConstHostVector();
|
||||||
std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(),
|
std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(),
|
||||||
@ -304,12 +293,11 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
|
|
||||||
cached_ = m;
|
cached_ = m;
|
||||||
|
|
||||||
size_t t_idx = 0;
|
std::size_t t_idx = 0;
|
||||||
for (auto p_tree : trees) {
|
for (auto p_tree : trees) {
|
||||||
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||||
@ -320,7 +308,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasNodePosition() const override { return true; }
|
[[nodiscard]] bool HasNodePosition() const override { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
|
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
|
||||||
@ -330,5 +318,4 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
|||||||
"Tree constructor that uses approximate histogram construction "
|
"Tree constructor that uses approximate histogram construction "
|
||||||
"for each node.")
|
"for each node.")
|
||||||
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); });
|
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); });
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file updater_colmaker.cc
|
* \file updater_colmaker.cc
|
||||||
* \brief use columnwise update to construct a tree
|
* \brief use columnwise update to construct a tree
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -17,8 +17,7 @@
|
|||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "split_evaluator.h"
|
#include "split_evaluator.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
||||||
|
|
||||||
@ -57,18 +56,15 @@ class ColMaker: public TreeUpdater {
|
|||||||
public:
|
public:
|
||||||
explicit ColMaker(Context const *ctx) : TreeUpdater(ctx) {}
|
explicit ColMaker(Context const *ctx) : TreeUpdater(ctx) {}
|
||||||
void Configure(const Args &args) override {
|
void Configure(const Args &args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
|
||||||
colmaker_param_.UpdateAllowUnknown(args);
|
colmaker_param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
|
||||||
FromJson(config.at("colmaker_train_param"), &this->colmaker_param_);
|
FromJson(config.at("colmaker_train_param"), &this->colmaker_param_);
|
||||||
}
|
}
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json *p_out) const override {
|
||||||
auto& out = *p_out;
|
auto &out = *p_out;
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
out["colmaker_train_param"] = ToJson(colmaker_param_);
|
out["colmaker_train_param"] = ToJson(colmaker_param_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +91,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
if (collective::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
@ -108,22 +104,16 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
this->LazyGetColumnDensity(dmat);
|
this->LazyGetColumnDensity(dmat);
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
interaction_constraints_.Configure(param_, dmat->Info().num_row_);
|
|
||||||
// build tree
|
// build tree
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
CHECK(ctx_);
|
CHECK(ctx_);
|
||||||
Builder builder(param_, colmaker_param_, interaction_constraints_, ctx_,
|
Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_);
|
||||||
column_densities_);
|
|
||||||
builder.Update(gpair->ConstHostVector(), dmat, tree);
|
builder.Update(gpair->ConstHostVector(), dmat, tree);
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// training parameter
|
|
||||||
TrainParam param_;
|
|
||||||
ColMakerTrainParam colmaker_param_;
|
ColMakerTrainParam colmaker_param_;
|
||||||
// SplitEvaluator that will be cloned for each Builder
|
// SplitEvaluator that will be cloned for each Builder
|
||||||
std::vector<float> column_densities_;
|
std::vector<float> column_densities_;
|
||||||
@ -614,5 +604,4 @@ class ColMaker: public TreeUpdater {
|
|||||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||||
.describe("Grow tree with parallelization over columns.")
|
.describe("Grow tree with parallelization over columns.")
|
||||||
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); });
|
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); });
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 XGBoost contributors
|
* Copyright 2017-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/reduce.h>
|
#include <thrust/reduce.h>
|
||||||
@ -756,7 +756,6 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
void Configure(const Args& args) override {
|
void Configure(const Args& args) override {
|
||||||
// Used in test to count how many configurations are performed
|
// Used in test to count how many configurations are performed
|
||||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||||
param_.UpdateAllowUnknown(args);
|
|
||||||
hist_maker_param_.UpdateAllowUnknown(args);
|
hist_maker_param_.UpdateAllowUnknown(args);
|
||||||
dh::CheckComputeCapability();
|
dh::CheckComputeCapability();
|
||||||
initialised_ = false;
|
initialised_ = false;
|
||||||
@ -768,32 +767,26 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||||
initialised_ = false;
|
initialised_ = false;
|
||||||
FromJson(config.at("train_param"), ¶m_);
|
|
||||||
}
|
}
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~GPUHistMaker() { // NOLINT
|
~GPUHistMaker() { // NOLINT
|
||||||
dh::GlobalMemoryLogger().Log();
|
dh::GlobalMemoryLogger().Log();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
monitor_.Start("Update");
|
monitor_.Start("Update");
|
||||||
|
|
||||||
// rescale learning rate according to size of trees
|
|
||||||
float lr = param_.learning_rate;
|
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
|
|
||||||
// build tree
|
// build tree
|
||||||
try {
|
try {
|
||||||
size_t t_idx{0};
|
size_t t_idx{0};
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
this->UpdateTree(gpair, dmat, tree, &out_position[t_idx]);
|
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
|
||||||
|
|
||||||
if (hist_maker_param_.debug_synchronize) {
|
if (hist_maker_param_.debug_synchronize) {
|
||||||
this->CheckTreesSynchronized(tree);
|
this->CheckTreesSynchronized(tree);
|
||||||
@ -804,12 +797,10 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
param_.learning_rate = lr;
|
|
||||||
monitor_.Stop("Update");
|
monitor_.Stop("Update");
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDataOnce(DMatrix* dmat) {
|
void InitDataOnce(TrainParam const* param, DMatrix* dmat) {
|
||||||
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
||||||
info_ = &dmat->Info();
|
info_ = &dmat->Info();
|
||||||
|
|
||||||
@ -818,24 +809,24 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||||
|
|
||||||
BatchParam batch_param{
|
BatchParam batch_param{
|
||||||
ctx_->gpu_id,
|
ctx_->gpu_id,
|
||||||
param_.max_bin,
|
param->max_bin,
|
||||||
};
|
};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
info_->feature_types.SetDevice(ctx_->gpu_id);
|
info_->feature_types.SetDevice(ctx_->gpu_id);
|
||||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
||||||
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_,
|
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param,
|
||||||
column_sampling_seed, info_->num_col_, batch_param));
|
column_sampling_seed, info_->num_col_, batch_param));
|
||||||
|
|
||||||
p_last_fmat_ = dmat;
|
p_last_fmat_ = dmat;
|
||||||
initialised_ = true;
|
initialised_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitData(DMatrix* dmat, RegTree const* p_tree) {
|
void InitData(TrainParam const* param, DMatrix* dmat, RegTree const* p_tree) {
|
||||||
if (!initialised_) {
|
if (!initialised_) {
|
||||||
monitor_.Start("InitDataOnce");
|
monitor_.Start("InitDataOnce");
|
||||||
this->InitDataOnce(dmat);
|
this->InitDataOnce(param, dmat);
|
||||||
monitor_.Stop("InitDataOnce");
|
monitor_.Stop("InitDataOnce");
|
||||||
}
|
}
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
@ -856,10 +847,10 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
CHECK(*local_tree == reference_tree);
|
CHECK(*local_tree == reference_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
|
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
||||||
monitor_.Start("InitData");
|
monitor_.Start("InitData");
|
||||||
this->InitData(p_fmat, p_tree);
|
this->InitData(param, p_fmat, p_tree);
|
||||||
monitor_.Stop("InitData");
|
monitor_.Stop("InitData");
|
||||||
|
|
||||||
gpair->SetDevice(ctx_->gpu_id);
|
gpair->SetDevice(ctx_->gpu_id);
|
||||||
@ -878,7 +869,6 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainParam param_; // NOLINT
|
|
||||||
MetaInfo* info_{}; // NOLINT
|
MetaInfo* info_{}; // NOLINT
|
||||||
|
|
||||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file updater_prune.cc
|
* \file updater_prune.cc
|
||||||
* \brief prune a tree given the statistics
|
* \brief prune a tree given the statistics
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -8,13 +8,11 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../common/timer.h"
|
||||||
|
#include "./param.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "./param.h"
|
namespace xgboost::tree {
|
||||||
#include "../common/timer.h"
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||||
|
|
||||||
/*! \brief pruner that prunes a tree after growing finishes */
|
/*! \brief pruner that prunes a tree after growing finishes */
|
||||||
@ -24,47 +22,31 @@ class TreePruner : public TreeUpdater {
|
|||||||
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
||||||
pruner_monitor_.Init("TreePruner");
|
pruner_monitor_.Init("TreePruner");
|
||||||
}
|
}
|
||||||
char const* Name() const override {
|
[[nodiscard]] char const* Name() const override { return "prune"; }
|
||||||
return "prune";
|
|
||||||
}
|
|
||||||
|
|
||||||
// set training parameter
|
// set training parameter
|
||||||
void Configure(const Args& args) override {
|
void Configure(const Args& args) override { syncher_->Configure(args); }
|
||||||
param_.UpdateAllowUnknown(args);
|
|
||||||
syncher_->Configure(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const&) override {}
|
||||||
auto const& config = get<Object const>(in);
|
void SaveConfig(Json*) const override {}
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
[[nodiscard]] bool CanModifyTree() const override { return true; }
|
||||||
}
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
|
||||||
auto& out = *p_out;
|
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
}
|
|
||||||
bool CanModifyTree() const override {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the tree, do pruning
|
// update the tree, do pruning
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
pruner_monitor_.Start("PrunerUpdate");
|
pruner_monitor_.Start("PrunerUpdate");
|
||||||
// rescale learning rate according to size of trees
|
|
||||||
float lr = param_.learning_rate;
|
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
this->DoPrune(tree);
|
this->DoPrune(param, tree);
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
syncher_->Update(param, gpair, p_fmat, out_position, trees);
|
||||||
syncher_->Update(gpair, p_fmat, out_position, trees);
|
|
||||||
pruner_monitor_.Stop("PrunerUpdate");
|
pruner_monitor_.Stop("PrunerUpdate");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// try to prune off current leaf
|
// try to prune off current leaf
|
||||||
bst_node_t TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
|
bst_node_t TryPruneLeaf(TrainParam const* param, RegTree* p_tree, int nid, int depth,
|
||||||
|
int npruned) {
|
||||||
|
auto& tree = *p_tree;
|
||||||
CHECK(tree[nid].IsLeaf());
|
CHECK(tree[nid].IsLeaf());
|
||||||
if (tree[nid].IsRoot()) {
|
if (tree[nid].IsRoot()) {
|
||||||
return npruned;
|
return npruned;
|
||||||
@ -77,22 +59,22 @@ class TreePruner : public TreeUpdater {
|
|||||||
auto right = tree[pid].RightChild();
|
auto right = tree[pid].RightChild();
|
||||||
bool balanced = tree[left].IsLeaf() &&
|
bool balanced = tree[left].IsLeaf() &&
|
||||||
right != RegTree::kInvalidNodeId && tree[right].IsLeaf();
|
right != RegTree::kInvalidNodeId && tree[right].IsLeaf();
|
||||||
if (balanced && param_.NeedPrune(s.loss_chg, depth)) {
|
if (balanced && param->NeedPrune(s.loss_chg, depth)) {
|
||||||
// need to be pruned
|
// need to be pruned
|
||||||
tree.ChangeToLeaf(pid, param_.learning_rate * s.base_weight);
|
tree.ChangeToLeaf(pid, param->learning_rate * s.base_weight);
|
||||||
// tail recursion
|
// tail recursion
|
||||||
return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
|
return this->TryPruneLeaf(param, p_tree, pid, depth - 1, npruned + 2);
|
||||||
} else {
|
} else {
|
||||||
return npruned;
|
return npruned;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief do pruning of a tree */
|
/*! \brief do pruning of a tree */
|
||||||
void DoPrune(RegTree* p_tree) {
|
void DoPrune(TrainParam const* param, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
bst_node_t npruned = 0;
|
bst_node_t npruned = 0;
|
||||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
||||||
if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
|
if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
|
||||||
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
|
npruned = this->TryPruneLeaf(param, p_tree, nid, tree.GetDepth(nid), npruned);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG(INFO) << "tree pruning end, "
|
LOG(INFO) << "tree pruning end, "
|
||||||
@ -103,13 +85,10 @@ class TreePruner : public TreeUpdater {
|
|||||||
private:
|
private:
|
||||||
// synchronizer
|
// synchronizer
|
||||||
std::unique_ptr<TreeUpdater> syncher_;
|
std::unique_ptr<TreeUpdater> syncher_;
|
||||||
// training parameter
|
|
||||||
TrainParam param_;
|
|
||||||
common::Monitor pruner_monitor_;
|
common::Monitor pruner_monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||||
.describe("Pruner that prune the tree according to statistics.")
|
.describe("Pruner that prune the tree according to statistics.")
|
||||||
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -28,21 +28,14 @@ namespace tree {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist);
|
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist);
|
||||||
|
|
||||||
void QuantileHistMaker::Configure(const Args &args) {
|
void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair,
|
||||||
param_.UpdateAllowUnknown(args);
|
DMatrix *dmat,
|
||||||
}
|
|
||||||
|
|
||||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree *> &trees) {
|
const std::vector<RegTree *> &trees) {
|
||||||
// rescale learning rate according to size of trees
|
|
||||||
float lr = param_.learning_rate;
|
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
|
|
||||||
// build tree
|
// build tree
|
||||||
const size_t n_trees = trees.size();
|
const size_t n_trees = trees.size();
|
||||||
if (!pimpl_) {
|
if (!pimpl_) {
|
||||||
pimpl_.reset(new Builder(n_trees, param_, dmat, task_, ctx_));
|
pimpl_.reset(new Builder(n_trees, param, dmat, task_, ctx_));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t t_idx{0};
|
size_t t_idx{0};
|
||||||
@ -51,8 +44,6 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *d
|
|||||||
this->pimpl_->UpdateTree(gpair, dmat, p_tree, &t_row_position);
|
this->pimpl_->UpdateTree(gpair, dmat, p_tree, &t_row_position);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
param_.learning_rate = lr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
|
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
|
||||||
@ -107,7 +98,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
|||||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||||
p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess();
|
p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess();
|
||||||
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
||||||
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
|
(*p_tree)[RegTree::kRoot].SetLeaf(param_->learning_rate * weight);
|
||||||
|
|
||||||
std::vector<CPUExpandEntry> entries{node};
|
std::vector<CPUExpandEntry> entries{node};
|
||||||
monitor_->Start("EvaluateSplits");
|
monitor_->Start("EvaluateSplits");
|
||||||
@ -173,7 +164,7 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
|
|||||||
HostDeviceVector<bst_node_t> *p_out_position) {
|
HostDeviceVector<bst_node_t> *p_out_position) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
|
|
||||||
Driver<CPUExpandEntry> driver(param_);
|
Driver<CPUExpandEntry> driver(*param_);
|
||||||
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
|
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
|
||||||
auto const &tree = *p_tree;
|
auto const &tree = *p_tree;
|
||||||
auto expand_set = driver.Pop();
|
auto expand_set = driver.Pop();
|
||||||
@ -285,7 +276,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
|
|
||||||
auto m_gpair =
|
auto m_gpair =
|
||||||
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||||
SampleGradient(ctx_, param_, m_gpair);
|
SampleGradient(ctx_, *param_, m_gpair);
|
||||||
}
|
}
|
||||||
|
|
||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
|
|||||||
@ -35,49 +35,36 @@
|
|||||||
#include "../common/partition_builder.h"
|
#include "../common/partition_builder.h"
|
||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
inline BatchParam HistBatch(TrainParam const* param) {
|
||||||
inline BatchParam HistBatch(TrainParam const& param) {
|
return {param->max_bin, param->sparse_threshold};
|
||||||
return {param.max_bin, param.sparse_threshold};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker: public TreeUpdater {
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {}
|
explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {}
|
||||||
void Configure(const Args& args) override;
|
void Configure(const Args&) override {}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree*>& trees) override;
|
const std::vector<RegTree*>& trees) override;
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix *data,
|
bool UpdatePredictionCache(const DMatrix *data,
|
||||||
linalg::VectorView<float> out_preds) override;
|
linalg::VectorView<float> out_preds) override;
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const&) override {}
|
||||||
auto const& config = get<Object const>(in);
|
void SaveConfig(Json*) const override {}
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
|
||||||
}
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
|
||||||
auto& out = *p_out;
|
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
}
|
|
||||||
|
|
||||||
char const* Name() const override {
|
[[nodiscard]] char const* Name() const override { return "grow_quantile_histmaker"; }
|
||||||
return "grow_quantile_histmaker";
|
[[nodiscard]] bool HasNodePosition() const override { return true; }
|
||||||
}
|
|
||||||
|
|
||||||
bool HasNodePosition() const override { return true; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// training parameter
|
|
||||||
TrainParam param_;
|
|
||||||
|
|
||||||
// actual builder that runs the algorithm
|
// actual builder that runs the algorithm
|
||||||
struct Builder {
|
struct Builder {
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat,
|
explicit Builder(const size_t n_trees, TrainParam const* param, DMatrix const* fmat,
|
||||||
ObjInfo task, Context const* ctx)
|
ObjInfo task, Context const* ctx)
|
||||||
: n_trees_(n_trees),
|
: n_trees_(n_trees),
|
||||||
param_(param),
|
param_(param),
|
||||||
@ -115,7 +102,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
const size_t n_trees_;
|
const size_t n_trees_;
|
||||||
const TrainParam& param_;
|
TrainParam const* param_;
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_{
|
std::shared_ptr<common::ColumnSampler> column_sampler_{
|
||||||
std::make_shared<common::ColumnSampler>()};
|
std::make_shared<common::ColumnSampler>()};
|
||||||
|
|
||||||
@ -140,7 +127,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::unique_ptr<Builder> pimpl_;
|
std::unique_ptr<Builder> pimpl_;
|
||||||
ObjInfo task_;
|
ObjInfo task_;
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#endif // XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
#endif // XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file updater_refresh.cc
|
* \file updater_refresh.cc
|
||||||
* \brief refresh the statistics and leaf value on the tree on the dataset
|
* \brief refresh the statistics and leaf value on the tree on the dataset
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -16,8 +16,7 @@
|
|||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
||||||
|
|
||||||
@ -25,23 +24,14 @@ DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
|||||||
class TreeRefresher : public TreeUpdater {
|
class TreeRefresher : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
explicit TreeRefresher(Context const *ctx) : TreeUpdater(ctx) {}
|
explicit TreeRefresher(Context const *ctx) : TreeUpdater(ctx) {}
|
||||||
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
void Configure(const Args &) override {}
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const &) override {}
|
||||||
auto const& config = get<Object const>(in);
|
void SaveConfig(Json *) const override {}
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
|
||||||
}
|
[[nodiscard]] char const *Name() const override { return "refresh"; }
|
||||||
void SaveConfig(Json* p_out) const override {
|
[[nodiscard]] bool CanModifyTree() const override { return true; }
|
||||||
auto& out = *p_out;
|
|
||||||
out["train_param"] = ToJson(param_);
|
|
||||||
}
|
|
||||||
char const* Name() const override {
|
|
||||||
return "refresh";
|
|
||||||
}
|
|
||||||
bool CanModifyTree() const override {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// update the tree, do pruning
|
// update the tree, do pruning
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
if (trees.size() == 0) return;
|
if (trees.size() == 0) return;
|
||||||
@ -103,16 +93,11 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
lazy_get_stats();
|
lazy_get_stats();
|
||||||
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
|
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
|
||||||
stemp[0].size() * 2);
|
stemp[0].size() * 2);
|
||||||
// rescale learning rate according to size of trees
|
|
||||||
float lr = param_.learning_rate;
|
|
||||||
param_.learning_rate = lr / trees.size();
|
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
||||||
offset += tree->param.num_nodes;
|
offset += tree->param.num_nodes;
|
||||||
}
|
}
|
||||||
// set learning rate back
|
|
||||||
param_.learning_rate = lr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -135,31 +120,27 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
gstats[pid].Add(gpair[ridx]);
|
gstats[pid].Add(gpair[ridx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void Refresh(const GradStats *gstats,
|
inline void Refresh(TrainParam const *param, const GradStats *gstats, int nid, RegTree *p_tree) {
|
||||||
int nid, RegTree *p_tree) {
|
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
tree.Stat(nid).base_weight =
|
tree.Stat(nid).base_weight =
|
||||||
static_cast<bst_float>(CalcWeight(param_, gstats[nid]));
|
static_cast<bst_float>(CalcWeight(*param, gstats[nid]));
|
||||||
tree.Stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
tree.Stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
||||||
if (tree[nid].IsLeaf()) {
|
if (tree[nid].IsLeaf()) {
|
||||||
if (param_.refresh_leaf) {
|
if (param->refresh_leaf) {
|
||||||
tree[nid].SetLeaf(tree.Stat(nid).base_weight * param_.learning_rate);
|
tree[nid].SetLeaf(tree.Stat(nid).base_weight * param->learning_rate);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tree.Stat(nid).loss_chg = static_cast<bst_float>(
|
tree.Stat(nid).loss_chg =
|
||||||
xgboost::tree::CalcGain(param_, gstats[tree[nid].LeftChild()]) +
|
static_cast<bst_float>(xgboost::tree::CalcGain(*param, gstats[tree[nid].LeftChild()]) +
|
||||||
xgboost::tree::CalcGain(param_, gstats[tree[nid].RightChild()]) -
|
xgboost::tree::CalcGain(*param, gstats[tree[nid].RightChild()]) -
|
||||||
xgboost::tree::CalcGain(param_, gstats[nid]));
|
xgboost::tree::CalcGain(*param, gstats[nid]));
|
||||||
this->Refresh(gstats, tree[nid].LeftChild(), p_tree);
|
this->Refresh(param, gstats, tree[nid].LeftChild(), p_tree);
|
||||||
this->Refresh(gstats, tree[nid].RightChild(), p_tree);
|
this->Refresh(param, gstats, tree[nid].RightChild(), p_tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// training parameter
|
|
||||||
TrainParam param_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||||
.set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
.set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2019 by Contributors
|
* Copyright 2014-2013 by XBGoost Contributors
|
||||||
* \file updater_sync.cc
|
* \file updater_sync.cc
|
||||||
* \brief synchronize the tree in all distributed nodes
|
* \brief synchronize the tree in all distributed nodes
|
||||||
*/
|
*/
|
||||||
@ -13,8 +13,7 @@
|
|||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_sync);
|
DMLC_REGISTRY_FILE_TAG(updater_sync);
|
||||||
|
|
||||||
@ -30,11 +29,9 @@ class TreeSyncher : public TreeUpdater {
|
|||||||
void LoadConfig(Json const&) override {}
|
void LoadConfig(Json const&) override {}
|
||||||
void SaveConfig(Json*) const override {}
|
void SaveConfig(Json*) const override {}
|
||||||
|
|
||||||
char const* Name() const override {
|
[[nodiscard]] char const* Name() const override { return "prune"; }
|
||||||
return "prune";
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>*, DMatrix*,
|
void Update(TrainParam const*, HostDeviceVector<GradientPair>*, DMatrix*,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
if (collective::GetWorldSize() == 1) return;
|
if (collective::GetWorldSize() == 1) return;
|
||||||
@ -57,5 +54,4 @@ class TreeSyncher : public TreeUpdater {
|
|||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||||
.set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); });
|
.set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); });
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -6,8 +6,9 @@
|
|||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/objective.h>
|
#include <xgboost/objective.h>
|
||||||
|
|
||||||
#include "../../../src/common/linalg_op.h" // begin,end
|
#include "../../../src/common/linalg_op.h" // for begin, end
|
||||||
#include "../../../src/objective/adaptive.h"
|
#include "../../../src/objective/adaptive.h"
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -408,9 +409,13 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteError)) {
|
|||||||
h_predt[i] = labels[i] + i;
|
h_predt[i] = labels[i] + i;
|
||||||
}
|
}
|
||||||
|
|
||||||
obj->UpdateTreeLeaf(position, info, predt, 0, &tree);
|
tree::TrainParam param;
|
||||||
ASSERT_EQ(tree[1].LeafValue(), -1);
|
param.Init(Args{});
|
||||||
ASSERT_EQ(tree[2].LeafValue(), -4);
|
auto lr = param.learning_rate;
|
||||||
|
|
||||||
|
obj->UpdateTreeLeaf(position, info, param.learning_rate, predt, 0, &tree);
|
||||||
|
ASSERT_EQ(tree[1].LeafValue(), -1.0f * lr);
|
||||||
|
ASSERT_EQ(tree[2].LeafValue(), -4.0f * lr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
||||||
@ -457,11 +462,16 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
|||||||
ASSERT_EQ(tree.GetNumLeaves(), 4);
|
ASSERT_EQ(tree.GetNumLeaves(), 4);
|
||||||
|
|
||||||
auto empty_leaf = tree[4].LeafValue();
|
auto empty_leaf = tree[4].LeafValue();
|
||||||
obj->UpdateTreeLeaf(position, info, predt, t, &tree);
|
|
||||||
ASSERT_EQ(tree[3].LeafValue(), -5);
|
tree::TrainParam param;
|
||||||
ASSERT_EQ(tree[4].LeafValue(), empty_leaf);
|
param.Init(Args{});
|
||||||
ASSERT_EQ(tree[5].LeafValue(), -10);
|
auto lr = param.learning_rate;
|
||||||
ASSERT_EQ(tree[6].LeafValue(), -14);
|
|
||||||
|
obj->UpdateTreeLeaf(position, info, lr, predt, t, &tree);
|
||||||
|
ASSERT_EQ(tree[3].LeafValue(), -5.0f * lr);
|
||||||
|
ASSERT_EQ(tree[4].LeafValue(), empty_leaf * lr);
|
||||||
|
ASSERT_EQ(tree[5].LeafValue(), -10.0f * lr);
|
||||||
|
ASSERT_EQ(tree[6].LeafValue(), -14.0f * lr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
|||||||
|
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
common::HistCollection hist;
|
common::HistCollection hist;
|
||||||
std::vector<GradientPair> row_gpairs = {
|
std::vector<GradientPair> row_gpairs = {
|
||||||
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
@ -96,7 +96,7 @@ TEST(HistEvaluator, Apply) {
|
|||||||
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
|
|
||||||
CPUExpandEntry entry{0, 0, 10.0f};
|
CPUExpandEntry entry{0, 0, 10.0f};
|
||||||
entry.split.left_sum = GradStats{0.4, 0.6f};
|
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||||
@ -123,7 +123,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
|
|||||||
// check the evaluator is returning the optimal split
|
// check the evaluator is returning the optimal split
|
||||||
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
HistEvaluator<CPUExpandEntry> evaluator{&ctx, param_, info_, sampler};
|
HistEvaluator<CPUExpandEntry> evaluator{&ctx, ¶m_, info_, sampler};
|
||||||
evaluator.InitRoot(GradStats{total_gpair_});
|
evaluator.InitRoot(GradStats{total_gpair_});
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
@ -153,7 +153,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
|||||||
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||||
|
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|
||||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
||||||
@ -204,7 +204,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
|||||||
info.num_col_ = 1;
|
info.num_col_ = 1;
|
||||||
info.feature_types = {FeatureType::kCategorical};
|
info.feature_types = {FeatureType::kCategorical};
|
||||||
Context ctx;
|
Context ctx;
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param_, info, sampler};
|
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m_, info, sampler};
|
||||||
evaluator.InitRoot(GradStats{parent_sum_});
|
evaluator.InitRoot(GradStats{parent_sum_});
|
||||||
|
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
@ -12,8 +12,7 @@
|
|||||||
#include "../../../src/tree/split_evaluator.h"
|
#include "../../../src/tree/split_evaluator.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
/**
|
/**
|
||||||
* \brief Enumerate all possible partitions for categorical split.
|
* \brief Enumerate all possible partitions for categorical split.
|
||||||
*/
|
*/
|
||||||
@ -151,5 +150,4 @@ class TestCategoricalSplitWithMissing : public testing::Test {
|
|||||||
ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
|
ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 XGBoost contributors
|
* Copyright 2017-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
@ -13,6 +13,7 @@
|
|||||||
#include "../../../src/common/common.h"
|
#include "../../../src/common/common.h"
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_source.h"
|
||||||
#include "../../../src/tree/constraints.cuh"
|
#include "../../../src/tree/constraints.cuh"
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../../../src/tree/updater_gpu_common.cuh"
|
#include "../../../src/tree/updater_gpu_common.cuh"
|
||||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
#include "../../../src/tree/updater_gpu_hist.cu"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
@ -21,8 +22,7 @@
|
|||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
TEST(GpuHist, DeviceHistogram) {
|
TEST(GpuHist, DeviceHistogram) {
|
||||||
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
||||||
dh::safe_cuda(cudaSetDevice(0));
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
@ -83,11 +83,12 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
int const kNRows = 16, kNCols = 8;
|
int const kNRows = 16, kNCols = 8;
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
std::vector<std::pair<std::string, std::string>> args {
|
Args args{
|
||||||
{"max_depth", "6"},
|
{"max_depth", "6"},
|
||||||
{"max_leaves", "0"},
|
{"max_leaves", "0"},
|
||||||
};
|
};
|
||||||
param.Init(args);
|
param.Init(args);
|
||||||
|
|
||||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
BatchParam batch_param{};
|
BatchParam batch_param{};
|
||||||
Context ctx{CreateEmptyGenericParam(0)};
|
Context ctx{CreateEmptyGenericParam(0)};
|
||||||
@ -168,7 +169,6 @@ void TestHistogramIndexImpl() {
|
|||||||
int constexpr kNRows = 1000, kNCols = 10;
|
int constexpr kNRows = 1000, kNCols = 10;
|
||||||
|
|
||||||
// Build 2 matrices and build a histogram maker with that
|
// Build 2 matrices and build a histogram maker with that
|
||||||
|
|
||||||
Context ctx(CreateEmptyGenericParam(0));
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}},
|
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}},
|
||||||
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}};
|
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}};
|
||||||
@ -179,15 +179,14 @@ void TestHistogramIndexImpl() {
|
|||||||
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
|
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
|
||||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir));
|
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir));
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string>> training_params = {
|
Args training_params = {{"max_depth", "10"}, {"max_leaves", "0"}};
|
||||||
{"max_depth", "10"},
|
TrainParam param;
|
||||||
{"max_leaves", "0"}
|
param.UpdateAllowUnknown(training_params);
|
||||||
};
|
|
||||||
|
|
||||||
hist_maker.Configure(training_params);
|
hist_maker.Configure(training_params);
|
||||||
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
hist_maker.InitDataOnce(¶m, hist_maker_dmat.get());
|
||||||
hist_maker_ext.Configure(training_params);
|
hist_maker_ext.Configure(training_params);
|
||||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
hist_maker_ext.InitDataOnce(¶m, hist_maker_ext_dmat.get());
|
||||||
|
|
||||||
// Extract the device maker from the histogram makers and from that its compressed
|
// Extract the device maker from the histogram makers and from that its compressed
|
||||||
// histogram index
|
// histogram index
|
||||||
@ -237,13 +236,15 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
|||||||
{"subsample", std::to_string(subsample)},
|
{"subsample", std::to_string(subsample)},
|
||||||
{"sampling_method", sampling_method},
|
{"sampling_method", sampling_method},
|
||||||
};
|
};
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
Context ctx(CreateEmptyGenericParam(0));
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}};
|
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}};
|
||||||
hist_maker.Configure(args);
|
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});
|
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||||
|
{tree});
|
||||||
auto cache = linalg::VectorView<float>{preds->DeviceSpan(), {preds->Size()}, 0};
|
auto cache = linalg::VectorView<float>{preds->DeviceSpan(), {preds->Size()}, 0};
|
||||||
hist_maker.UpdatePredictionCache(dmat, cache);
|
hist_maker.UpdatePredictionCache(dmat, cache);
|
||||||
}
|
}
|
||||||
@ -391,13 +392,11 @@ TEST(GpuHist, ConfigIO) {
|
|||||||
Json j_updater { Object() };
|
Json j_updater { Object() };
|
||||||
updater->SaveConfig(&j_updater);
|
updater->SaveConfig(&j_updater);
|
||||||
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
|
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
|
||||||
ASSERT_TRUE(IsA<Object>(j_updater["train_param"]));
|
|
||||||
updater->LoadConfig(j_updater);
|
updater->LoadConfig(j_updater);
|
||||||
|
|
||||||
Json j_updater_roundtrip { Object() };
|
Json j_updater_roundtrip { Object() };
|
||||||
updater->SaveConfig(&j_updater_roundtrip);
|
updater->SaveConfig(&j_updater_roundtrip);
|
||||||
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
|
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
|
||||||
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["train_param"]));
|
|
||||||
|
|
||||||
ASSERT_EQ(j_updater, j_updater_roundtrip);
|
ASSERT_EQ(j_updater, j_updater_roundtrip);
|
||||||
}
|
}
|
||||||
@ -414,5 +413,4 @@ TEST(GpuHist, MaxDepth) {
|
|||||||
|
|
||||||
ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error);
|
ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error);
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -5,11 +5,10 @@
|
|||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols){
|
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols){
|
||||||
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||||
}
|
}
|
||||||
@ -45,11 +44,11 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{
|
TrainParam param;
|
||||||
{"interaction_constraints", "[[0, 1]]"},
|
param.UpdateAllowUnknown(
|
||||||
{"num_feature", std::to_string(kCols)}});
|
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
@ -64,9 +63,10 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
@ -83,7 +83,6 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
|||||||
Context ctx;
|
Context ctx;
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(cols)}});
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> sliced{
|
std::unique_ptr<DMatrix> sliced{
|
||||||
@ -91,7 +90,9 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
|||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = cols;
|
tree.param.num_feature = cols;
|
||||||
updater->Update(p_gradients.get(), sliced.get(), position, {&tree});
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||||
|
|
||||||
EXPECT_EQ(tree.NumExtraNodes(), 10);
|
EXPECT_EQ(tree.NumExtraNodes(), 10);
|
||||||
EXPECT_EQ(tree[0].SplitIndex(), 1);
|
EXPECT_EQ(tree[0].SplitIndex(), 1);
|
||||||
@ -115,14 +116,13 @@ TEST(GrowHistMaker, ColumnSplit) {
|
|||||||
Context ctx;
|
Context ctx;
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&expected_tree});
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&expected_tree});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto constexpr kWorldSize = 2;
|
auto constexpr kWorldSize = 2;
|
||||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit, kRows, kCols, std::cref(expected_tree));
|
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit, kRows, kCols, std::cref(expected_tree));
|
||||||
}
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -75,9 +76,11 @@ class TestPredictionCache : public ::testing::Test {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<RegTree *> trees{&tree};
|
std::vector<RegTree *> trees{&tree};
|
||||||
auto gpair = GenerateRandomGradients(n_samples_);
|
auto gpair = GenerateRandomGradients(n_samples_);
|
||||||
updater->Configure(Args{{"max_bin", "64"}});
|
tree::TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(&gpair, Xy_.get(), position, trees);
|
updater->Update(¶m, &gpair, Xy_.get(), position, trees);
|
||||||
HostDeviceVector<float> out_prediction_cached;
|
HostDeviceVector<float> out_prediction_cached;
|
||||||
out_prediction_cached.SetDevice(ctx.gpu_id);
|
out_prediction_cached.SetDevice(ctx.gpu_id);
|
||||||
out_prediction_cached.Resize(n_samples_);
|
out_prediction_cached.Resize(n_samples_);
|
||||||
|
|||||||
@ -1,20 +1,20 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018-2019 by Contributors
|
* Copyright 2018-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/tree_updater.h>
|
|
||||||
#include <xgboost/learner.h>
|
#include <xgboost/learner.h>
|
||||||
#include <gtest/gtest.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
TEST(Updater, Prune) {
|
TEST(Updater, Prune) {
|
||||||
int constexpr kCols = 16;
|
int constexpr kCols = 16;
|
||||||
|
|
||||||
@ -36,28 +36,30 @@ TEST(Updater, Prune) {
|
|||||||
tree.param.UpdateAllowUnknown(cfg);
|
tree.param.UpdateAllowUnknown(cfg);
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
// prepare pruner
|
// prepare pruner
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(cfg);
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> pruner(
|
std::unique_ptr<TreeUpdater> pruner(
|
||||||
TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression}));
|
TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression}));
|
||||||
pruner->Configure(cfg);
|
|
||||||
|
|
||||||
// loss_chg < min_split_loss;
|
// loss_chg < min_split_loss;
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
|
std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
|
||||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
|
||||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
pruner->Update(&gpair, p_dmat.get(), position, trees);
|
pruner->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||||
|
|
||||||
// loss_chg > min_split_loss;
|
// loss_chg > min_split_loss;
|
||||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
|
||||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
pruner->Update(&gpair, p_dmat.get(), position, trees);
|
pruner->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
|
||||||
// loss_chg == min_split_loss;
|
// loss_chg == min_split_loss;
|
||||||
tree.Stat(0).loss_chg = 10;
|
tree.Stat(0).loss_chg = 10;
|
||||||
pruner->Update(&gpair, p_dmat.get(), position, trees);
|
pruner->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
|
||||||
@ -71,10 +73,10 @@ TEST(Updater, Prune) {
|
|||||||
0, 0.5f, true, 0.3, 0.4, 0.5,
|
0, 0.5f, true, 0.3, 0.4, 0.5,
|
||||||
/*loss_chg=*/19.0f, 0.0f,
|
/*loss_chg=*/19.0f, 0.0f,
|
||||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
cfg.emplace_back("max_depth", "1");
|
|
||||||
pruner->Configure(cfg);
|
|
||||||
pruner->Update(&gpair, p_dmat.get(), position, trees);
|
|
||||||
|
|
||||||
|
cfg.emplace_back("max_depth", "1");
|
||||||
|
param.UpdateAllowUnknown(cfg);
|
||||||
|
pruner->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
|
||||||
tree.ExpandNode(tree[0].LeftChild(),
|
tree.ExpandNode(tree[0].LeftChild(),
|
||||||
@ -82,9 +84,9 @@ TEST(Updater, Prune) {
|
|||||||
/*loss_chg=*/18.0f, 0.0f,
|
/*loss_chg=*/18.0f, 0.0f,
|
||||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
cfg.emplace_back("min_split_loss", "0");
|
cfg.emplace_back("min_split_loss", "0");
|
||||||
pruner->Configure(cfg);
|
param.UpdateAllowUnknown(cfg);
|
||||||
pruner->Update(&gpair, p_dmat.get(), position, trees);
|
|
||||||
|
pruner->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018-2019 by Contributors
|
* Copyright 2018-2013 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -43,9 +44,11 @@ TEST(Updater, Refresh) {
|
|||||||
tree.Stat(cleft).base_weight = 1.2;
|
tree.Stat(cleft).base_weight = 1.2;
|
||||||
tree.Stat(cright).base_weight = 1.3;
|
tree.Stat(cright).base_weight = 1.3;
|
||||||
|
|
||||||
refresher->Configure(cfg);
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position;
|
std::vector<HostDeviceVector<bst_node_t>> position;
|
||||||
refresher->Update(&gpair, p_dmat.get(), position, trees);
|
tree::TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(cfg);
|
||||||
|
|
||||||
|
refresher->Update(¶m, &gpair, p_dmat.get(), position, trees);
|
||||||
|
|
||||||
bst_float constexpr kEps = 1e-6;
|
bst_float constexpr kEps = 1e-6;
|
||||||
ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps);
|
ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps);
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -21,6 +25,9 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RunTest(std::string updater) {
|
void RunTest(std::string updater) {
|
||||||
|
tree::TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
|
||||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
: CreateEmptyGenericParam(Context::kCpuId));
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
auto up = std::unique_ptr<TreeUpdater>{
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
@ -29,7 +36,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up->Update(&gpairs_, p_dmat_.get(), position, {&tree});
|
up->Update(¶m, &gpairs_, p_dmat_.get(), position, {&tree});
|
||||||
|
|
||||||
tree.WalkTree([&tree](bst_node_t nidx) {
|
tree.WalkTree([&tree](bst_node_t nidx) {
|
||||||
if (tree[nidx].IsLeaf()) {
|
if (tree[nidx].IsLeaf()) {
|
||||||
@ -69,28 +76,33 @@ class UpdaterEtaTest : public ::testing::Test {
|
|||||||
void RunTest(std::string updater) {
|
void RunTest(std::string updater) {
|
||||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
: CreateEmptyGenericParam(Context::kCpuId));
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
|
|
||||||
float eta = 0.4;
|
float eta = 0.4;
|
||||||
auto up_0 = std::unique_ptr<TreeUpdater>{
|
auto up_0 = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||||
up_0->Configure(Args{{"eta", std::to_string(eta)}});
|
up_0->Configure(Args{});
|
||||||
|
tree::TrainParam param0;
|
||||||
|
param0.Init(Args{{"eta", std::to_string(eta)}});
|
||||||
|
|
||||||
auto up_1 = std::unique_ptr<TreeUpdater>{
|
auto up_1 = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||||
up_1->Configure(Args{{"eta", "1.0"}});
|
up_1->Configure(Args{{"eta", "1.0"}});
|
||||||
|
tree::TrainParam param1;
|
||||||
|
param1.Init(Args{{"eta", "1.0"}});
|
||||||
|
|
||||||
for (size_t iter = 0; iter < 4; ++iter) {
|
for (size_t iter = 0; iter < 4; ++iter) {
|
||||||
RegTree tree_0;
|
RegTree tree_0;
|
||||||
{
|
{
|
||||||
tree_0.param.num_feature = kCols;
|
tree_0.param.num_feature = kCols;
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up_0->Update(&gpairs_, p_dmat_.get(), position, {&tree_0});
|
up_0->Update(¶m0, &gpairs_, p_dmat_.get(), position, {&tree_0});
|
||||||
}
|
}
|
||||||
|
|
||||||
RegTree tree_1;
|
RegTree tree_1;
|
||||||
{
|
{
|
||||||
tree_1.param.num_feature = kCols;
|
tree_1.param.num_feature = kCols;
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up_1->Update(&gpairs_, p_dmat_.get(), position, {&tree_1});
|
up_1->Update(¶m1, &gpairs_, p_dmat_.get(), position, {&tree_1});
|
||||||
}
|
}
|
||||||
tree_0.WalkTree([&](bst_node_t nidx) {
|
tree_0.WalkTree([&](bst_node_t nidx) {
|
||||||
if (tree_0[nidx].IsLeaf()) {
|
if (tree_0[nidx].IsLeaf()) {
|
||||||
@ -139,17 +151,18 @@ class TestMinSplitLoss : public ::testing::Test {
|
|||||||
|
|
||||||
// test gamma
|
// test gamma
|
||||||
{"gamma", std::to_string(gamma)}};
|
{"gamma", std::to_string(gamma)}};
|
||||||
|
tree::TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
: CreateEmptyGenericParam(Context::kCpuId));
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
std::cout << ctx.gpu_id << std::endl;
|
|
||||||
auto up = std::unique_ptr<TreeUpdater>{
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
up->Configure(args);
|
up->Configure({});
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up->Update(&gpair_, dmat_.get(), position, {&tree});
|
up->Update(¶m, &gpair_, dmat_.get(), position, {&tree});
|
||||||
|
|
||||||
auto n_nodes = tree.NumExtraNodes();
|
auto n_nodes = tree.NumExtraNodes();
|
||||||
return n_nodes;
|
return n_nodes;
|
||||||
|
|||||||
@ -42,9 +42,15 @@ class TestGPUBasicModels:
|
|||||||
def test_custom_objective(self):
|
def test_custom_objective(self):
|
||||||
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
||||||
|
|
||||||
def test_eta_decay_gpu_hist(self):
|
def test_eta_decay(self):
|
||||||
self.cpu_test_cb.run_eta_decay('gpu_hist')
|
self.cpu_test_cb.run_eta_decay('gpu_hist')
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"objective", ["binary:logistic", "reg:absoluteerror", "reg:quantileerror"]
|
||||||
|
)
|
||||||
|
def test_eta_decay_leaf_output(self, objective) -> None:
|
||||||
|
self.cpu_test_cb.run_eta_decay_leaf_output("gpu_hist", objective)
|
||||||
|
|
||||||
def test_deterministic_gpu_hist(self):
|
def test_deterministic_gpu_hist(self):
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
kCols = 64
|
kCols = 64
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
@ -355,47 +356,125 @@ class TestCallbacks:
|
|||||||
with warning_check:
|
with warning_check:
|
||||||
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
||||||
|
|
||||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
|
||||||
|
# check decay has effect on leaf output.
|
||||||
|
num_round = 4
|
||||||
|
scheduler = xgb.callback.LearningRateScheduler
|
||||||
|
|
||||||
|
dpath = tm.data_dir(__file__)
|
||||||
|
dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
|
||||||
|
dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
|
||||||
|
param = {
|
||||||
|
"max_depth": 2,
|
||||||
|
"objective": objective,
|
||||||
|
"eval_metric": "error",
|
||||||
|
"tree_method": tree_method,
|
||||||
|
}
|
||||||
|
if objective == "reg:quantileerror":
|
||||||
|
param["quantile_alpha"] = 0.3
|
||||||
|
|
||||||
|
def eta_decay_0(i):
|
||||||
|
return num_round / (i + 1)
|
||||||
|
|
||||||
|
bst0 = xgb.train(
|
||||||
|
param,
|
||||||
|
dtrain,
|
||||||
|
num_round,
|
||||||
|
watchlist,
|
||||||
|
callbacks=[scheduler(eta_decay_0)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def eta_decay_1(i: int) -> float:
|
||||||
|
if i > 1:
|
||||||
|
return 5.0
|
||||||
|
return num_round / (i + 1)
|
||||||
|
|
||||||
|
bst1 = xgb.train(
|
||||||
|
param,
|
||||||
|
dtrain,
|
||||||
|
num_round,
|
||||||
|
watchlist,
|
||||||
|
callbacks=[scheduler(eta_decay_1)],
|
||||||
|
)
|
||||||
|
bst_json0 = bst0.save_raw(raw_format="json")
|
||||||
|
bst_json1 = bst1.save_raw(raw_format="json")
|
||||||
|
|
||||||
|
j0 = json.loads(bst_json0)
|
||||||
|
j1 = json.loads(bst_json1)
|
||||||
|
|
||||||
|
tree_2th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][2]
|
||||||
|
tree_2th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][2]
|
||||||
|
assert tree_2th_0["base_weights"] == tree_2th_1["base_weights"]
|
||||||
|
assert tree_2th_0["split_conditions"] == tree_2th_1["split_conditions"]
|
||||||
|
|
||||||
|
tree_3th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][3]
|
||||||
|
tree_3th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][3]
|
||||||
|
assert tree_3th_0["base_weights"] != tree_3th_1["base_weights"]
|
||||||
|
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tree_method", ["hist", "approx", "approx"])
|
||||||
def test_eta_decay(self, tree_method):
|
def test_eta_decay(self, tree_method):
|
||||||
self.run_eta_decay(tree_method)
|
self.run_eta_decay(tree_method)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tree_method,objective",
|
||||||
|
[
|
||||||
|
("hist", "binary:logistic"),
|
||||||
|
("hist", "reg:absoluteerror"),
|
||||||
|
("hist", "reg:quantileerror"),
|
||||||
|
("approx", "binary:logistic"),
|
||||||
|
("approx", "reg:absoluteerror"),
|
||||||
|
("approx", "reg:quantileerror"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
|
||||||
|
self.run_eta_decay_leaf_output(tree_method, objective)
|
||||||
|
|
||||||
def test_check_point(self):
|
def test_check_point(self):
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
iterations=1,
|
directory=tmpdir, iterations=1, name="model"
|
||||||
name='model')
|
)
|
||||||
xgb.train({'objective': 'binary:logistic'}, m,
|
xgb.train(
|
||||||
num_boost_round=10,
|
{"objective": "binary:logistic"},
|
||||||
verbose_eval=False,
|
m,
|
||||||
callbacks=[check_point])
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point],
|
||||||
|
)
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
assert os.path.exists(
|
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
||||||
os.path.join(tmpdir, 'model_' + str(i) + '.json'))
|
|
||||||
|
|
||||||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
iterations=1,
|
directory=tmpdir, iterations=1, as_pickle=True, name="model"
|
||||||
as_pickle=True,
|
)
|
||||||
name='model')
|
xgb.train(
|
||||||
xgb.train({'objective': 'binary:logistic'}, m,
|
{"objective": "binary:logistic"},
|
||||||
num_boost_round=10,
|
m,
|
||||||
verbose_eval=False,
|
num_boost_round=10,
|
||||||
callbacks=[check_point])
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point],
|
||||||
|
)
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
assert os.path.exists(
|
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
|
||||||
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
|
|
||||||
|
|
||||||
def test_callback_list(self):
|
def test_callback_list(self):
|
||||||
X, y = tm.get_california_housing()
|
X, y = tm.get_california_housing()
|
||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
xgb.train({'objective': 'reg:squarederror',
|
xgb.train(
|
||||||
'eval_metric': 'rmse'}, m,
|
{"objective": "reg:squarederror", "eval_metric": "rmse"},
|
||||||
evals=[(m, 'Train')],
|
m,
|
||||||
num_boost_round=1,
|
evals=[(m, "Train")],
|
||||||
verbose_eval=True,
|
num_boost_round=1,
|
||||||
callbacks=callbacks)
|
verbose_eval=True,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
assert len(callbacks) == 1
|
assert len(callbacks) == 1
|
||||||
|
|||||||
@ -51,11 +51,8 @@ class TestPickling:
|
|||||||
|
|
||||||
def test_model_pickling_json(self):
|
def test_model_pickling_json(self):
|
||||||
def check(config):
|
def check(config):
|
||||||
updater = config["learner"]["gradient_booster"]["updater"]
|
tree_param = config["learner"]["gradient_booster"]["tree_train_param"]
|
||||||
if params["tree_method"] == "exact":
|
subsample = tree_param["subsample"]
|
||||||
subsample = updater["grow_colmaker"]["train_param"]["subsample"]
|
|
||||||
else:
|
|
||||||
subsample = updater["grow_quantile_histmaker"]["train_param"]["subsample"]
|
|
||||||
assert float(subsample) == 0.5
|
assert float(subsample) == 0.5
|
||||||
|
|
||||||
params = {"nthread": 8, "tree_method": "hist", "subsample": 0.5}
|
params = {"nthread": 8, "tree_method": "hist", "subsample": 0.5}
|
||||||
|
|||||||
@ -447,7 +447,8 @@ class TestTreeMethod:
|
|||||||
{
|
{
|
||||||
"tree_method": tree_method,
|
"tree_method": tree_method,
|
||||||
"objective": "reg:absoluteerror",
|
"objective": "reg:absoluteerror",
|
||||||
"subsample": 0.8
|
"subsample": 0.8,
|
||||||
|
"eta": 1.0,
|
||||||
},
|
},
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=10,
|
num_boost_round=10,
|
||||||
|
|||||||
@ -1018,14 +1018,18 @@ def test_XGBClassifier_resume():
|
|||||||
|
|
||||||
|
|
||||||
def test_constraint_parameters():
|
def test_constraint_parameters():
|
||||||
reg = xgb.XGBRegressor(interaction_constraints='[[0, 1], [2, 3, 4]]')
|
reg = xgb.XGBRegressor(interaction_constraints="[[0, 1], [2, 3, 4]]")
|
||||||
X = np.random.randn(10, 10)
|
X = np.random.randn(10, 10)
|
||||||
y = np.random.randn(10)
|
y = np.random.randn(10)
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
|
|
||||||
config = json.loads(reg.get_booster().save_config())
|
config = json.loads(reg.get_booster().save_config())
|
||||||
assert config['learner']['gradient_booster']['updater']['grow_colmaker'][
|
assert (
|
||||||
'train_param']['interaction_constraints'] == '[[0, 1], [2, 3, 4]]'
|
config["learner"]["gradient_booster"]["tree_train_param"][
|
||||||
|
"interaction_constraints"
|
||||||
|
]
|
||||||
|
== "[[0, 1], [2, 3, 4]]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parameter_validation():
|
def test_parameter_validation():
|
||||||
|
|||||||
@ -422,10 +422,10 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
|||||||
self.assertTrue(hasattr(classifier, "max_depth"))
|
self.assertTrue(hasattr(classifier, "max_depth"))
|
||||||
self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7)
|
self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7)
|
||||||
booster_config = json.loads(model.get_booster().save_config())
|
booster_config = json.loads(model.get_booster().save_config())
|
||||||
max_depth = booster_config["learner"]["gradient_booster"]["updater"][
|
max_depth = booster_config["learner"]["gradient_booster"]["tree_train_param"][
|
||||||
"grow_histmaker"
|
"max_depth"
|
||||||
]["train_param"]["max_depth"]
|
]
|
||||||
self.assertEqual(int(max_depth), 7)
|
assert int(max_depth) == 7
|
||||||
|
|
||||||
def test_repartition(self):
|
def test_repartition(self):
|
||||||
# The following test case has a few partitioned datasets that are either
|
# The following test case has a few partitioned datasets that are either
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user