[Breaking] Remove num roots. (#5059)
This commit is contained in:
@@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
std::unique_ptr<TreeGenerator> builder {
|
||||
TreeGenerator::Create(format, fmap, with_stats)
|
||||
};
|
||||
for (int32_t i = 0; i < param.num_roots; ++i) {
|
||||
builder->BuildTree(*this);
|
||||
}
|
||||
builder->BuildTree(*this);
|
||||
|
||||
std::string result = builder->Str();
|
||||
return result;
|
||||
@@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
|
||||
sizeof(RTreeNodeStat) * stats_.size());
|
||||
// chg deleted nodes
|
||||
deleted_nodes_.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
|
||||
for (int i = 1; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) {
|
||||
deleted_nodes_.push_back(i);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
@@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() {
|
||||
return;
|
||||
}
|
||||
this->node_mean_values_.resize(num_nodes);
|
||||
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
|
||||
this->FillNodeMeanValue(root_id);
|
||||
}
|
||||
this->FillNodeMeanValue(0);
|
||||
}
|
||||
|
||||
bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
@@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
unsigned root_id,
|
||||
bst_float *out_contribs) const {
|
||||
CHECK_GT(this->node_mean_values_.size(), 0U);
|
||||
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||
unsigned split_index = 0;
|
||||
auto pid = static_cast<int>(root_id);
|
||||
// update bias value
|
||||
bst_float node_value = this->node_mean_values_[pid];
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
if ((*this)[pid].IsLeaf()) {
|
||||
if ((*this)[0].IsLeaf()) {
|
||||
// nothing to do anymore
|
||||
return;
|
||||
}
|
||||
while (!(*this)[pid].IsLeaf()) {
|
||||
split_index = (*this)[pid].SplitIndex();
|
||||
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[pid];
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
node_value = new_value;
|
||||
}
|
||||
bst_float leaf_value = (*this)[pid].LeafValue();
|
||||
bst_float leaf_value = (*this)[nid].LeafValue();
|
||||
// update leaf feature weight
|
||||
out_contribs[split_index] += leaf_value - node_value;
|
||||
}
|
||||
@@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
||||
unsigned root_id, bst_float *out_contribs,
|
||||
bst_float *out_contribs,
|
||||
int condition,
|
||||
unsigned condition_feature) const {
|
||||
// find the expected value of the tree's predictions
|
||||
if (condition == 0) {
|
||||
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
}
|
||||
|
||||
// Preallocate space for the unique path data
|
||||
const int maxd = this->MaxDepth(root_id) + 2;
|
||||
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
|
||||
const int maxd = this->MaxDepth(0) + 2;
|
||||
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
|
||||
|
||||
TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
|
||||
TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
|
||||
1, 1, -1, condition, condition_feature, 1);
|
||||
delete[] unique_path_data;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user