[Breaking] Remove num roots. (#5059)

This commit is contained in:
Jiaming Yuan
2019-12-05 21:58:43 +08:00
committed by GitHub
parent f3d8536702
commit 64af1ecf86
23 changed files with 87 additions and 189 deletions

View File

@@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats)
};
for (int32_t i = 0; i < param.num_roots; ++i) {
builder->BuildTree(*this);
}
builder->BuildTree(*this);
std::string result = builder->Str();
return result;
@@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
sizeof(RTreeNodeStat) * stats_.size());
// chg deleted nodes
deleted_nodes_.resize(0);
for (int i = param.num_roots; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
for (int i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
}
@@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() {
return;
}
this->node_mean_values_.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
this->FillNodeMeanValue(root_id);
}
this->FillNodeMeanValue(0);
}
bst_float RegTree::FillNodeMeanValue(int nid) {
@@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) {
}
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
auto pid = static_cast<int>(root_id);
// update bias value
bst_float node_value = this->node_mean_values_[pid];
bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value;
if ((*this)[pid].IsLeaf()) {
if ((*this)[0].IsLeaf()) {
// nothing to do anymore
return;
}
while (!(*this)[pid].IsLeaf()) {
split_index = (*this)[pid].SplitIndex();
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[pid];
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
}
bst_float leaf_value = (*this)[pid].LeafValue();
bst_float leaf_value = (*this)[nid].LeafValue();
// update leaf feature weight
out_contribs[split_index] += leaf_value - node_value;
}
@@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
}
void RegTree::CalculateContributions(const RegTree::FVec &feat,
unsigned root_id, bst_float *out_contribs,
bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value;
}
// Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id) + 2;
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
const int maxd = this->MaxDepth(0) + 2;
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
1, 1, -1, condition, condition_feature, 1);
delete[] unique_path_data;
}
} // namespace xgboost

View File

@@ -128,21 +128,10 @@ class BaseMaker: public TreeUpdater {
inline void InitData(const std::vector<GradientPair> &gpair,
const DMatrix &fmat,
const RegTree &tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "TreeMaker: can only grow new tree";
const std::vector<unsigned> &root_index = fmat.Info().root_index_;
{
// setup position
position_.resize(gpair.size());
if (root_index.size() == 0) {
std::fill(position_.begin(), position_.end(), 0);
} else {
for (size_t i = 0; i < position_.size(); ++i) {
position_[i] = root_index[i];
CHECK_LT(root_index[i], (unsigned)tree.param.num_roots)
<< "root index exceed setting";
}
}
std::fill(position_.begin(), position_.end(), 0);
// mark delete for the deleted datas
for (size_t i = 0; i < position_.size(); ++i) {
if (gpair[i].GetHess() < 0.0f) position_[i] = ~position_[i];
@@ -160,9 +149,7 @@ class BaseMaker: public TreeUpdater {
{
// expand query
qexpand_.reserve(256); qexpand_.clear();
for (int i = 0; i < tree.param.num_roots; ++i) {
qexpand_.push_back(i);
}
qexpand_.push_back(0);
this->UpdateNode2WorkIndex(tree);
}
this->interaction_constraints_.Configure(param_, fmat.Info().num_col_);

View File

@@ -146,21 +146,11 @@ class ColMaker: public TreeUpdater {
inline void InitData(const std::vector<GradientPair>& gpair,
const DMatrix& fmat,
const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMaker: can only grow new tree";
const std::vector<unsigned>& root_index = fmat.Info().root_index_;
{
// setup position
position_.resize(gpair.size());
CHECK_EQ(fmat.Info().num_row_, position_.size());
if (root_index.size() == 0) {
std::fill(position_.begin(), position_.end(), 0);
} else {
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
position_[ridx] = root_index[ridx];
CHECK_LT(root_index[ridx], (unsigned)tree.param.num_roots);
}
}
std::fill(position_.begin(), position_.end(), 0);
// mark delete for the deleted datas
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx];
@@ -192,9 +182,7 @@ class ColMaker: public TreeUpdater {
{
// expand query
qexpand_.reserve(256); qexpand_.clear();
for (int i = 0; i < tree.param.num_roots; ++i) {
qexpand_.push_back(i);
}
qexpand_.push_back(0);
}
}
/*!

View File

@@ -119,10 +119,7 @@ class HistMaker: public BaseMaker {
this->InitData(gpair, *p_fmat, *p_tree);
this->InitWorkSet(p_fmat, *p_tree, &selected_features_);
// mark root node as fresh.
for (int i = 0; i < p_tree->param.num_roots; ++i) {
(*p_tree)[i].SetLeaf(0.0f, 0);
}
CHECK_EQ(p_tree->param.num_roots, 1) << "Support for num roots is removed.";
(*p_tree)[0].SetLeaf(0.0f, 0);
for (int depth = 0; depth < param_.max_depth; ++depth) {
// reset and propose candidate split

View File

@@ -75,7 +75,7 @@ class TreePruner: public TreeUpdater {
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
}
}
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
LOG(INFO) << "tree pruning end, "
<< tree.NumExtraNodes() << " extra nodes, " << npruned
<< " pruned nodes, max_depth=" << tree.MaxDepth();
}

View File

@@ -255,18 +255,15 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
unsigned timestamp = 0;
int num_leaves = 0;
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
hist_.AddHistRow(nid);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true);
hist_.AddHistRow(0);
BuildHist(gpair_h, row_set_collection_[0], gmat, gmatb, hist_[0], true);
this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree);
this->InitNewNode(0, gmat, gpair_h, *p_fmat, *p_tree);
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree);
qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid),
snode_[nid].best.loss_chg,
timestamp++));
++num_leaves;
}
this->EvaluateSplit(0, gmat, hist_, *p_fmat, *p_tree);
qexpand_loss_guided_->push(ExpandEntry(0, p_tree->GetDepth(0),
snode_[0].best.loss_chg, timestamp++));
++num_leaves;
while (!qexpand_loss_guided_->empty()) {
const ExpandEntry candidate = qexpand_loss_guided_->top();
@@ -397,8 +394,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
const DMatrix& fmat,
const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMakerHist: can only grow new tree";
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
<< "at least one should be a positive quantity.";
@@ -425,7 +420,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
}
hist_builder_.Init(this->nthread_, nbins);
CHECK_EQ(info.root_index_.size(), 0U);
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
row_indices.resize(info.num_row_);
auto* p_row_indices = row_indices.data();

View File

@@ -90,9 +90,7 @@ class TreeRefresher: public TreeUpdater {
param_.learning_rate = lr / trees.size();
int offset = 0;
for (auto tree : trees) {
for (int rid = 0; rid < tree->param.num_roots; ++rid) {
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, tree);
}
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
offset += tree->param.num_nodes;
}
// set learning rate back
@@ -107,7 +105,7 @@ class TreeRefresher: public TreeUpdater {
const bst_uint ridx,
GradStats *gstats) {
// start from groups that belongs to current data
auto pid = static_cast<int>(info.GetRoot(ridx));
auto pid = 0;
gstats[pid].Add(gpair[ridx]);
// tranverse tree
while (!tree[pid].IsLeaf()) {