Combine TreeModel and RegTree (#3995)
This commit is contained in:
parent
c055a32609
commit
84c99f86f4
File diff suppressed because it is too large
Load Diff
@ -272,7 +272,6 @@ class GBTree : public GradientBooster {
|
|||||||
// create new tree
|
// create new tree
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||||
ptr->param.InitAllowUnknown(this->cfg_);
|
ptr->param.InitAllowUnknown(this->cfg_);
|
||||||
ptr->InitModel();
|
|
||||||
new_trees.push_back(ptr.get());
|
new_trees.push_back(ptr.get());
|
||||||
ret->push_back(std::move(ptr));
|
ret->push_back(std::move(ptr));
|
||||||
} else if (tparam_.process_type == kUpdate) {
|
} else if (tparam_.process_type == kUpdate) {
|
||||||
|
|||||||
@ -169,4 +169,240 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
|||||||
}
|
}
|
||||||
return fo.str();
|
return fo.str();
|
||||||
}
|
}
|
||||||
|
void RegTree::FillNodeMeanValues() {
|
||||||
|
size_t num_nodes = this->param.num_nodes;
|
||||||
|
if (this->node_mean_values_.size() == num_nodes) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this->node_mean_values_.resize(num_nodes);
|
||||||
|
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
|
||||||
|
this->FillNodeMeanValue(root_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||||
|
bst_float result;
|
||||||
|
auto& node = (*this)[nid];
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
result = node.LeafValue();
|
||||||
|
} else {
|
||||||
|
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
|
||||||
|
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
|
||||||
|
result /= this->Stat(nid).sum_hess;
|
||||||
|
}
|
||||||
|
this->node_mean_values_[nid] = result;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||||
|
unsigned root_id,
|
||||||
|
bst_float *out_contribs) const {
|
||||||
|
CHECK_GT(this->node_mean_values_.size(), 0U);
|
||||||
|
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||||
|
unsigned split_index = 0;
|
||||||
|
auto pid = static_cast<int>(root_id);
|
||||||
|
// update bias value
|
||||||
|
bst_float node_value = this->node_mean_values_[pid];
|
||||||
|
out_contribs[feat.Size()] += node_value;
|
||||||
|
if ((*this)[pid].IsLeaf()) {
|
||||||
|
// nothing to do anymore
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
while (!(*this)[pid].IsLeaf()) {
|
||||||
|
split_index = (*this)[pid].SplitIndex();
|
||||||
|
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||||
|
bst_float new_value = this->node_mean_values_[pid];
|
||||||
|
// update feature weight
|
||||||
|
out_contribs[split_index] += new_value - node_value;
|
||||||
|
node_value = new_value;
|
||||||
|
}
|
||||||
|
bst_float leaf_value = (*this)[pid].LeafValue();
|
||||||
|
// update leaf feature weight
|
||||||
|
out_contribs[split_index] += leaf_value - node_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by TreeShap
|
||||||
|
// data we keep about our decision path
|
||||||
|
// note that pweight is included for convenience and is not tied with the other attributes
|
||||||
|
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
|
||||||
|
struct PathElement {
|
||||||
|
int feature_index;
|
||||||
|
bst_float zero_fraction;
|
||||||
|
bst_float one_fraction;
|
||||||
|
bst_float pweight;
|
||||||
|
PathElement() = default;
|
||||||
|
PathElement(int i, bst_float z, bst_float o, bst_float w) :
|
||||||
|
feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// extend our decision path with a fraction of one and zero extensions
|
||||||
|
void ExtendPath(PathElement *unique_path, unsigned unique_depth,
|
||||||
|
bst_float zero_fraction, bst_float one_fraction,
|
||||||
|
int feature_index) {
|
||||||
|
unique_path[unique_depth].feature_index = feature_index;
|
||||||
|
unique_path[unique_depth].zero_fraction = zero_fraction;
|
||||||
|
unique_path[unique_depth].one_fraction = one_fraction;
|
||||||
|
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
|
||||||
|
for (int i = unique_depth - 1; i >= 0; i--) {
|
||||||
|
unique_path[i+1].pweight += one_fraction * unique_path[i].pweight * (i + 1)
|
||||||
|
/ static_cast<bst_float>(unique_depth + 1);
|
||||||
|
unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i)
|
||||||
|
/ static_cast<bst_float>(unique_depth + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// undo a previous extension of the decision path
|
||||||
|
void UnwindPath(PathElement *unique_path, unsigned unique_depth,
|
||||||
|
unsigned path_index) {
|
||||||
|
const bst_float one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
bst_float next_one_portion = unique_path[unique_depth].pweight;
|
||||||
|
|
||||||
|
for (int i = unique_depth - 1; i >= 0; --i) {
|
||||||
|
if (one_fraction != 0) {
|
||||||
|
const bst_float tmp = unique_path[i].pweight;
|
||||||
|
unique_path[i].pweight = next_one_portion * (unique_depth + 1)
|
||||||
|
/ static_cast<bst_float>((i + 1) * one_fraction);
|
||||||
|
next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i)
|
||||||
|
/ static_cast<bst_float>(unique_depth + 1);
|
||||||
|
} else {
|
||||||
|
unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1))
|
||||||
|
/ static_cast<bst_float>(zero_fraction * (unique_depth - i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = path_index; i < unique_depth; ++i) {
|
||||||
|
unique_path[i].feature_index = unique_path[i+1].feature_index;
|
||||||
|
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
|
||||||
|
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine what the total permuation weight would be if
|
||||||
|
// we unwound a previous extension in the decision path
|
||||||
|
bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
|
||||||
|
unsigned path_index) {
|
||||||
|
const bst_float one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
bst_float next_one_portion = unique_path[unique_depth].pweight;
|
||||||
|
bst_float total = 0;
|
||||||
|
for (int i = unique_depth - 1; i >= 0; --i) {
|
||||||
|
if (one_fraction != 0) {
|
||||||
|
const bst_float tmp = next_one_portion * (unique_depth + 1)
|
||||||
|
/ static_cast<bst_float>((i + 1) * one_fraction);
|
||||||
|
total += tmp;
|
||||||
|
next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i)
|
||||||
|
/ static_cast<bst_float>(unique_depth + 1));
|
||||||
|
} else {
|
||||||
|
total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
|
||||||
|
/ static_cast<bst_float>(unique_depth + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
// recursive computation of SHAP values for a decision tree
|
||||||
|
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||||
|
unsigned node_index, unsigned unique_depth,
|
||||||
|
PathElement *parent_unique_path,
|
||||||
|
bst_float parent_zero_fraction,
|
||||||
|
bst_float parent_one_fraction, int parent_feature_index,
|
||||||
|
int condition, unsigned condition_feature,
|
||||||
|
bst_float condition_fraction) const {
|
||||||
|
const auto node = (*this)[node_index];
|
||||||
|
|
||||||
|
// stop if we have no weight coming down to us
|
||||||
|
if (condition_fraction == 0) return;
|
||||||
|
|
||||||
|
// extend the unique path
|
||||||
|
PathElement *unique_path = parent_unique_path + unique_depth + 1;
|
||||||
|
std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
|
||||||
|
|
||||||
|
if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
|
||||||
|
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
|
||||||
|
parent_one_fraction, parent_feature_index);
|
||||||
|
}
|
||||||
|
const unsigned split_index = node.SplitIndex();
|
||||||
|
|
||||||
|
// leaf node
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
for (unsigned i = 1; i <= unique_depth; ++i) {
|
||||||
|
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
|
||||||
|
const PathElement &el = unique_path[i];
|
||||||
|
phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction)
|
||||||
|
* node.LeafValue() * condition_fraction;
|
||||||
|
}
|
||||||
|
|
||||||
|
// internal node
|
||||||
|
} else {
|
||||||
|
// find which branch is "hot" (meaning x would follow it)
|
||||||
|
unsigned hot_index = 0;
|
||||||
|
if (feat.IsMissing(split_index)) {
|
||||||
|
hot_index = node.DefaultChild();
|
||||||
|
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
|
||||||
|
hot_index = node.LeftChild();
|
||||||
|
} else {
|
||||||
|
hot_index = node.RightChild();
|
||||||
|
}
|
||||||
|
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
|
||||||
|
node.RightChild() : node.LeftChild());
|
||||||
|
const bst_float w = this->Stat(node_index).sum_hess;
|
||||||
|
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
|
||||||
|
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
|
||||||
|
bst_float incoming_zero_fraction = 1;
|
||||||
|
bst_float incoming_one_fraction = 1;
|
||||||
|
|
||||||
|
// see if we have already split on this feature,
|
||||||
|
// if so we undo that split so we can redo it for this node
|
||||||
|
unsigned path_index = 0;
|
||||||
|
for (; path_index <= unique_depth; ++path_index) {
|
||||||
|
if (static_cast<unsigned>(unique_path[path_index].feature_index) == split_index) break;
|
||||||
|
}
|
||||||
|
if (path_index != unique_depth + 1) {
|
||||||
|
incoming_zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
incoming_one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
UnwindPath(unique_path, unique_depth, path_index);
|
||||||
|
unique_depth -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// divide up the condition_fraction among the recursive calls
|
||||||
|
bst_float hot_condition_fraction = condition_fraction;
|
||||||
|
bst_float cold_condition_fraction = condition_fraction;
|
||||||
|
if (condition > 0 && split_index == condition_feature) {
|
||||||
|
cold_condition_fraction = 0;
|
||||||
|
unique_depth -= 1;
|
||||||
|
} else if (condition < 0 && split_index == condition_feature) {
|
||||||
|
hot_condition_fraction *= hot_zero_fraction;
|
||||||
|
cold_condition_fraction *= cold_zero_fraction;
|
||||||
|
unique_depth -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path,
|
||||||
|
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
|
||||||
|
split_index, condition, condition_feature, hot_condition_fraction);
|
||||||
|
|
||||||
|
TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path,
|
||||||
|
cold_zero_fraction * incoming_zero_fraction, 0,
|
||||||
|
split_index, condition, condition_feature, cold_condition_fraction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
||||||
|
unsigned root_id, bst_float *out_contribs,
|
||||||
|
int condition,
|
||||||
|
unsigned condition_feature) const {
|
||||||
|
// find the expected value of the tree's predictions
|
||||||
|
if (condition == 0) {
|
||||||
|
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
|
||||||
|
out_contribs[feat.Size()] += node_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preallocate space for the unique path data
|
||||||
|
const int maxd = this->MaxDepth(root_id) + 2;
|
||||||
|
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
|
||||||
|
|
||||||
|
TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
|
||||||
|
1, 1, -1, condition, condition_feature, 1);
|
||||||
|
delete[] unique_path_data;
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -48,7 +48,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
|
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
|
||||||
if (tree[nid].IsRoot()) return npruned;
|
if (tree[nid].IsRoot()) return npruned;
|
||||||
int pid = tree[nid].Parent();
|
int pid = tree[nid].Parent();
|
||||||
RegTree::NodeStat &s = tree.Stat(pid);
|
RTreeNodeStat &s = tree.Stat(pid);
|
||||||
++s.leaf_child_cnt;
|
++s.leaf_child_cnt;
|
||||||
if (s.leaf_child_cnt >= 2 && param_.NeedPrune(s.loss_chg, depth - 1)) {
|
if (s.leaf_child_cnt >= 2 && param_.NeedPrune(s.loss_chg, depth - 1)) {
|
||||||
// need to be pruned
|
// need to be pruned
|
||||||
|
|||||||
@ -10,7 +10,6 @@ TEST(cpu_predictor, Test) {
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||||
trees.back()->InitModel();
|
|
||||||
(*trees.back())[0].SetLeaf(1.5f);
|
(*trees.back())[0].SetLeaf(1.5f);
|
||||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||||
gbm::GBTreeModel model(0.5);
|
gbm::GBTreeModel model(0.5);
|
||||||
|
|||||||
@ -35,7 +35,6 @@ TEST(gpu_predictor, Test) {
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||||
trees.back()->InitModel();
|
|
||||||
(*trees.back())[0].SetLeaf(1.5f);
|
(*trees.back())[0].SetLeaf(1.5f);
|
||||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||||
gbm::GBTreeModel model(0.5);
|
gbm::GBTreeModel model(0.5);
|
||||||
@ -181,7 +180,6 @@ TEST(gpu_predictor, MGPU_Test) {
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||||
trees.back()->InitModel();
|
|
||||||
(*trees.back())[0].SetLeaf(1.5f);
|
(*trees.back())[0].SetLeaf(1.5f);
|
||||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||||
gbm::GBTreeModel model(0.5);
|
gbm::GBTreeModel model(0.5);
|
||||||
|
|||||||
@ -291,8 +291,6 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
false);
|
false);
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.InitModel();
|
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
info.num_row_ = n_rows;
|
info.num_row_ = n_rows;
|
||||||
info.num_col_ = n_cols;
|
info.num_col_ = n_cols;
|
||||||
@ -339,7 +337,6 @@ TEST(GpuHist, ApplySplit) {
|
|||||||
// Initialize GPUHistMaker
|
// Initialize GPUHistMaker
|
||||||
hist_maker.param_ = param;
|
hist_maker.param_ = param;
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.InitModel();
|
|
||||||
|
|
||||||
DeviceSplitCandidate candidate;
|
DeviceSplitCandidate candidate;
|
||||||
candidate.Update(2, kLeftDir,
|
candidate.Update(2, kLeftDir,
|
||||||
|
|||||||
@ -31,7 +31,6 @@ TEST(Updater, Prune) {
|
|||||||
|
|
||||||
// prepare tree
|
// prepare tree
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.InitModel();
|
|
||||||
tree.param.InitAllowUnknown(cfg);
|
tree.param.InitAllowUnknown(cfg);
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
// prepare pruner
|
// prepare pruner
|
||||||
|
|||||||
@ -122,7 +122,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
gmat.Init((*dmat).get(), max_bins);
|
gmat.Init((*dmat).get(), max_bins);
|
||||||
|
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.InitModel();
|
|
||||||
tree.param.InitAllowUnknown(cfg);
|
tree.param.InitAllowUnknown(cfg);
|
||||||
|
|
||||||
std::vector<GradientPair> gpair =
|
std::vector<GradientPair> gpair =
|
||||||
@ -134,7 +133,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
|
|
||||||
void TestBuildHist() {
|
void TestBuildHist() {
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.InitModel();
|
|
||||||
tree.param.InitAllowUnknown(cfg);
|
tree.param.InitAllowUnknown(cfg);
|
||||||
|
|
||||||
size_t constexpr max_bins = 4;
|
size_t constexpr max_bins = 4;
|
||||||
@ -146,7 +144,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
|
|
||||||
void TestEvaluateSplit() {
|
void TestEvaluateSplit() {
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.InitModel();
|
|
||||||
tree.param.InitAllowUnknown(cfg);
|
tree.param.InitAllowUnknown(cfg);
|
||||||
|
|
||||||
builder_->TestEvaluateSplit(gmatb_, tree);
|
builder_->TestEvaluateSplit(gmatb_, tree);
|
||||||
|
|||||||
@ -25,7 +25,6 @@ TEST(Updater, Refresh) {
|
|||||||
{"reg_lambda", "1"}};
|
{"reg_lambda", "1"}};
|
||||||
|
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.InitModel();
|
|
||||||
tree.param.InitAllowUnknown(cfg);
|
tree.param.InitAllowUnknown(cfg);
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user