[Breaking] Remove num roots. (#5059)

This commit is contained in:
Jiaming Yuan
2019-12-05 21:58:43 +08:00
committed by GitHub
parent f3d8536702
commit 64af1ecf86
23 changed files with 87 additions and 189 deletions

View File

@@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats)
};
for (int32_t i = 0; i < param.num_roots; ++i) {
builder->BuildTree(*this);
}
builder->BuildTree(*this);
std::string result = builder->Str();
return result;
@@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
sizeof(RTreeNodeStat) * stats_.size());
// chg deleted nodes
deleted_nodes_.resize(0);
for (int i = param.num_roots; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
for (int i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
}
@@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() {
return;
}
this->node_mean_values_.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
this->FillNodeMeanValue(root_id);
}
this->FillNodeMeanValue(0);
}
bst_float RegTree::FillNodeMeanValue(int nid) {
@@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) {
}
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
auto pid = static_cast<int>(root_id);
// update bias value
bst_float node_value = this->node_mean_values_[pid];
bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value;
if ((*this)[pid].IsLeaf()) {
if ((*this)[0].IsLeaf()) {
// nothing to do anymore
return;
}
while (!(*this)[pid].IsLeaf()) {
split_index = (*this)[pid].SplitIndex();
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[pid];
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
}
bst_float leaf_value = (*this)[pid].LeafValue();
bst_float leaf_value = (*this)[nid].LeafValue();
// update leaf feature weight
out_contribs[split_index] += leaf_value - node_value;
}
@@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
}
void RegTree::CalculateContributions(const RegTree::FVec &feat,
unsigned root_id, bst_float *out_contribs,
bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value;
}
// Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id) + 2;
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
const int maxd = this->MaxDepth(0) + 2;
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
1, 1, -1, condition, condition_feature, 1);
delete[] unique_path_data;
}
} // namespace xgboost