Remove public access to tree model param. (#8902)
* Make tree model param a private member. * Number of features and targets are immutable after construction. This is to reduce the number of places where we can run configuration.
This commit is contained in:
parent
5ba3509dd3
commit
9bade7203a
@ -178,51 +178,33 @@ class RegTree : public Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief index of left child */
|
/*! \brief index of left child */
|
||||||
XGBOOST_DEVICE [[nodiscard]] int LeftChild() const {
|
[[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
|
||||||
return this->cleft_;
|
|
||||||
}
|
|
||||||
/*! \brief index of right child */
|
/*! \brief index of right child */
|
||||||
XGBOOST_DEVICE [[nodiscard]] int RightChild() const {
|
[[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
|
||||||
return this->cright_;
|
|
||||||
}
|
|
||||||
/*! \brief index of default child when feature is missing */
|
/*! \brief index of default child when feature is missing */
|
||||||
XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const {
|
[[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
|
||||||
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
||||||
}
|
}
|
||||||
/*! \brief feature index of split condition */
|
/*! \brief feature index of split condition */
|
||||||
XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const {
|
[[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
|
||||||
return sindex_ & ((1U << 31) - 1U);
|
return sindex_ & ((1U << 31) - 1U);
|
||||||
}
|
}
|
||||||
/*! \brief when feature is unknown, whether goes to left child */
|
/*! \brief when feature is unknown, whether goes to left child */
|
||||||
XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const {
|
[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
|
||||||
return (sindex_ >> 31) != 0;
|
|
||||||
}
|
|
||||||
/*! \brief whether current node is leaf node */
|
/*! \brief whether current node is leaf node */
|
||||||
XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const {
|
[[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
|
||||||
return cleft_ == kInvalidNodeId;
|
|
||||||
}
|
|
||||||
/*! \return get leaf value of leaf node */
|
/*! \return get leaf value of leaf node */
|
||||||
XGBOOST_DEVICE [[nodiscard]] float LeafValue() const {
|
[[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
|
||||||
return (this->info_).leaf_value;
|
|
||||||
}
|
|
||||||
/*! \return get split condition of the node */
|
/*! \return get split condition of the node */
|
||||||
XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const {
|
[[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
|
||||||
return (this->info_).split_cond;
|
|
||||||
}
|
|
||||||
/*! \brief get parent of the node */
|
/*! \brief get parent of the node */
|
||||||
XGBOOST_DEVICE [[nodiscard]] int Parent() const {
|
[[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
|
||||||
return parent_ & ((1U << 31) - 1);
|
|
||||||
}
|
|
||||||
/*! \brief whether current node is left child */
|
/*! \brief whether current node is left child */
|
||||||
XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const {
|
[[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
|
||||||
return (parent_ & (1U << 31)) != 0;
|
|
||||||
}
|
|
||||||
/*! \brief whether this node is deleted */
|
/*! \brief whether this node is deleted */
|
||||||
XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const {
|
[[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
|
||||||
return sindex_ == kDeletedNodeMarker;
|
|
||||||
}
|
|
||||||
/*! \brief whether current node is root */
|
/*! \brief whether current node is root */
|
||||||
XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
[[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||||
/*!
|
/*!
|
||||||
* \brief set the left child
|
* \brief set the left child
|
||||||
* \param nid node id to right child
|
* \param nid node id to right child
|
||||||
@ -337,15 +319,13 @@ class RegTree : public Model {
|
|||||||
this->ChangeToLeaf(rid, value);
|
this->ChangeToLeaf(rid, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief model parameter */
|
|
||||||
TreeParam param;
|
|
||||||
RegTree() {
|
RegTree() {
|
||||||
param.Init(Args{});
|
param_.Init(Args{});
|
||||||
nodes_.resize(param.num_nodes);
|
nodes_.resize(param_.num_nodes);
|
||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param_.num_nodes);
|
||||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||||
split_categories_segments_.resize(param.num_nodes);
|
split_categories_segments_.resize(param_.num_nodes);
|
||||||
for (int i = 0; i < param.num_nodes; i++) {
|
for (int i = 0; i < param_.num_nodes; i++) {
|
||||||
nodes_[i].SetLeaf(0.0f);
|
nodes_[i].SetLeaf(0.0f);
|
||||||
nodes_[i].SetParent(kInvalidNodeId);
|
nodes_[i].SetParent(kInvalidNodeId);
|
||||||
}
|
}
|
||||||
@ -354,10 +334,10 @@ class RegTree : public Model {
|
|||||||
* \brief Constructor that initializes the tree model with shape.
|
* \brief Constructor that initializes the tree model with shape.
|
||||||
*/
|
*/
|
||||||
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
|
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
|
||||||
param.num_feature = n_features;
|
param_.num_feature = n_features;
|
||||||
param.size_leaf_vector = n_targets;
|
param_.size_leaf_vector = n_targets;
|
||||||
if (n_targets > 1) {
|
if (n_targets > 1) {
|
||||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
this->p_mt_tree_.reset(new MultiTargetTree{¶m_});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,7 +381,7 @@ class RegTree : public Model {
|
|||||||
|
|
||||||
bool operator==(const RegTree& b) const {
|
bool operator==(const RegTree& b) const {
|
||||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
|
||||||
}
|
}
|
||||||
/* \brief Iterate through all nodes in this tree.
|
/* \brief Iterate through all nodes in this tree.
|
||||||
*
|
*
|
||||||
@ -459,7 +439,9 @@ class RegTree : public Model {
|
|||||||
bst_float loss_change, float sum_hess, float left_sum,
|
bst_float loss_change, float sum_hess, float left_sum,
|
||||||
float right_sum,
|
float right_sum,
|
||||||
bst_node_t leaf_right_child = kInvalidNodeId);
|
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||||
|
/**
|
||||||
|
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
|
||||||
|
*/
|
||||||
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
|
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
|
||||||
linalg::VectorView<float const> base_weight,
|
linalg::VectorView<float const> base_weight,
|
||||||
linalg::VectorView<float const> left_weight,
|
linalg::VectorView<float const> left_weight,
|
||||||
@ -485,19 +467,48 @@ class RegTree : public Model {
|
|||||||
bst_float base_weight, bst_float left_leaf_weight,
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||||
float left_sum, float right_sum);
|
float left_sum, float right_sum);
|
||||||
|
/**
|
||||||
[[nodiscard]] bool HasCategoricalSplit() const {
|
* \brief Whether this tree has categorical split.
|
||||||
return !split_categories_.empty();
|
*/
|
||||||
}
|
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
|
||||||
/**
|
/**
|
||||||
* \brief Whether this is a multi-target tree.
|
* \brief Whether this is a multi-target tree.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
||||||
[[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; }
|
/**
|
||||||
|
* \brief The size of leaf weight.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
|
||||||
|
/**
|
||||||
|
* \brief Get the underlying implementaiton of multi-target tree.
|
||||||
|
*/
|
||||||
[[nodiscard]] auto GetMultiTargetTree() const {
|
[[nodiscard]] auto GetMultiTargetTree() const {
|
||||||
CHECK(IsMultiTarget());
|
CHECK(IsMultiTarget());
|
||||||
return p_mt_tree_.get();
|
return p_mt_tree_.get();
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* \brief Get the number of features.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
|
||||||
|
/**
|
||||||
|
* \brief Get the total number of nodes including deleted ones in this tree.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
|
||||||
|
/**
|
||||||
|
* \brief Get the total number of valid nodes in this tree.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_node_t NumValidNodes() const noexcept {
|
||||||
|
return param_.num_nodes - param_.num_deleted;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief number of extra nodes besides the root
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
|
||||||
|
return param_.num_nodes - 1 - param_.num_deleted;
|
||||||
|
}
|
||||||
|
/* \brief Count number of leaves in tree. */
|
||||||
|
[[nodiscard]] bst_node_t GetNumLeaves() const;
|
||||||
|
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get current depth
|
* \brief get current depth
|
||||||
@ -514,6 +525,9 @@ class RegTree : public Model {
|
|||||||
}
|
}
|
||||||
return depth;
|
return depth;
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* \brief Set the leaf weight for a multi-target tree.
|
||||||
|
*/
|
||||||
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
||||||
CHECK(IsMultiTarget());
|
CHECK(IsMultiTarget());
|
||||||
return this->p_mt_tree_->SetLeaf(nidx, weight);
|
return this->p_mt_tree_->SetLeaf(nidx, weight);
|
||||||
@ -525,25 +539,13 @@ class RegTree : public Model {
|
|||||||
*/
|
*/
|
||||||
[[nodiscard]] int MaxDepth(int nid) const {
|
[[nodiscard]] int MaxDepth(int nid) const {
|
||||||
if (nodes_[nid].IsLeaf()) return 0;
|
if (nodes_[nid].IsLeaf()) return 0;
|
||||||
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
|
return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
|
||||||
MaxDepth(nodes_[nid].RightChild())+1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get maximum depth
|
* \brief get maximum depth
|
||||||
*/
|
*/
|
||||||
int MaxDepth() {
|
int MaxDepth() { return MaxDepth(0); }
|
||||||
return MaxDepth(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief number of extra nodes besides the root */
|
|
||||||
[[nodiscard]] int NumExtraNodes() const {
|
|
||||||
return param.num_nodes - 1 - param.num_deleted;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* \brief Count number of leaves in tree. */
|
|
||||||
[[nodiscard]] bst_node_t GetNumLeaves() const;
|
|
||||||
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief dense feature vector that can be taken by RegTree
|
* \brief dense feature vector that can be taken by RegTree
|
||||||
@ -735,6 +737,8 @@ class RegTree : public Model {
|
|||||||
template <bool typed>
|
template <bool typed>
|
||||||
void LoadCategoricalSplit(Json const& in);
|
void LoadCategoricalSplit(Json const& in);
|
||||||
void SaveCategoricalSplit(Json* p_out) const;
|
void SaveCategoricalSplit(Json* p_out) const;
|
||||||
|
/*! \brief model parameter */
|
||||||
|
TreeParam param_;
|
||||||
// vector of nodes
|
// vector of nodes
|
||||||
std::vector<Node> nodes_;
|
std::vector<Node> nodes_;
|
||||||
// free node space, used during training process
|
// free node space, used during training process
|
||||||
@ -752,20 +756,20 @@ class RegTree : public Model {
|
|||||||
// allocate a new node,
|
// allocate a new node,
|
||||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||||
bst_node_t AllocNode() {
|
bst_node_t AllocNode() {
|
||||||
if (param.num_deleted != 0) {
|
if (param_.num_deleted != 0) {
|
||||||
int nid = deleted_nodes_.back();
|
int nid = deleted_nodes_.back();
|
||||||
deleted_nodes_.pop_back();
|
deleted_nodes_.pop_back();
|
||||||
nodes_[nid].Reuse();
|
nodes_[nid].Reuse();
|
||||||
--param.num_deleted;
|
--param_.num_deleted;
|
||||||
return nid;
|
return nid;
|
||||||
}
|
}
|
||||||
int nd = param.num_nodes++;
|
int nd = param_.num_nodes++;
|
||||||
CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
|
CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
|
||||||
<< "number of nodes in the tree exceed 2^31";
|
<< "number of nodes in the tree exceed 2^31";
|
||||||
nodes_.resize(param.num_nodes);
|
nodes_.resize(param_.num_nodes);
|
||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param_.num_nodes);
|
||||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||||
split_categories_segments_.resize(param.num_nodes);
|
split_categories_segments_.resize(param_.num_nodes);
|
||||||
return nd;
|
return nd;
|
||||||
}
|
}
|
||||||
// delete a tree node, keep the parent field to allow trace back
|
// delete a tree node, keep the parent field to allow trace back
|
||||||
@ -780,7 +784,7 @@ class RegTree : public Model {
|
|||||||
|
|
||||||
deleted_nodes_.push_back(nid);
|
deleted_nodes_.push_back(nid);
|
||||||
nodes_[nid].MarkDelete();
|
nodes_[nid].MarkDelete();
|
||||||
++param.num_deleted;
|
++param_.num_deleted;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -360,8 +360,8 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
|||||||
<< "Set `process_type` to `update` if you want to update existing "
|
<< "Set `process_type` to `update` if you want to update existing "
|
||||||
"trees.";
|
"trees.";
|
||||||
// create new tree
|
// create new tree
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
std::unique_ptr<RegTree> ptr(new RegTree{this->model_.learner_model_param->LeafLength(),
|
||||||
ptr->param.UpdateAllowUnknown(this->cfg_);
|
this->model_.learner_model_param->num_feature});
|
||||||
new_trees.push_back(ptr.get());
|
new_trees.push_back(ptr.get());
|
||||||
ret->push_back(std::move(ptr));
|
ret->push_back(std::move(ptr));
|
||||||
} else if (tparam_.process_type == TreeProcessType::kUpdate) {
|
} else if (tparam_.process_type == TreeProcessType::kUpdate) {
|
||||||
|
|||||||
@ -775,8 +775,6 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
CHECK_NE(mparam_.num_feature, 0)
|
CHECK_NE(mparam_.num_feature, 0)
|
||||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||||
// Remove these once binary IO is gone.
|
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
|
void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
|
||||||
|
|||||||
@ -275,7 +275,7 @@ float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
||||||
size_t num_nodes = tree->param.num_nodes;
|
size_t num_nodes = tree->NumNodes();
|
||||||
if (mean_values->size() == num_nodes) {
|
if (mean_values->size() == num_nodes) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -815,9 +815,9 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
|
|||||||
linalg::VectorView<float const> left_weight,
|
linalg::VectorView<float const> left_weight,
|
||||||
linalg::VectorView<float const> right_weight) {
|
linalg::VectorView<float const> right_weight) {
|
||||||
CHECK(IsMultiTarget());
|
CHECK(IsMultiTarget());
|
||||||
CHECK_LT(split_index, this->param.num_feature);
|
CHECK_LT(split_index, this->param_.num_feature);
|
||||||
CHECK(this->p_mt_tree_);
|
CHECK(this->p_mt_tree_);
|
||||||
CHECK_GT(param.size_leaf_vector, 1);
|
CHECK_GT(param_.size_leaf_vector, 1);
|
||||||
|
|
||||||
this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight,
|
this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight,
|
||||||
right_weight);
|
right_weight);
|
||||||
@ -826,7 +826,7 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
|
|||||||
split_categories_segments_.resize(this->Size());
|
split_categories_segments_.resize(this->Size());
|
||||||
this->split_types_.at(nidx) = FeatureType::kNumerical;
|
this->split_types_.at(nidx) = FeatureType::kNumerical;
|
||||||
|
|
||||||
this->param.num_nodes = this->p_mt_tree_->Size();
|
this->param_.num_nodes = this->p_mt_tree_->Size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
||||||
@ -850,13 +850,13 @@ void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::Load(dmlc::Stream* fi) {
|
void RegTree::Load(dmlc::Stream* fi) {
|
||||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
CHECK_EQ(fi->Read(¶m_, sizeof(TreeParam)), sizeof(TreeParam));
|
||||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||||
param = param.ByteSwap();
|
param_ = param_.ByteSwap();
|
||||||
}
|
}
|
||||||
nodes_.resize(param.num_nodes);
|
nodes_.resize(param_.num_nodes);
|
||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param_.num_nodes);
|
||||||
CHECK_NE(param.num_nodes, 0);
|
CHECK_NE(param_.num_nodes, 0);
|
||||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
|
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
|
||||||
sizeof(Node) * nodes_.size());
|
sizeof(Node) * nodes_.size());
|
||||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||||
@ -873,29 +873,29 @@ void RegTree::Load(dmlc::Stream* fi) {
|
|||||||
}
|
}
|
||||||
// chg deleted nodes
|
// chg deleted nodes
|
||||||
deleted_nodes_.resize(0);
|
deleted_nodes_.resize(0);
|
||||||
for (int i = 1; i < param.num_nodes; ++i) {
|
for (int i = 1; i < param_.num_nodes; ++i) {
|
||||||
if (nodes_[i].IsDeleted()) {
|
if (nodes_[i].IsDeleted()) {
|
||||||
deleted_nodes_.push_back(i);
|
deleted_nodes_.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param_.num_deleted);
|
||||||
|
|
||||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||||
split_categories_segments_.resize(param.num_nodes);
|
split_categories_segments_.resize(param_.num_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::Save(dmlc::Stream* fo) const {
|
void RegTree::Save(dmlc::Stream* fo) const {
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
CHECK_EQ(param_.num_nodes, static_cast<int>(nodes_.size()));
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
||||||
CHECK_EQ(param.deprecated_num_roots, 1);
|
CHECK_EQ(param_.deprecated_num_roots, 1);
|
||||||
CHECK_NE(param.num_nodes, 0);
|
CHECK_NE(param_.num_nodes, 0);
|
||||||
CHECK(!HasCategoricalSplit())
|
CHECK(!HasCategoricalSplit())
|
||||||
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
||||||
|
|
||||||
if (DMLC_IO_NO_ENDIAN_SWAP) {
|
if (DMLC_IO_NO_ENDIAN_SWAP) {
|
||||||
fo->Write(¶m, sizeof(TreeParam));
|
fo->Write(¶m_, sizeof(TreeParam));
|
||||||
} else {
|
} else {
|
||||||
TreeParam x = param.ByteSwap();
|
TreeParam x = param_.ByteSwap();
|
||||||
fo->Write(&x, sizeof(x));
|
fo->Write(&x, sizeof(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1081,7 +1081,7 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
bool typed = IsA<I32Array>(in[tf::kParent]);
|
bool typed = IsA<I32Array>(in[tf::kParent]);
|
||||||
auto const& in_obj = get<Object const>(in);
|
auto const& in_obj = get<Object const>(in);
|
||||||
// basic properties
|
// basic properties
|
||||||
FromJson(in["tree_param"], ¶m);
|
FromJson(in["tree_param"], ¶m_);
|
||||||
// categorical splits
|
// categorical splits
|
||||||
bool has_cat = in_obj.find("split_type") != in_obj.cend();
|
bool has_cat = in_obj.find("split_type") != in_obj.cend();
|
||||||
if (has_cat) {
|
if (has_cat) {
|
||||||
@ -1092,55 +1092,55 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// multi-target
|
// multi-target
|
||||||
if (param.size_leaf_vector > 1) {
|
if (param_.size_leaf_vector > 1) {
|
||||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
this->p_mt_tree_.reset(new MultiTargetTree{¶m_});
|
||||||
this->GetMultiTargetTree()->LoadModel(in);
|
this->GetMultiTargetTree()->LoadModel(in);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
||||||
if (typed && feature_is_64) {
|
if (typed && feature_is_64) {
|
||||||
LoadModelImpl<true, true>(in, param, &stats_, &nodes_);
|
LoadModelImpl<true, true>(in, param_, &stats_, &nodes_);
|
||||||
} else if (typed && !feature_is_64) {
|
} else if (typed && !feature_is_64) {
|
||||||
LoadModelImpl<true, false>(in, param, &stats_, &nodes_);
|
LoadModelImpl<true, false>(in, param_, &stats_, &nodes_);
|
||||||
} else if (!typed && feature_is_64) {
|
} else if (!typed && feature_is_64) {
|
||||||
LoadModelImpl<false, true>(in, param, &stats_, &nodes_);
|
LoadModelImpl<false, true>(in, param_, &stats_, &nodes_);
|
||||||
} else {
|
} else {
|
||||||
LoadModelImpl<false, false>(in, param, &stats_, &nodes_);
|
LoadModelImpl<false, false>(in, param_, &stats_, &nodes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!has_cat) {
|
if (!has_cat) {
|
||||||
this->split_categories_segments_.resize(this->param.num_nodes);
|
this->split_categories_segments_.resize(this->param_.num_nodes);
|
||||||
this->split_types_.resize(this->param.num_nodes);
|
this->split_types_.resize(this->param_.num_nodes);
|
||||||
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
|
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted_nodes_.clear();
|
deleted_nodes_.clear();
|
||||||
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
|
for (bst_node_t i = 1; i < param_.num_nodes; ++i) {
|
||||||
if (nodes_[i].IsDeleted()) {
|
if (nodes_[i].IsDeleted()) {
|
||||||
deleted_nodes_.push_back(i);
|
deleted_nodes_.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// easier access to [] operator
|
// easier access to [] operator
|
||||||
auto& self = *this;
|
auto& self = *this;
|
||||||
for (auto nid = 1; nid < param.num_nodes; ++nid) {
|
for (auto nid = 1; nid < param_.num_nodes; ++nid) {
|
||||||
auto parent = self[nid].Parent();
|
auto parent = self[nid].Parent();
|
||||||
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
||||||
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
||||||
}
|
}
|
||||||
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
|
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param_.num_deleted);
|
||||||
CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes);
|
CHECK_EQ(this->split_categories_segments_.size(), param_.num_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::SaveModel(Json* p_out) const {
|
void RegTree::SaveModel(Json* p_out) const {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
// basic properties
|
// basic properties
|
||||||
out["tree_param"] = ToJson(param);
|
out["tree_param"] = ToJson(param_);
|
||||||
// categorical splits
|
// categorical splits
|
||||||
this->SaveCategoricalSplit(p_out);
|
this->SaveCategoricalSplit(p_out);
|
||||||
// multi-target
|
// multi-target
|
||||||
if (this->IsMultiTarget()) {
|
if (this->IsMultiTarget()) {
|
||||||
CHECK_GT(param.size_leaf_vector, 1);
|
CHECK_GT(param_.size_leaf_vector, 1);
|
||||||
this->GetMultiTargetTree()->SaveModel(p_out);
|
this->GetMultiTargetTree()->SaveModel(p_out);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1150,11 +1150,11 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
* pruner, and this pruner can be used inside another updater so leaf are not necessary
|
* pruner, and this pruner can be used inside another updater so leaf are not necessary
|
||||||
* at the end of node array.
|
* at the end of node array.
|
||||||
*/
|
*/
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
CHECK_EQ(param_.num_nodes, static_cast<int>(nodes_.size()));
|
||||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
||||||
|
|
||||||
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
|
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param_.num_nodes));
|
||||||
auto n_nodes = param.num_nodes;
|
auto n_nodes = param_.num_nodes;
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
F32Array loss_changes(n_nodes);
|
F32Array loss_changes(n_nodes);
|
||||||
@ -1168,7 +1168,7 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
|
|
||||||
F32Array conds(n_nodes);
|
F32Array conds(n_nodes);
|
||||||
U8Array default_left(n_nodes);
|
U8Array default_left(n_nodes);
|
||||||
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
CHECK_EQ(this->split_types_.size(), param_.num_nodes);
|
||||||
|
|
||||||
namespace tf = tree_field;
|
namespace tf = tree_field;
|
||||||
|
|
||||||
@ -1189,7 +1189,7 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
|
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
if (this->param_.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
||||||
I64Array indices_64(n_nodes);
|
I64Array indices_64(n_nodes);
|
||||||
save_tree(&indices_64);
|
save_tree(&indices_64);
|
||||||
out[tf::kSplitIdx] = std::move(indices_64);
|
out[tf::kSplitIdx] = std::move(indices_64);
|
||||||
|
|||||||
@ -190,7 +190,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
// remember auxiliary statistics in the tree node
|
// remember auxiliary statistics in the tree node
|
||||||
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
|
for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
|
||||||
p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg;
|
p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg;
|
||||||
p_tree->Stat(nid).base_weight = snode_[nid].weight;
|
p_tree->Stat(nid).base_weight = snode_[nid].weight;
|
||||||
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
|
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
|
||||||
@ -255,9 +255,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
{
|
{
|
||||||
// setup statistics space for each tree node
|
// setup statistics space for each tree node
|
||||||
for (auto& i : stemp_) {
|
for (auto& i : stemp_) {
|
||||||
i.resize(tree.param.num_nodes, ThreadEntry());
|
i.resize(tree.NumNodes(), ThreadEntry());
|
||||||
}
|
}
|
||||||
snode_.resize(tree.param.num_nodes, NodeEntry());
|
snode_.resize(tree.NumNodes(), NodeEntry());
|
||||||
}
|
}
|
||||||
const MetaInfo& info = fmat.Info();
|
const MetaInfo& info = fmat.Info();
|
||||||
// setup position
|
// setup position
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class TreePruner : public TreeUpdater {
|
|||||||
void DoPrune(TrainParam const* param, RegTree* p_tree) {
|
void DoPrune(TrainParam const* param, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
bst_node_t npruned = 0;
|
bst_node_t npruned = 0;
|
||||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
for (int nid = 0; nid < tree.NumNodes(); ++nid) {
|
||||||
if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
|
if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
|
||||||
npruned = this->TryPruneLeaf(param, p_tree, nid, tree.GetDepth(nid), npruned);
|
npruned = this->TryPruneLeaf(param, p_tree, nid, tree.GetDepth(nid), npruned);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,11 +50,11 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
int num_nodes = 0;
|
int num_nodes = 0;
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
num_nodes += tree->param.num_nodes;
|
num_nodes += tree->NumNodes();
|
||||||
}
|
}
|
||||||
stemp[tid].resize(num_nodes, GradStats());
|
stemp[tid].resize(num_nodes, GradStats());
|
||||||
std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats());
|
std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats());
|
||||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
fvec_temp[tid].Init(trees[0]->NumFeatures());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
exc.Rethrow();
|
exc.Rethrow();
|
||||||
@ -77,7 +77,7 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
AddStats(*tree, feats, gpair_h, info, ridx,
|
AddStats(*tree, feats, gpair_h, info, ridx,
|
||||||
dmlc::BeginPtr(stemp[tid]) + offset);
|
dmlc::BeginPtr(stemp[tid]) + offset);
|
||||||
offset += tree->param.num_nodes;
|
offset += tree->NumNodes();
|
||||||
}
|
}
|
||||||
feats.Drop(inst);
|
feats.Drop(inst);
|
||||||
});
|
});
|
||||||
@ -96,7 +96,7 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
int offset = 0;
|
int offset = 0;
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
||||||
offset += tree->param.num_nodes;
|
offset += tree->NumNodes();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,8 +40,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
{
|
{
|
||||||
// With constraints
|
// With constraints
|
||||||
RegTree tree;
|
RegTree tree{1, kCols};
|
||||||
tree.param.num_feature = kCols;
|
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
@ -58,8 +57,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Without constraints
|
// Without constraints
|
||||||
RegTree tree;
|
RegTree tree{1u, kCols};
|
||||||
tree.param.num_feature = kCols;
|
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
@ -76,7 +74,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
void TestColumnSplit(int32_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||||
auto p_dmat = GenerateDMatrix(rows, cols);
|
auto p_dmat = GenerateDMatrix(rows, cols);
|
||||||
auto p_gradients = GenerateGradients(rows);
|
auto p_gradients = GenerateGradients(rows);
|
||||||
Context ctx;
|
Context ctx;
|
||||||
@ -87,8 +85,7 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
|||||||
std::unique_ptr<DMatrix> sliced{
|
std::unique_ptr<DMatrix> sliced{
|
||||||
p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree{1u, cols};
|
||||||
tree.param.num_feature = cols;
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
param.Init(Args{});
|
param.Init(Args{});
|
||||||
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||||
@ -107,8 +104,7 @@ TEST(GrowHistMaker, ColumnSplit) {
|
|||||||
auto constexpr kRows = 32;
|
auto constexpr kRows = 32;
|
||||||
auto constexpr kCols = 16;
|
auto constexpr kCols = 16;
|
||||||
|
|
||||||
RegTree expected_tree;
|
RegTree expected_tree{1u, kCols};
|
||||||
expected_tree.param.num_feature = kCols;
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
{
|
{
|
||||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||||
|
|||||||
@ -17,8 +17,8 @@ TEST(MultiTargetTree, JsonIO) {
|
|||||||
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId};
|
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId};
|
||||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
||||||
left_weight.HostView(), right_weight.HostView());
|
left_weight.HostView(), right_weight.HostView());
|
||||||
ASSERT_EQ(tree.param.num_nodes, 3);
|
ASSERT_EQ(tree.NumNodes(), 3);
|
||||||
ASSERT_EQ(tree.param.size_leaf_vector, 3);
|
ASSERT_EQ(tree.NumTargets(), 3);
|
||||||
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
||||||
ASSERT_EQ(tree.Size(), 3);
|
ASSERT_EQ(tree.Size(), 3);
|
||||||
|
|
||||||
@ -26,20 +26,19 @@ TEST(MultiTargetTree, JsonIO) {
|
|||||||
tree.SaveModel(&jtree);
|
tree.SaveModel(&jtree);
|
||||||
|
|
||||||
auto check_jtree = [](Json jtree, RegTree const& tree) {
|
auto check_jtree = [](Json jtree, RegTree const& tree) {
|
||||||
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]),
|
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes()));
|
||||||
std::to_string(tree.param.num_nodes));
|
|
||||||
ASSERT_EQ(get<F32Array const>(jtree["base_weights"]).size(),
|
ASSERT_EQ(get<F32Array const>(jtree["base_weights"]).size(),
|
||||||
tree.param.num_nodes * tree.param.size_leaf_vector);
|
tree.NumNodes() * tree.NumTargets());
|
||||||
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.param.num_nodes);
|
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.NumNodes());
|
||||||
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.param.num_nodes);
|
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.NumNodes());
|
||||||
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.param.num_nodes);
|
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.NumNodes());
|
||||||
};
|
};
|
||||||
check_jtree(jtree, tree);
|
check_jtree(jtree, tree);
|
||||||
|
|
||||||
RegTree loaded;
|
RegTree loaded;
|
||||||
loaded.LoadModel(jtree);
|
loaded.LoadModel(jtree);
|
||||||
ASSERT_TRUE(loaded.IsMultiTarget());
|
ASSERT_TRUE(loaded.IsMultiTarget());
|
||||||
ASSERT_EQ(loaded.param.num_nodes, 3);
|
ASSERT_EQ(loaded.NumNodes(), 3);
|
||||||
|
|
||||||
Json jtree1{Object{}};
|
Json jtree1{Object{}};
|
||||||
loaded.SaveModel(&jtree1);
|
loaded.SaveModel(&jtree1);
|
||||||
|
|||||||
@ -32,8 +32,7 @@ TEST(Updater, Prune) {
|
|||||||
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
||||||
|
|
||||||
// prepare tree
|
// prepare tree
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree{1u, kCols};
|
||||||
tree.param.UpdateAllowUnknown(cfg);
|
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
// prepare pruner
|
// prepare pruner
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
|||||||
@ -28,9 +28,8 @@ TEST(Updater, Refresh) {
|
|||||||
{"num_feature", std::to_string(kCols)},
|
{"num_feature", std::to_string(kCols)},
|
||||||
{"reg_lambda", "1"}};
|
{"reg_lambda", "1"}};
|
||||||
|
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree{1u, kCols};
|
||||||
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
||||||
tree.param.UpdateAllowUnknown(cfg);
|
|
||||||
std::vector<RegTree*> trees{&tree};
|
std::vector<RegTree*> trees{&tree};
|
||||||
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
|||||||
@ -11,9 +11,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(Tree, ModelShape) {
|
TEST(Tree, ModelShape) {
|
||||||
bst_feature_t n_features = std::numeric_limits<uint32_t>::max();
|
bst_feature_t n_features = std::numeric_limits<uint32_t>::max();
|
||||||
RegTree tree;
|
RegTree tree{1u, n_features};
|
||||||
tree.param.UpdateAllowUnknown(Args{{"num_feature", std::to_string(n_features)}});
|
ASSERT_EQ(tree.NumFeatures(), n_features);
|
||||||
ASSERT_EQ(tree.param.num_feature, n_features);
|
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/tree.model";
|
const std::string tmp_file = tempdir.path + "/tree.model";
|
||||||
@ -27,7 +26,7 @@ TEST(Tree, ModelShape) {
|
|||||||
RegTree new_tree;
|
RegTree new_tree;
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
|
||||||
new_tree.Load(fi.get());
|
new_tree.Load(fi.get());
|
||||||
ASSERT_EQ(new_tree.param.num_feature, n_features);
|
ASSERT_EQ(new_tree.NumFeatures(), n_features);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// json
|
// json
|
||||||
@ -39,7 +38,7 @@ TEST(Tree, ModelShape) {
|
|||||||
|
|
||||||
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()});
|
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()});
|
||||||
new_tree.LoadModel(j_loaded);
|
new_tree.LoadModel(j_loaded);
|
||||||
ASSERT_EQ(new_tree.param.num_feature, n_features);
|
ASSERT_EQ(new_tree.NumFeatures(), n_features);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// ubjson
|
// ubjson
|
||||||
@ -51,7 +50,7 @@ TEST(Tree, ModelShape) {
|
|||||||
|
|
||||||
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}, std::ios::binary);
|
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}, std::ios::binary);
|
||||||
new_tree.LoadModel(j_loaded);
|
new_tree.LoadModel(j_loaded);
|
||||||
ASSERT_EQ(new_tree.param.num_feature, n_features);
|
ASSERT_EQ(new_tree.NumFeatures(), n_features);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,8 +487,7 @@ TEST(Tree, JsonIO) {
|
|||||||
|
|
||||||
RegTree loaded_tree;
|
RegTree loaded_tree;
|
||||||
loaded_tree.LoadModel(j_tree);
|
loaded_tree.LoadModel(j_tree);
|
||||||
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
|
ASSERT_EQ(loaded_tree.NumNodes(), 3);
|
||||||
|
|
||||||
ASSERT_TRUE(loaded_tree == tree);
|
ASSERT_TRUE(loaded_tree == tree);
|
||||||
|
|
||||||
auto left = tree[0].LeftChild();
|
auto left = tree[0].LeftChild();
|
||||||
|
|||||||
@ -37,8 +37,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
: CreateEmptyGenericParam(Context::kCpuId));
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||||
up->Configure(Args{});
|
up->Configure(Args{});
|
||||||
RegTree tree;
|
RegTree tree{1u, kCols};
|
||||||
tree.param.num_feature = kCols;
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up->Update(¶m, &gpairs_, p_dmat_.get(), position, {&tree});
|
up->Update(¶m, &gpairs_, p_dmat_.get(), position, {&tree});
|
||||||
|
|
||||||
@ -95,16 +94,14 @@ class UpdaterEtaTest : public ::testing::Test {
|
|||||||
param1.Init(Args{{"eta", "1.0"}});
|
param1.Init(Args{{"eta", "1.0"}});
|
||||||
|
|
||||||
for (size_t iter = 0; iter < 4; ++iter) {
|
for (size_t iter = 0; iter < 4; ++iter) {
|
||||||
RegTree tree_0;
|
RegTree tree_0{1u, kCols};
|
||||||
{
|
{
|
||||||
tree_0.param.num_feature = kCols;
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up_0->Update(¶m0, &gpairs_, p_dmat_.get(), position, {&tree_0});
|
up_0->Update(¶m0, &gpairs_, p_dmat_.get(), position, {&tree_0});
|
||||||
}
|
}
|
||||||
|
|
||||||
RegTree tree_1;
|
RegTree tree_1{1u, kCols};
|
||||||
{
|
{
|
||||||
tree_1.param.num_feature = kCols;
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
up_1->Update(¶m1, &gpairs_, p_dmat_.get(), position, {&tree_1});
|
up_1->Update(¶m1, &gpairs_, p_dmat_.get(), position, {&tree_1});
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user