Merge lossgude and depthwise strategies for CPU hist (#7007)
* fix java/scala test: max depth is also valid parameter for lossguide Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
parent
ee4f51a631
commit
57c732655e
@ -387,7 +387,7 @@ public class BoosterImplTest {
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 0);
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("tree_method", "hist");
|
||||
@ -408,7 +408,7 @@ public class BoosterImplTest {
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 0);
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("tree_method", "hist");
|
||||
|
||||
@ -159,7 +159,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
test("test with quantile histo lossguide") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
|
||||
trainBoosterWithQuantileHisto(trainMat, Map("training" -> trainMat, "test" -> testMat),
|
||||
@ -169,7 +169,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
test("test with quantile histo lossguide with max bin") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "auc").toMap
|
||||
|
||||
92
src/tree/driver.h
Normal file
92
src/tree/driver.h
Normal file
@ -0,0 +1,92 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_DRIVER_H_
|
||||
#define XGBOOST_TREE_DRIVER_H_
|
||||
#include <xgboost/span.h>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include "./param.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename ExpandEntryT>
|
||||
inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
||||
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth
|
||||
}
|
||||
|
||||
template <typename ExpandEntryT>
|
||||
inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
||||
if (lhs.GetLossChange() == rhs.GetLossChange()) {
|
||||
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
|
||||
} else {
|
||||
return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
|
||||
}
|
||||
}
|
||||
|
||||
// Drives execution of tree building on device
|
||||
template <typename ExpandEntryT>
|
||||
class Driver {
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
|
||||
std::function<bool(ExpandEntryT, ExpandEntryT)>>;
|
||||
|
||||
public:
|
||||
explicit Driver(TrainParam::TreeGrowPolicy policy)
|
||||
: policy_(policy),
|
||||
queue_(policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT> :
|
||||
LossGuide<ExpandEntryT>) {}
|
||||
template <typename EntryIterT>
|
||||
void Push(EntryIterT begin, EntryIterT end) {
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
const ExpandEntryT& e = *it;
|
||||
if (e.split.loss_chg > kRtEps) {
|
||||
queue_.push(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Push(const std::vector<ExpandEntryT> &entries) {
|
||||
this->Push(entries.begin(), entries.end());
|
||||
}
|
||||
void Push(const ExpandEntryT e) {
|
||||
queue_.push(e);
|
||||
}
|
||||
|
||||
bool IsEmpty() {
|
||||
return queue_.empty();
|
||||
}
|
||||
|
||||
// Return the set of nodes to be expanded
|
||||
// This set has no dependencies between entries so they may be expanded in
|
||||
// parallel or asynchronously
|
||||
std::vector<ExpandEntryT> Pop() {
|
||||
if (queue_.empty()) return {};
|
||||
// Return a single entry for loss guided mode
|
||||
if (policy_ == TrainParam::kLossGuide) {
|
||||
ExpandEntryT e = queue_.top();
|
||||
queue_.pop();
|
||||
return {e};
|
||||
}
|
||||
// Return nodes on same level for depth wise
|
||||
std::vector<ExpandEntryT> result;
|
||||
ExpandEntryT e = queue_.top();
|
||||
int level = e.depth;
|
||||
while (e.depth == level && !queue_.empty()) {
|
||||
queue_.pop();
|
||||
result.emplace_back(e);
|
||||
if (!queue_.empty()) {
|
||||
e = queue_.top();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
TrainParam::TreeGrowPolicy policy_;
|
||||
ExpandQueue queue_;
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TREE_DRIVER_H_
|
||||
@ -1,127 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef DRIVER_CUH_
|
||||
#define DRIVER_CUH_
|
||||
#include <xgboost/span.h>
|
||||
#include <queue>
|
||||
#include "../param.h"
|
||||
#include "evaluate_splits.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
struct ExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
|
||||
float base_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
float left_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
float right_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
|
||||
ExpandEntry() = default;
|
||||
XGBOOST_DEVICE ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
|
||||
float base, float left, float right)
|
||||
: nid(nid), depth(depth), split(std::move(split)), base_weight{base},
|
||||
left_weight{left}, right_weight{right} {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (split.loss_chg < param.min_split_loss) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
|
||||
os << "ExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "left_sum: " << e.split.left_sum << "\n";
|
||||
os << "right_sum: " << e.split.right_sum << "\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||
return lhs.depth > rhs.depth; // favor small depth
|
||||
}
|
||||
|
||||
inline bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
||||
return lhs.nid > rhs.nid; // favor small timestamp
|
||||
} else {
|
||||
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
||||
}
|
||||
}
|
||||
|
||||
// Drives execution of tree building on device
|
||||
class Driver {
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
|
||||
public:
|
||||
explicit Driver(TrainParam::TreeGrowPolicy policy)
|
||||
: policy_(policy),
|
||||
queue_(policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {}
|
||||
template <typename EntryIterT>
|
||||
void Push(EntryIterT begin,EntryIterT end) {
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
const ExpandEntry& e = *it;
|
||||
if (e.split.loss_chg > kRtEps) {
|
||||
queue_.push(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Push(const std::vector<ExpandEntry> &entries) {
|
||||
this->Push(entries.begin(), entries.end());
|
||||
}
|
||||
// Return the set of nodes to be expanded
|
||||
// This set has no dependencies between entries so they may be expanded in
|
||||
// parallel or asynchronously
|
||||
std::vector<ExpandEntry> Pop() {
|
||||
if (queue_.empty()) return {};
|
||||
// Return a single entry for loss guided mode
|
||||
if (policy_ == TrainParam::kLossGuide) {
|
||||
ExpandEntry e = queue_.top();
|
||||
queue_.pop();
|
||||
return {e};
|
||||
}
|
||||
// Return nodes on same level for depth wise
|
||||
std::vector<ExpandEntry> result;
|
||||
ExpandEntry e = queue_.top();
|
||||
int level = e.depth;
|
||||
while (e.depth == level && !queue_.empty()) {
|
||||
queue_.pop();
|
||||
result.emplace_back(e);
|
||||
if (!queue_.empty()) {
|
||||
e = queue_.top();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
TrainParam::TreeGrowPolicy policy_;
|
||||
ExpandQueue queue_;
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // DRIVER_CUH_
|
||||
76
src/tree/gpu_hist/expand_entry.cuh
Normal file
76
src/tree/gpu_hist/expand_entry.cuh
Normal file
@ -0,0 +1,76 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef EXPAND_ENTRY_CUH_
|
||||
#define EXPAND_ENTRY_CUH_
|
||||
#include <xgboost/span.h>
|
||||
#include "../param.h"
|
||||
#include "evaluate_splits.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
struct GPUExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
|
||||
float base_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
float left_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
float right_weight { std::numeric_limits<float>::quiet_NaN() };
|
||||
|
||||
GPUExpandEntry() = default;
|
||||
XGBOOST_DEVICE GPUExpandEntry(int nid, int depth, DeviceSplitCandidate split,
|
||||
float base, float left, float right)
|
||||
: nid(nid), depth(depth), split(std::move(split)), base_weight{base},
|
||||
left_weight{left}, right_weight{right} {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (split.loss_chg < param.min_split_loss) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bst_float GetLossChange() const {
|
||||
return split.loss_chg;
|
||||
}
|
||||
|
||||
int GetNodeId() const {
|
||||
return nid;
|
||||
}
|
||||
|
||||
int GetDepth() const {
|
||||
return depth;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GPUExpandEntry& e) {
|
||||
os << "GPUExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "left_sum: " << e.split.left_sum << "\n";
|
||||
os << "right_sum: " << e.split.right_sum << "\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // EXPAND_ENTRY_CUH_
|
||||
@ -25,6 +25,7 @@
|
||||
#include "../data/ellpack_page.cuh"
|
||||
|
||||
#include "param.h"
|
||||
#include "driver.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "split_evaluator.h"
|
||||
#include "constraints.cuh"
|
||||
@ -33,7 +34,7 @@
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
#include "gpu_hist/driver.cuh"
|
||||
#include "gpu_hist/expand_entry.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -321,8 +322,8 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
void EvaluateLeftRightSplits(
|
||||
ExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
|
||||
common::Span<ExpandEntry> pinned_candidates_out) {
|
||||
GPUExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
|
||||
common::Span<GPUExpandEntry> pinned_candidates_out) {
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto left_sampled_features =
|
||||
@ -363,7 +364,7 @@ struct GPUHistMakerDevice {
|
||||
hist.GetNodeHistogram(right_nidx)};
|
||||
auto d_splits_out = dh::ToSpan(splits_out);
|
||||
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
|
||||
dh::TemporaryArray<ExpandEntry> entries(2);
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(2);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
auto d_entries = entries.data().get();
|
||||
dh::LaunchN(device_id, 2, [=] __device__(size_t idx) {
|
||||
@ -378,12 +379,12 @@ struct GPUHistMakerDevice {
|
||||
nidx, gpu_param, GradStats{split.right_sum});
|
||||
|
||||
d_entries[idx] =
|
||||
ExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx],
|
||||
GPUExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx],
|
||||
base_weight, left_weight, right_weight};
|
||||
});
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
pinned_candidates_out.data(), entries.data().get(),
|
||||
sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
@ -569,7 +570,7 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left,
|
||||
void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left,
|
||||
int nidx_right, dh::AllReducer* reducer) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
@ -599,7 +600,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
}
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
auto evaluator = tree_evaluator.GetEvaluator();
|
||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||
@ -647,7 +648,7 @@ struct GPUHistMakerDevice {
|
||||
tree[candidate.nid].RightChild());
|
||||
}
|
||||
|
||||
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
GradientPair root_sum = dh::Reduce(
|
||||
@ -670,7 +671,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
// Generate first split
|
||||
auto split = this->EvaluateRootSplit(root_sum);
|
||||
dh::TemporaryArray<ExpandEntry> entries(1);
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(1);
|
||||
auto d_entries = entries.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
GPUTrainingParam gpu_param(param);
|
||||
@ -681,20 +682,20 @@ struct GPUHistMakerDevice {
|
||||
float right_weight = evaluator.CalcWeight(
|
||||
kRootNIdx, gpu_param, GradStats{split.right_sum});
|
||||
d_entries[0] =
|
||||
ExpandEntry(kRootNIdx, depth, split,
|
||||
GPUExpandEntry(kRootNIdx, depth, split,
|
||||
weight, left_weight, right_weight);
|
||||
});
|
||||
ExpandEntry root_entry;
|
||||
GPUExpandEntry root_entry;
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
&root_entry, entries.data().get(),
|
||||
sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
auto& tree = *p_tree;
|
||||
Driver driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
|
||||
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
|
||||
|
||||
monitor.Start("Reset");
|
||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
||||
@ -710,7 +711,7 @@ struct GPUHistMakerDevice {
|
||||
auto expand_set = driver.Pop();
|
||||
while (!expand_set.empty()) {
|
||||
auto new_candidates =
|
||||
pinned.GetSpan<ExpandEntry>(expand_set.size() * 2, ExpandEntry());
|
||||
pinned.GetSpan<GPUExpandEntry>(expand_set.size() * 2, GPUExpandEntry());
|
||||
|
||||
for (auto i = 0ull; i < expand_set.size(); i++) {
|
||||
auto candidate = expand_set.at(i);
|
||||
@ -724,7 +725,7 @@ struct GPUHistMakerDevice {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.Start("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
@ -741,8 +742,8 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("EvaluateSplits");
|
||||
} else {
|
||||
// Set default
|
||||
new_candidates[i * 2] = ExpandEntry();
|
||||
new_candidates[i * 2 + 1] = ExpandEntry();
|
||||
new_candidates[i * 2] = GPUExpandEntry();
|
||||
new_candidates[i * 2 + 1] = GPUExpandEntry();
|
||||
}
|
||||
}
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
@ -140,10 +140,11 @@ void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
|
||||
// Merging histograms from each thread into once
|
||||
builder->hist_buffer_.ReduceHist(node, r.begin(), r.end());
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||
const int subtraction_node_id = builder->nodes_for_subtraction_trick_[node].nid;
|
||||
auto parent_hist = builder->hist_[parent_id];
|
||||
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||
auto sibling_hist = builder->hist_[subtraction_node_id];
|
||||
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||
}
|
||||
});
|
||||
@ -169,13 +170,14 @@ void DistributedHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT* builder
|
||||
auto this_local = builder->hist_local_worker_[entry.nid];
|
||||
CopyHist(this_local, this_hist, r.begin(), r.end());
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||
const int subtraction_node_id = builder->nodes_for_subtraction_trick_[node].nid;
|
||||
auto parent_hist = builder->hist_local_worker_[parent_id];
|
||||
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||
auto sibling_hist = builder->hist_[subtraction_node_id];
|
||||
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||
// Store posible parent node
|
||||
auto sibling_local = builder->hist_local_worker_[entry.sibling_nid];
|
||||
auto sibling_local = builder->hist_local_worker_[subtraction_node_id];
|
||||
CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
|
||||
}
|
||||
});
|
||||
@ -186,12 +188,14 @@ void DistributedHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT* builder
|
||||
|
||||
builder->builder_monitor_.Stop("SyncHistogramsAllreduce");
|
||||
|
||||
ParallelSubtractionHist(builder, space, builder->nodes_for_explicit_hist_build_, p_tree);
|
||||
ParallelSubtractionHist(builder, space, builder->nodes_for_explicit_hist_build_,
|
||||
builder->nodes_for_subtraction_trick_, p_tree);
|
||||
|
||||
common::BlockedSpace2d space2(builder->nodes_for_subtraction_trick_.size(), [&](size_t) {
|
||||
return nbins;
|
||||
}, 1024);
|
||||
ParallelSubtractionHist(builder, space2, builder->nodes_for_subtraction_trick_, p_tree);
|
||||
ParallelSubtractionHist(builder, space2, builder->nodes_for_subtraction_trick_,
|
||||
builder->nodes_for_explicit_hist_build_, p_tree);
|
||||
builder->builder_monitor_.Stop("SyncHistograms");
|
||||
}
|
||||
|
||||
@ -199,16 +203,18 @@ template <typename GradientSumT>
|
||||
void DistributedHistSynchronizer<GradientSumT>::ParallelSubtractionHist(
|
||||
BuilderT* builder,
|
||||
const common::BlockedSpace2d& space,
|
||||
const std::vector<ExpandEntryT>& nodes,
|
||||
const std::vector<CPUExpandEntry>& nodes,
|
||||
const std::vector<CPUExpandEntry>& subtraction_nodes,
|
||||
const RegTree * p_tree) {
|
||||
common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const auto& entry = nodes[node];
|
||||
if (!((*p_tree)[entry.nid].IsLeftChild())) {
|
||||
auto this_hist = builder->hist_[entry.nid];
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
const int subtraction_node_id = subtraction_nodes[node].nid;
|
||||
auto parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()];
|
||||
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||
auto sibling_hist = builder->hist_[subtraction_node_id];
|
||||
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
|
||||
}
|
||||
}
|
||||
@ -287,18 +293,19 @@ void QuantileHistMaker::Builder<GradientSumT>::SetHistRowsAdder(
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildHistogramsLossGuide(
|
||||
ExpandEntry entry, const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb, RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const DMatrix& fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h,
|
||||
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
|
||||
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
|
||||
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.push_back(entry);
|
||||
|
||||
if (entry.sibling_nid > -1) {
|
||||
nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid,
|
||||
p_tree->GetDepth(entry.sibling_nid), 0.0f, 0);
|
||||
}
|
||||
nodes_for_explicit_hist_build_.push_back(node);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@ -306,6 +313,13 @@ void QuantileHistMaker::Builder<GradientSumT>::BuildHistogramsLossGuide(
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree);
|
||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||
|
||||
this->InitNewNode(CPUExpandEntry::kRootNid, gmat, gpair_h, fmat, *p_tree);
|
||||
|
||||
this->EvaluateSplits({node}, gmat, hist_, *p_tree);
|
||||
node.loss_chg = snode_[CPUExpandEntry::kRootNid].best.loss_chg;
|
||||
expand->push_back(node);
|
||||
++(*num_leaves);
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
@ -347,48 +361,17 @@ void QuantileHistMaker::Builder<GradientSumT>::BuildLocalHistograms(
|
||||
builder_monitor_.Stop("BuildLocalHistograms");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildNodeStats(
|
||||
const GHistIndexMatrix &gmat,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
builder_monitor_.Start("BuildNodeStats");
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
int nid = entry.nid;
|
||||
this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
// add constraints
|
||||
if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) {
|
||||
// it's a right child
|
||||
auto parent_id = (*p_tree)[nid].Parent();
|
||||
auto left_sibling_id = (*p_tree)[parent_id].LeftChild();
|
||||
auto parent_split_feature_id = snode_[parent_id].best.SplitIndex();
|
||||
tree_evaluator_.AddSplit(
|
||||
parent_id, left_sibling_id, nid, parent_split_feature_id,
|
||||
snode_[left_sibling_id].weight, snode_[nid].weight);
|
||||
interaction_constraints_.Split(parent_id, parent_split_feature_id,
|
||||
left_sibling_id, nid);
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("BuildNodeStats");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const std::vector<CPUExpandEntry>& expand,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry>* nodes_for_apply_split,
|
||||
std::vector<ExpandEntry>* temp_qexpand_depth) {
|
||||
std::vector<CPUExpandEntry>* nodes_for_apply_split) {
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
for (auto const& entry : expand) {
|
||||
int nid = entry.nid;
|
||||
|
||||
if (snode_[nid].best.loss_chg < kRtEps ||
|
||||
(param_.max_depth > 0 && depth == param_.max_depth) ||
|
||||
(param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) {
|
||||
if (entry.IsValid(param_, *num_leaves)) {
|
||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||
} else {
|
||||
nodes_for_apply_split->push_back(entry);
|
||||
@ -402,36 +385,12 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||
|
||||
int left_id = (*p_tree)[nid].LeftChild();
|
||||
int right_id = (*p_tree)[nid].RightChild();
|
||||
temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id,
|
||||
p_tree->GetDepth(left_id), 0.0, (*timestamp)++));
|
||||
temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id,
|
||||
p_tree->GetDepth(right_id), 0.0, (*timestamp)++));
|
||||
// - 1 parent + 2 new children
|
||||
(*num_leaves)++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::EvaluateAndApplySplits(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth) {
|
||||
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);
|
||||
|
||||
std::vector<ExpandEntry> nodes_for_apply_split;
|
||||
AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp,
|
||||
&nodes_for_apply_split, temp_qexpand_depth);
|
||||
ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, p_tree);
|
||||
}
|
||||
|
||||
// Split nodes to 2 sets depending on amount of rows in each node
|
||||
// Histograms for small nodes will be built explicitly
|
||||
// Histograms for big nodes will be built by 'Subtraction Trick'
|
||||
@ -440,148 +399,109 @@ void QuantileHistMaker::Builder<GradientSumT>::EvaluateAndApplySplits(
|
||||
// This ensures that the workers operate on the same set of tree nodes.
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
||||
const std::vector<ExpandEntry> &nodes,
|
||||
std::vector<ExpandEntry> *small_siblings,
|
||||
std::vector<ExpandEntry> *big_siblings, RegTree *p_tree) {
|
||||
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
|
||||
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
|
||||
builder_monitor_.Start("SplitSiblings");
|
||||
for (auto const& entry : nodes) {
|
||||
for (auto const& entry : nodes_for_apply_split) {
|
||||
int nid = entry.nid;
|
||||
RegTree::Node &node = (*p_tree)[nid];
|
||||
if (node.IsRoot()) {
|
||||
small_siblings->push_back(entry);
|
||||
} else {
|
||||
const int32_t left_id = (*p_tree)[node.Parent()].LeftChild();
|
||||
const int32_t right_id = (*p_tree)[node.Parent()].RightChild();
|
||||
|
||||
if (nid == left_id && row_set_collection_[left_id ].Size() <
|
||||
row_set_collection_[right_id].Size()) {
|
||||
small_siblings->push_back(entry);
|
||||
} else if (nid == right_id && row_set_collection_[right_id].Size() <=
|
||||
row_set_collection_[left_id ].Size()) {
|
||||
small_siblings->push_back(entry);
|
||||
} else {
|
||||
big_siblings->push_back(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("SplitSiblings");
|
||||
}
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandWithDepthWise(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const ColumnMatrix &column_matrix,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
unsigned timestamp = 0;
|
||||
int num_leaves = 0;
|
||||
|
||||
// in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
|
||||
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++));
|
||||
++num_leaves;
|
||||
for (int depth = 0; depth < param_.max_depth + 1; depth++) {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
std::vector<ExpandEntry> temp_qexpand_depth;
|
||||
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
|
||||
&nodes_for_subtraction_trick_, p_tree);
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree);
|
||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
|
||||
|
||||
EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp,
|
||||
&temp_qexpand_depth);
|
||||
|
||||
// clean up
|
||||
qexpand_depth_wise_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
if (temp_qexpand_depth.empty()) {
|
||||
break;
|
||||
} else {
|
||||
qexpand_depth_wise_ = temp_qexpand_depth;
|
||||
temp_qexpand_depth.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree,
|
||||
const std::vector<GradientPair>& gpair_h) {
|
||||
builder_monitor_.Start("ExpandWithLossGuide");
|
||||
unsigned timestamp = 0;
|
||||
int num_leaves = 0;
|
||||
|
||||
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
p_tree->GetDepth(0), 0.0f, timestamp++);
|
||||
BuildHistogramsLossGuide(node, gmat, gmatb, p_tree, gpair_h);
|
||||
|
||||
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
|
||||
this->EvaluateSplits({node}, gmat, hist_, *p_tree);
|
||||
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
|
||||
|
||||
qexpand_loss_guided_->push(node);
|
||||
++num_leaves;
|
||||
|
||||
while (!qexpand_loss_guided_->empty()) {
|
||||
const ExpandEntry candidate = qexpand_loss_guided_->top();
|
||||
const int nid = candidate.nid;
|
||||
qexpand_loss_guided_->pop();
|
||||
if (candidate.IsValid(param_, num_leaves)) {
|
||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||
} else {
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
NodeEntry& e = snode_[nid];
|
||||
bst_float left_leaf_weight =
|
||||
evaluator.CalcWeight(nid, param_, GradStats{e.best.left_sum}) * param_.learning_rate;
|
||||
bst_float right_leaf_weight =
|
||||
evaluator.CalcWeight(nid, param_, GradStats{e.best.right_sum}) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
||||
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
|
||||
ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft),
|
||||
0.0f, timestamp++);
|
||||
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
|
||||
0.0f, timestamp++);
|
||||
|
||||
const CPUExpandEntry left_node = CPUExpandEntry(cleft, p_tree->GetDepth(cleft), 0.0);
|
||||
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
|
||||
nodes_to_evaluate->push_back(left_node);
|
||||
nodes_to_evaluate->push_back(right_node);
|
||||
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
||||
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
||||
nodes_for_explicit_hist_build_.push_back(left_node);
|
||||
nodes_for_subtraction_trick_.push_back(right_node);
|
||||
} else {
|
||||
BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h);
|
||||
nodes_for_explicit_hist_build_.push_back(right_node);
|
||||
nodes_for_subtraction_trick_.push_back(left_node);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size());
|
||||
builder_monitor_.Stop("SplitSiblings");
|
||||
}
|
||||
|
||||
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildNodeStats(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const DMatrix& fmat,
|
||||
const std::vector<GradientPair> &gpair_h,
|
||||
const std::vector<CPUExpandEntry>& nodes_for_apply_split, RegTree *p_tree) {
|
||||
for (auto const& candidate : nodes_for_apply_split) {
|
||||
const int nid = candidate.nid;
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
|
||||
InitNewNode(cleft, gmat, gpair_h, fmat, *p_tree);
|
||||
InitNewNode(cright, gmat, gpair_h, fmat, *p_tree);
|
||||
bst_uint featureid = snode_[nid].best.SplitIndex();
|
||||
tree_evaluator_.AddSplit(nid, cleft, cright, featureid,
|
||||
snode_[cleft].weight, snode_[cright].weight);
|
||||
interaction_constraints_.Split(nid, featureid, cleft, cright);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree,
|
||||
const std::vector<GradientPair>& gpair_h) {
|
||||
builder_monitor_.Start("ExpandTree");
|
||||
int num_leaves = 0;
|
||||
|
||||
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
|
||||
std::vector<CPUExpandEntry> expand;
|
||||
InitRoot(gmat, gmatb, *p_fmat, p_tree, gpair_h, &num_leaves, &expand);
|
||||
driver.Push(expand[0]);
|
||||
|
||||
int depth = 0;
|
||||
while (!driver.IsEmpty()) {
|
||||
expand = driver.Pop();
|
||||
depth = expand[0].depth + 1;
|
||||
std::vector<CPUExpandEntry> nodes_for_apply_split;
|
||||
std::vector<CPUExpandEntry> nodes_to_evaluate;
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
|
||||
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split);
|
||||
|
||||
if (nodes_for_apply_split.size() != 0) {
|
||||
ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, p_tree);
|
||||
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree);
|
||||
if (depth < param_.max_depth) {
|
||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||
}
|
||||
|
||||
BuildNodeStats(gmat, *p_fmat, gpair_h, nodes_for_apply_split, p_tree);
|
||||
EvaluateSplits(nodes_to_evaluate, gmat, hist_, *p_tree);
|
||||
|
||||
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
||||
const CPUExpandEntry candidate = nodes_for_apply_split[i];
|
||||
const int nid = candidate.nid;
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
CPUExpandEntry left_node = nodes_to_evaluate[i*2 + 0];
|
||||
CPUExpandEntry right_node = nodes_to_evaluate[i*2 + 1];
|
||||
|
||||
this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
|
||||
left_node.loss_chg = snode_[cleft].best.loss_chg;
|
||||
right_node.loss_chg = snode_[cright].best.loss_chg;
|
||||
|
||||
qexpand_loss_guided_->push(left_node);
|
||||
qexpand_loss_guided_->push(right_node);
|
||||
|
||||
++num_leaves; // give two and take one, as parent is no longer a leaf
|
||||
driver.Push(left_node);
|
||||
driver.Push(right_node);
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("ExpandWithLossGuide");
|
||||
}
|
||||
builder_monitor_.Stop("ExpandTree");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -604,11 +524,8 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
p_last_fmat_mutable_ = p_fmat;
|
||||
|
||||
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
|
||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
} else {
|
||||
ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
}
|
||||
|
||||
ExpandTree(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
|
||||
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
|
||||
p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg;
|
||||
@ -871,13 +788,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
snode_.reserve(256);
|
||||
snode_.clear();
|
||||
}
|
||||
{
|
||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||
qexpand_loss_guided_.reset(new ExpandQueue(LossGuide));
|
||||
} else {
|
||||
qexpand_depth_wise_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("InitData");
|
||||
}
|
||||
|
||||
@ -898,7 +809,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::SplitContainsMissingValues(
|
||||
// nodes_set - set of nodes to be processed in parallel
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::EvaluateSplits(
|
||||
const std::vector<ExpandEntry>& nodes_set,
|
||||
const std::vector<CPUExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
const RegTree& tree) {
|
||||
@ -1123,7 +1034,7 @@ void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
|
||||
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
const std::vector<CPUExpandEntry>& nodes,
|
||||
const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<int32_t>* split_conditions) {
|
||||
@ -1151,7 +1062,7 @@ void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
||||
}
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet(
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
const std::vector<CPUExpandEntry>& nodes,
|
||||
RegTree* p_tree) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
for (unsigned int i = 0; i < n_nodes; ++i) {
|
||||
@ -1165,7 +1076,7 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet(
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<ExpandEntry> nodes,
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "constraints.h"
|
||||
#include "./param.h"
|
||||
#include "./driver.h"
|
||||
#include "./split_evaluator.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/timer.h"
|
||||
@ -148,6 +149,36 @@ struct CPUHistMakerTrainParam
|
||||
}
|
||||
};
|
||||
|
||||
/* tree growing policies */
|
||||
struct CPUExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
static const int kEmptyNid = -1;
|
||||
int nid;
|
||||
int depth;
|
||||
bst_float loss_chg;
|
||||
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
|
||||
: nid(nid), depth(depth), loss_chg(loss_chg) {}
|
||||
|
||||
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
||||
bool ret = loss_chg <= kRtEps ||
|
||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bst_float GetLossChange() const {
|
||||
return loss_chg;
|
||||
}
|
||||
|
||||
int GetNodeId() const {
|
||||
return nid;
|
||||
}
|
||||
|
||||
int GetDepth() const {
|
||||
return depth;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
@ -299,28 +330,6 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
friend class BatchHistRowsAdder<GradientSumT>;
|
||||
friend class DistributedHistRowsAdder<GradientSumT>;
|
||||
|
||||
/* tree growing policies */
|
||||
struct ExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
static const int kEmptyNid = -1;
|
||||
int nid;
|
||||
int sibling_nid;
|
||||
int depth;
|
||||
bst_float loss_chg;
|
||||
unsigned timestamp;
|
||||
ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg,
|
||||
unsigned tstmp)
|
||||
: nid(nid), sibling_nid(sibling_nid), depth(depth),
|
||||
loss_chg(loss_chg), timestamp(tstmp) {}
|
||||
|
||||
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
||||
bool ret = loss_chg <= kRtEps ||
|
||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
// initialize temp data structure
|
||||
void InitData(const GHistIndexMatrix& gmat,
|
||||
const DMatrix& fmat,
|
||||
@ -333,12 +342,12 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
std::vector<GradientPair>* gpair,
|
||||
std::vector<size_t>* row_indices);
|
||||
|
||||
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
void EvaluateSplits(const std::vector<CPUExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
const RegTree& tree);
|
||||
|
||||
void ApplySplit(std::vector<ExpandEntry> nodes,
|
||||
void ApplySplit(std::vector<CPUExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
@ -349,10 +358,10 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const int32_t split_cond,
|
||||
const ColumnMatrix& column_matrix, const RegTree& tree);
|
||||
|
||||
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);
|
||||
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree* p_tree);
|
||||
|
||||
|
||||
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
|
||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
|
||||
|
||||
void InitNewNode(int nid,
|
||||
@ -365,8 +374,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
||||
// value for the particular feature fid.
|
||||
template <int d_step>
|
||||
GradStats EnumerateSplit(
|
||||
const GHistIndexMatrix &gmat, const GHistRowT &hist,
|
||||
GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRowT &hist,
|
||||
const NodeEntry &snode, SplitEntry *p_best, bst_uint fid,
|
||||
bst_uint nodeID,
|
||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const;
|
||||
@ -377,73 +385,42 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
// else - there are missing values
|
||||
bool SplitContainsMissingValues(const GradStats e, const NodeEntry& snode);
|
||||
|
||||
void ExpandWithDepthWise(const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const ColumnMatrix &column_matrix,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
void BuildLocalHistograms(const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
void BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const GHistIndexMatrix &gmat,
|
||||
void InitRoot(const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const DMatrix& fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
const std::vector<GradientPair> &gpair_h,
|
||||
int *num_leaves, std::vector<CPUExpandEntry> *expand);
|
||||
|
||||
// Split nodes to 2 sets depending on amount of rows in each node
|
||||
// Histograms for small nodes will be built explicitly
|
||||
// Histograms for big nodes will be built by 'Subtraction Trick'
|
||||
void SplitSiblings(const std::vector<ExpandEntry>& nodes,
|
||||
std::vector<ExpandEntry>* small_siblings,
|
||||
std::vector<ExpandEntry>* big_siblings,
|
||||
void SplitSiblings(const std::vector<CPUExpandEntry>& nodes,
|
||||
std::vector<CPUExpandEntry>* nodes_to_evaluate,
|
||||
RegTree *p_tree);
|
||||
|
||||
void ParallelSubtractionHist(const common::BlockedSpace2d& space,
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
const RegTree * p_tree);
|
||||
void AddSplitsToTree(const std::vector<CPUExpandEntry>& expand,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
std::vector<CPUExpandEntry>* nodes_for_apply_split);
|
||||
|
||||
void BuildNodeStats(const GHistIndexMatrix &gmat,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
const DMatrix& fmat,
|
||||
const std::vector<GradientPair> &gpair_h,
|
||||
const std::vector<CPUExpandEntry>& nodes_for_apply_split, RegTree *p_tree);
|
||||
|
||||
void EvaluateAndApplySplits(const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth);
|
||||
|
||||
void AddSplitsToTree(
|
||||
const GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry>* nodes_for_apply_split,
|
||||
std::vector<ExpandEntry>* temp_qexpand_depth);
|
||||
|
||||
void ExpandWithLossGuide(const GHistIndexMatrix& gmat,
|
||||
void ExpandTree(const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree,
|
||||
const std::vector<GradientPair>& gpair_h);
|
||||
|
||||
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
if (lhs.loss_chg == rhs.loss_chg) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg
|
||||
}
|
||||
}
|
||||
// --data fields--
|
||||
const size_t n_trees_;
|
||||
const TrainParam& param_;
|
||||
@ -484,16 +461,14 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
DMatrix* p_last_fmat_mutable_;
|
||||
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::priority_queue<CPUExpandEntry, std::vector<CPUExpandEntry>,
|
||||
std::function<bool(CPUExpandEntry, CPUExpandEntry)>>;
|
||||
|
||||
std::unique_ptr<ExpandQueue> qexpand_loss_guided_;
|
||||
std::vector<ExpandEntry> qexpand_depth_wise_;
|
||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||
// provides the evidence for subtraction
|
||||
std::vector<ExpandEntry> nodes_for_subtraction_trick_;
|
||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||
// list of nodes whose histograms would be built explicitly.
|
||||
std::vector<ExpandEntry> nodes_for_explicit_hist_build_;
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
|
||||
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
@ -549,14 +524,14 @@ template <typename GradientSumT>
|
||||
class DistributedHistSynchronizer: public HistSynchronizer<GradientSumT> {
|
||||
public:
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename BuilderT::ExpandEntry;
|
||||
|
||||
void SyncHistograms(BuilderT* builder, int starting_index,
|
||||
int sync_count, RegTree *p_tree) override;
|
||||
|
||||
void ParallelSubtractionHist(BuilderT* builder,
|
||||
const common::BlockedSpace2d& space,
|
||||
const std::vector<ExpandEntryT>& nodes,
|
||||
const std::vector<CPUExpandEntry>& nodes,
|
||||
const std::vector<CPUExpandEntry>& subtraction_nodes,
|
||||
const RegTree * p_tree);
|
||||
};
|
||||
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../../src/tree/gpu_hist/driver.cuh"
|
||||
#include "../../../../src/tree/driver.h"
|
||||
#include "../../../../src/tree/gpu_hist/expand_entry.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
TEST(GpuHist, DriverDepthWise) {
|
||||
Driver driver(TrainParam::kDepthWise);
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kDepthWise);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
DeviceSplitCandidate split;
|
||||
split.loss_chg = 1.0f;
|
||||
ExpandEntry root(0, 0, split, .0f, .0f, .0f);
|
||||
GPUExpandEntry root(0, 0, split, .0f, .0f, .0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
driver.Push({ExpandEntry{1, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{2, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{3, 2, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}});
|
||||
// Should return entries from level 1
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 2);
|
||||
@ -32,14 +33,14 @@ TEST(GpuHist, DriverLossGuided) {
|
||||
DeviceSplitCandidate low_gain;
|
||||
low_gain.loss_chg = 1.0f;
|
||||
|
||||
Driver driver(TrainParam::kLossGuide);
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kLossGuide);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
ExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
|
||||
GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
// Select high gain first
|
||||
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 1);
|
||||
EXPECT_EQ(res[0].nid, 2);
|
||||
@ -48,8 +49,8 @@ TEST(GpuHist, DriverLossGuided) {
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
|
||||
// If equal gain, use nid
|
||||
driver.Push({ExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
res = driver.Pop();
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
res = driver.Pop();
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
ExpandEntry candidate;
|
||||
GPUExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
|
||||
@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
template <typename GradientSumT>
|
||||
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
|
||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename RealImpl::ExpandEntry;
|
||||
using GHistRowT = typename RealImpl::GHistRowT;
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
@ -169,19 +168,19 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
}
|
||||
@ -199,7 +198,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 0
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, tree->GetDepth(0), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
@ -207,11 +206,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 1
|
||||
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(),
|
||||
tree->GetDepth(1), 0.0f, 0);
|
||||
tree->GetDepth(1), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
|
||||
(*tree)[0].LeftChild(),
|
||||
tree->GetDepth(2), 0.0f, 0);
|
||||
tree->GetDepth(2), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
@ -219,10 +216,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 2
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
|
||||
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
|
||||
@ -278,21 +275,27 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
|
||||
}
|
||||
};
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
size_t node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_subtraction_trick_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_explicit_hist_build_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,10 +411,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
|
||||
RealImpl::ExpandEntry::kEmptyNid,
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg, 0);
|
||||
this->snode_[0].best.loss_chg);
|
||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user