[Breaking] Remove num roots. (#5059)
This commit is contained in:
@@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
std::unique_ptr<TreeGenerator> builder {
|
||||
TreeGenerator::Create(format, fmap, with_stats)
|
||||
};
|
||||
for (int32_t i = 0; i < param.num_roots; ++i) {
|
||||
builder->BuildTree(*this);
|
||||
}
|
||||
builder->BuildTree(*this);
|
||||
|
||||
std::string result = builder->Str();
|
||||
return result;
|
||||
@@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
|
||||
sizeof(RTreeNodeStat) * stats_.size());
|
||||
// chg deleted nodes
|
||||
deleted_nodes_.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
|
||||
for (int i = 1; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) {
|
||||
deleted_nodes_.push_back(i);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
@@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() {
|
||||
return;
|
||||
}
|
||||
this->node_mean_values_.resize(num_nodes);
|
||||
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
|
||||
this->FillNodeMeanValue(root_id);
|
||||
}
|
||||
this->FillNodeMeanValue(0);
|
||||
}
|
||||
|
||||
bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
@@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
unsigned root_id,
|
||||
bst_float *out_contribs) const {
|
||||
CHECK_GT(this->node_mean_values_.size(), 0U);
|
||||
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||
unsigned split_index = 0;
|
||||
auto pid = static_cast<int>(root_id);
|
||||
// update bias value
|
||||
bst_float node_value = this->node_mean_values_[pid];
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
if ((*this)[pid].IsLeaf()) {
|
||||
if ((*this)[0].IsLeaf()) {
|
||||
// nothing to do anymore
|
||||
return;
|
||||
}
|
||||
while (!(*this)[pid].IsLeaf()) {
|
||||
split_index = (*this)[pid].SplitIndex();
|
||||
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[pid];
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
node_value = new_value;
|
||||
}
|
||||
bst_float leaf_value = (*this)[pid].LeafValue();
|
||||
bst_float leaf_value = (*this)[nid].LeafValue();
|
||||
// update leaf feature weight
|
||||
out_contribs[split_index] += leaf_value - node_value;
|
||||
}
|
||||
@@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
||||
unsigned root_id, bst_float *out_contribs,
|
||||
bst_float *out_contribs,
|
||||
int condition,
|
||||
unsigned condition_feature) const {
|
||||
// find the expected value of the tree's predictions
|
||||
if (condition == 0) {
|
||||
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
}
|
||||
|
||||
// Preallocate space for the unique path data
|
||||
const int maxd = this->MaxDepth(root_id) + 2;
|
||||
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
|
||||
const int maxd = this->MaxDepth(0) + 2;
|
||||
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
|
||||
|
||||
TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
|
||||
TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
|
||||
1, 1, -1, condition, condition_feature, 1);
|
||||
delete[] unique_path_data;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -128,21 +128,10 @@ class BaseMaker: public TreeUpdater {
|
||||
inline void InitData(const std::vector<GradientPair> &gpair,
|
||||
const DMatrix &fmat,
|
||||
const RegTree &tree) {
|
||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||
<< "TreeMaker: can only grow new tree";
|
||||
const std::vector<unsigned> &root_index = fmat.Info().root_index_;
|
||||
{
|
||||
// setup position
|
||||
position_.resize(gpair.size());
|
||||
if (root_index.size() == 0) {
|
||||
std::fill(position_.begin(), position_.end(), 0);
|
||||
} else {
|
||||
for (size_t i = 0; i < position_.size(); ++i) {
|
||||
position_[i] = root_index[i];
|
||||
CHECK_LT(root_index[i], (unsigned)tree.param.num_roots)
|
||||
<< "root index exceed setting";
|
||||
}
|
||||
}
|
||||
std::fill(position_.begin(), position_.end(), 0);
|
||||
// mark delete for the deleted datas
|
||||
for (size_t i = 0; i < position_.size(); ++i) {
|
||||
if (gpair[i].GetHess() < 0.0f) position_[i] = ~position_[i];
|
||||
@@ -160,9 +149,7 @@ class BaseMaker: public TreeUpdater {
|
||||
{
|
||||
// expand query
|
||||
qexpand_.reserve(256); qexpand_.clear();
|
||||
for (int i = 0; i < tree.param.num_roots; ++i) {
|
||||
qexpand_.push_back(i);
|
||||
}
|
||||
qexpand_.push_back(0);
|
||||
this->UpdateNode2WorkIndex(tree);
|
||||
}
|
||||
this->interaction_constraints_.Configure(param_, fmat.Info().num_col_);
|
||||
|
||||
@@ -146,21 +146,11 @@ class ColMaker: public TreeUpdater {
|
||||
inline void InitData(const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||
<< "ColMaker: can only grow new tree";
|
||||
const std::vector<unsigned>& root_index = fmat.Info().root_index_;
|
||||
{
|
||||
// setup position
|
||||
position_.resize(gpair.size());
|
||||
CHECK_EQ(fmat.Info().num_row_, position_.size());
|
||||
if (root_index.size() == 0) {
|
||||
std::fill(position_.begin(), position_.end(), 0);
|
||||
} else {
|
||||
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
|
||||
position_[ridx] = root_index[ridx];
|
||||
CHECK_LT(root_index[ridx], (unsigned)tree.param.num_roots);
|
||||
}
|
||||
}
|
||||
std::fill(position_.begin(), position_.end(), 0);
|
||||
// mark delete for the deleted datas
|
||||
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
|
||||
if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx];
|
||||
@@ -192,9 +182,7 @@ class ColMaker: public TreeUpdater {
|
||||
{
|
||||
// expand query
|
||||
qexpand_.reserve(256); qexpand_.clear();
|
||||
for (int i = 0; i < tree.param.num_roots; ++i) {
|
||||
qexpand_.push_back(i);
|
||||
}
|
||||
qexpand_.push_back(0);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
|
||||
@@ -119,10 +119,7 @@ class HistMaker: public BaseMaker {
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitWorkSet(p_fmat, *p_tree, &selected_features_);
|
||||
// mark root node as fresh.
|
||||
for (int i = 0; i < p_tree->param.num_roots; ++i) {
|
||||
(*p_tree)[i].SetLeaf(0.0f, 0);
|
||||
}
|
||||
CHECK_EQ(p_tree->param.num_roots, 1) << "Support for num roots is removed.";
|
||||
(*p_tree)[0].SetLeaf(0.0f, 0);
|
||||
|
||||
for (int depth = 0; depth < param_.max_depth; ++depth) {
|
||||
// reset and propose candidate split
|
||||
|
||||
@@ -75,7 +75,7 @@ class TreePruner: public TreeUpdater {
|
||||
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
|
||||
LOG(INFO) << "tree pruning end, "
|
||||
<< tree.NumExtraNodes() << " extra nodes, " << npruned
|
||||
<< " pruned nodes, max_depth=" << tree.MaxDepth();
|
||||
}
|
||||
|
||||
@@ -255,18 +255,15 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
unsigned timestamp = 0;
|
||||
int num_leaves = 0;
|
||||
|
||||
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true);
|
||||
hist_.AddHistRow(0);
|
||||
BuildHist(gpair_h, row_set_collection_[0], gmat, gmatb, hist_[0], true);
|
||||
|
||||
this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
this->InitNewNode(0, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
|
||||
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree);
|
||||
qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid),
|
||||
snode_[nid].best.loss_chg,
|
||||
timestamp++));
|
||||
++num_leaves;
|
||||
}
|
||||
this->EvaluateSplit(0, gmat, hist_, *p_fmat, *p_tree);
|
||||
qexpand_loss_guided_->push(ExpandEntry(0, p_tree->GetDepth(0),
|
||||
snode_[0].best.loss_chg, timestamp++));
|
||||
++num_leaves;
|
||||
|
||||
while (!qexpand_loss_guided_->empty()) {
|
||||
const ExpandEntry candidate = qexpand_loss_guided_->top();
|
||||
@@ -397,8 +394,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||
<< "ColMakerHist: can only grow new tree";
|
||||
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
|
||||
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
|
||||
<< "at least one should be a positive quantity.";
|
||||
@@ -425,7 +420,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
}
|
||||
hist_builder_.Init(this->nthread_, nbins);
|
||||
|
||||
CHECK_EQ(info.root_index_.size(), 0U);
|
||||
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
|
||||
row_indices.resize(info.num_row_);
|
||||
auto* p_row_indices = row_indices.data();
|
||||
|
||||
@@ -90,9 +90,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
param_.learning_rate = lr / trees.size();
|
||||
int offset = 0;
|
||||
for (auto tree : trees) {
|
||||
for (int rid = 0; rid < tree->param.num_roots; ++rid) {
|
||||
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, tree);
|
||||
}
|
||||
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
||||
offset += tree->param.num_nodes;
|
||||
}
|
||||
// set learning rate back
|
||||
@@ -107,7 +105,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
const bst_uint ridx,
|
||||
GradStats *gstats) {
|
||||
// start from groups that belongs to current data
|
||||
auto pid = static_cast<int>(info.GetRoot(ridx));
|
||||
auto pid = 0;
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
// tranverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
|
||||
Reference in New Issue
Block a user