Reduce device synchronisation (#5631)
* Reduce device synchronisation * Initialise pinned memory
This commit is contained in:
parent
9910265064
commit
fcf57823b6
@ -503,6 +503,15 @@ struct PinnedMemory {
|
|||||||
return xgboost::common::Span<T>(static_cast<T *>(temp_storage), size);
|
return xgboost::common::Span<T>(static_cast<T *>(temp_storage), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
xgboost::common::Span<T> GetSpan(size_t size, T init) {
|
||||||
|
auto result = this->GetSpan<T>(size);
|
||||||
|
for (auto &e : result) {
|
||||||
|
e = init;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void Free() {
|
void Free() {
|
||||||
if (temp_storage != nullptr) {
|
if (temp_storage != nullptr) {
|
||||||
safe_cuda(cudaFreeHost(temp_storage));
|
safe_cuda(cudaFreeHost(temp_storage));
|
||||||
|
|||||||
120
src/tree/gpu_hist/driver.cuh
Normal file
120
src/tree/gpu_hist/driver.cuh
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2020 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef DRIVER_CUH_
|
||||||
|
#define DRIVER_CUH_
|
||||||
|
#include <xgboost/span.h>
|
||||||
|
#include <queue>
|
||||||
|
#include "../param.h"
|
||||||
|
#include "evaluate_splits.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
struct ExpandEntry {
|
||||||
|
int nid;
|
||||||
|
int depth;
|
||||||
|
DeviceSplitCandidate split;
|
||||||
|
ExpandEntry() = default;
|
||||||
|
XGBOOST_DEVICE ExpandEntry(int nid, int depth, DeviceSplitCandidate split)
|
||||||
|
: nid(nid), depth(depth), split(std::move(split)) {}
|
||||||
|
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||||
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
|
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (split.loss_chg < param.min_split_loss) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||||
|
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||||
|
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
|
||||||
|
os << "ExpandEntry: \n";
|
||||||
|
os << "nidx: " << e.nid << "\n";
|
||||||
|
os << "depth: " << e.depth << "\n";
|
||||||
|
os << "loss: " << e.split.loss_chg << "\n";
|
||||||
|
os << "left_sum: " << e.split.left_sum << "\n";
|
||||||
|
os << "right_sum: " << e.split.right_sum << "\n";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||||
|
return lhs.depth > rhs.depth; // favor small depth
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||||
|
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
||||||
|
return lhs.nid > rhs.nid; // favor small timestamp
|
||||||
|
} else {
|
||||||
|
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drives execution of tree building on device
|
||||||
|
class Driver {
|
||||||
|
using ExpandQueue =
|
||||||
|
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||||
|
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Driver(TrainParam::TreeGrowPolicy policy)
|
||||||
|
: policy_(policy),
|
||||||
|
queue_(policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {}
|
||||||
|
template <typename EntryIterT>
|
||||||
|
void Push(EntryIterT begin,EntryIterT end) {
|
||||||
|
for (auto it = begin; it != end; ++it) {
|
||||||
|
const ExpandEntry& e = *it;
|
||||||
|
if (e.split.loss_chg > kRtEps) {
|
||||||
|
queue_.push(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Push(const std::vector<ExpandEntry> &entries) {
|
||||||
|
this->Push(entries.begin(), entries.end());
|
||||||
|
}
|
||||||
|
// Return the set of nodes to be expanded
|
||||||
|
// This set has no dependencies between entries so they may be expanded in
|
||||||
|
// parallel or asynchronously
|
||||||
|
std::vector<ExpandEntry> Pop() {
|
||||||
|
if (queue_.empty()) return {};
|
||||||
|
// Return a single entry for loss guided mode
|
||||||
|
if (policy_ == TrainParam::kLossGuide) {
|
||||||
|
ExpandEntry e = queue_.top();
|
||||||
|
queue_.pop();
|
||||||
|
return {e};
|
||||||
|
}
|
||||||
|
// Return nodes on same level for depth wise
|
||||||
|
std::vector<ExpandEntry> result;
|
||||||
|
ExpandEntry e = queue_.top();
|
||||||
|
int level = e.depth;
|
||||||
|
while (e.depth == level && !queue_.empty()) {
|
||||||
|
queue_.pop();
|
||||||
|
result.emplace_back(e);
|
||||||
|
if (!queue_.empty()) {
|
||||||
|
e = queue_.top();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TrainParam::TreeGrowPolicy policy_;
|
||||||
|
ExpandQueue queue_;
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // DRIVER_CUH_
|
||||||
@ -61,6 +61,7 @@ class RowPartitioner {
|
|||||||
dh::caching_device_vector<int64_t>
|
dh::caching_device_vector<int64_t>
|
||||||
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
|
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
|
||||||
std::vector<cudaStream_t> streams_;
|
std::vector<cudaStream_t> streams_;
|
||||||
|
dh::PinnedMemory pinned_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RowPartitioner(int device_idx, size_t num_rows);
|
RowPartitioner(int device_idx, size_t num_rows);
|
||||||
@ -129,12 +130,12 @@ class RowPartitioner {
|
|||||||
d_position[idx] = new_position;
|
d_position[idx] = new_position;
|
||||||
});
|
});
|
||||||
// Overlap device to host memory copy (left_count) with sort
|
// Overlap device to host memory copy (left_count) with sort
|
||||||
int64_t left_count;
|
int64_t &left_count = pinned_.GetSpan<int64_t>(1)[0];
|
||||||
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
|
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
|
||||||
cudaMemcpyDeviceToHost, streams_[0]));
|
cudaMemcpyDeviceToHost, streams_[0]));
|
||||||
|
|
||||||
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
|
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1]
|
||||||
streams_[1]);
|
);
|
||||||
|
|
||||||
dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
|
dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
|
||||||
CHECK_LE(left_count, segment.Size());
|
CHECK_LE(left_count, segment.Size());
|
||||||
|
|||||||
@ -30,6 +30,7 @@
|
|||||||
#include "gpu_hist/row_partitioner.cuh"
|
#include "gpu_hist/row_partitioner.cuh"
|
||||||
#include "gpu_hist/histogram.cuh"
|
#include "gpu_hist/histogram.cuh"
|
||||||
#include "gpu_hist/evaluate_splits.cuh"
|
#include "gpu_hist/evaluate_splits.cuh"
|
||||||
|
#include "gpu_hist/driver.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -57,58 +58,6 @@ struct GPUHistMakerTrainParam
|
|||||||
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||||
#endif // !defined(GTEST_TEST)
|
#endif // !defined(GTEST_TEST)
|
||||||
|
|
||||||
struct ExpandEntry {
|
|
||||||
int nid;
|
|
||||||
int depth;
|
|
||||||
DeviceSplitCandidate split;
|
|
||||||
uint64_t timestamp;
|
|
||||||
ExpandEntry() = default;
|
|
||||||
ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
|
|
||||||
uint64_t timestamp)
|
|
||||||
: nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {}
|
|
||||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
|
||||||
if (split.loss_chg <= kRtEps) return false;
|
|
||||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (split.loss_chg < param.min_split_loss) { return false; }
|
|
||||||
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
|
|
||||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
|
||||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
|
||||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
|
|
||||||
os << "ExpandEntry: \n";
|
|
||||||
os << "nidx: " << e.nid << "\n";
|
|
||||||
os << "depth: " << e.depth << "\n";
|
|
||||||
os << "loss: " << e.split.loss_chg << "\n";
|
|
||||||
os << "left_sum: " << e.split.left_sum << "\n";
|
|
||||||
os << "right_sum: " << e.split.right_sum << "\n";
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
|
||||||
if (lhs.depth == rhs.depth) {
|
|
||||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
|
||||||
} else {
|
|
||||||
return lhs.depth > rhs.depth; // favor small depth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
|
||||||
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
|
||||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
|
||||||
} else {
|
|
||||||
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \struct DeviceHistogram
|
* \struct DeviceHistogram
|
||||||
*
|
*
|
||||||
@ -243,6 +192,8 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
GradientSumT histogram_rounding;
|
GradientSumT histogram_rounding;
|
||||||
|
|
||||||
|
dh::PinnedMemory pinned;
|
||||||
|
|
||||||
std::vector<cudaStream_t> streams{};
|
std::vector<cudaStream_t> streams{};
|
||||||
|
|
||||||
common::Monitor monitor;
|
common::Monitor monitor;
|
||||||
@ -250,11 +201,6 @@ struct GPUHistMakerDevice {
|
|||||||
common::ColumnSampler column_sampler;
|
common::ColumnSampler column_sampler;
|
||||||
FeatureInteractionConstraintDevice interaction_constraints;
|
FeatureInteractionConstraintDevice interaction_constraints;
|
||||||
|
|
||||||
using ExpandQueue =
|
|
||||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
|
||||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
|
||||||
std::unique_ptr<ExpandQueue> qexpand;
|
|
||||||
|
|
||||||
std::unique_ptr<GradientBasedSampler> sampler;
|
std::unique_ptr<GradientBasedSampler> sampler;
|
||||||
|
|
||||||
GPUHistMakerDevice(int _device_id,
|
GPUHistMakerDevice(int _device_id,
|
||||||
@ -314,11 +260,6 @@ struct GPUHistMakerDevice {
|
|||||||
// Note that the column sampler must be passed by value because it is not
|
// Note that the column sampler must be passed by value because it is not
|
||||||
// thread safe
|
// thread safe
|
||||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||||
if (param.grow_policy == TrainParam::kLossGuide) {
|
|
||||||
qexpand.reset(new ExpandQueue(LossGuide));
|
|
||||||
} else {
|
|
||||||
qexpand.reset(new ExpandQueue(DepthWise));
|
|
||||||
}
|
|
||||||
this->column_sampler.Init(num_columns, param.colsample_bynode,
|
this->column_sampler.Init(num_columns, param.colsample_bynode,
|
||||||
param.colsample_bylevel, param.colsample_bytree);
|
param.colsample_bylevel, param.colsample_bytree);
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device_id));
|
||||||
@ -370,9 +311,9 @@ struct GPUHistMakerDevice {
|
|||||||
return result.front();
|
return result.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<DeviceSplitCandidate> EvaluateLeftRightSplits(
|
void EvaluateLeftRightSplits(
|
||||||
ExpandEntry candidate, int left_nidx, int right_nidx,
|
ExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
|
||||||
const RegTree& tree) {
|
common::Span<ExpandEntry> pinned_candidates_out) {
|
||||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
|
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
auto left_sampled_features =
|
auto left_sampled_features =
|
||||||
@ -412,12 +353,19 @@ struct GPUHistMakerDevice {
|
|||||||
hist.GetNodeHistogram(right_nidx),
|
hist.GetNodeHistogram(right_nidx),
|
||||||
node_value_constraints[right_nidx],
|
node_value_constraints[right_nidx],
|
||||||
dh::ToSpan(monotone_constraints)};
|
dh::ToSpan(monotone_constraints)};
|
||||||
EvaluateSplits(dh::ToSpan(splits_out), left, right);
|
auto d_splits_out = dh::ToSpan(splits_out);
|
||||||
std::vector<DeviceSplitCandidate> result(2);
|
EvaluateSplits(d_splits_out, left, right);
|
||||||
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
|
dh::TemporaryArray<ExpandEntry> entries(2);
|
||||||
sizeof(DeviceSplitCandidate) * splits_out.size(),
|
auto d_entries = entries.data().get();
|
||||||
cudaMemcpyDeviceToHost));
|
dh::LaunchN(device_id, 1, [=] __device__(size_t idx) {
|
||||||
return result;
|
d_entries[0] =
|
||||||
|
ExpandEntry(left_nidx, candidate.depth + 1, d_splits_out[0]);
|
||||||
|
d_entries[1] =
|
||||||
|
ExpandEntry(right_nidx, candidate.depth + 1, d_splits_out[1]);
|
||||||
|
});
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(
|
||||||
|
pinned_candidates_out.data(), entries.data().get(),
|
||||||
|
sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildHist(int nidx) {
|
void BuildHist(int nidx) {
|
||||||
@ -637,7 +585,7 @@ struct GPUHistMakerDevice {
|
|||||||
tree[candidate.nid].RightChild());
|
tree[candidate.nid].RightChild());
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||||
constexpr bst_node_t kRootNIdx = 0;
|
constexpr bst_node_t kRootNIdx = 0;
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
GradientPair root_sum = thrust::reduce(
|
GradientPair root_sum = thrust::reduce(
|
||||||
@ -662,28 +610,32 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
// Generate first split
|
// Generate first split
|
||||||
auto split = this->EvaluateRootSplit(root_sum);
|
auto split = this->EvaluateRootSplit(root_sum);
|
||||||
qexpand->push(
|
return ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split);
|
||||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||||
RegTree* p_tree, dh::AllReducer* reducer) {
|
RegTree* p_tree, dh::AllReducer* reducer) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
|
Driver driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
|
||||||
|
|
||||||
monitor.Start("Reset");
|
monitor.Start("Reset");
|
||||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
||||||
monitor.Stop("Reset");
|
monitor.Stop("Reset");
|
||||||
|
|
||||||
monitor.Start("InitRoot");
|
monitor.Start("InitRoot");
|
||||||
this->InitRoot(p_tree, reducer);
|
driver.Push({ this->InitRoot(p_tree, reducer) });
|
||||||
monitor.Stop("InitRoot");
|
monitor.Stop("InitRoot");
|
||||||
|
|
||||||
auto timestamp = qexpand->size();
|
|
||||||
auto num_leaves = 1;
|
auto num_leaves = 1;
|
||||||
|
|
||||||
while (!qexpand->empty()) {
|
// The set of leaves that can be expanded asynchronously
|
||||||
ExpandEntry candidate = qexpand->top();
|
auto expand_set = driver.Pop();
|
||||||
qexpand->pop();
|
while (!expand_set.empty()) {
|
||||||
|
auto new_candidates =
|
||||||
|
pinned.GetSpan<ExpandEntry>(expand_set.size() * 2, ExpandEntry());
|
||||||
|
|
||||||
|
for (auto i = 0ull; i < expand_set.size(); i++) {
|
||||||
|
auto candidate = expand_set.at(i);
|
||||||
if (!candidate.IsValid(param, num_leaves)) {
|
if (!candidate.IsValid(param, num_leaves)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -705,19 +657,20 @@ struct GPUHistMakerDevice {
|
|||||||
monitor.Stop("BuildHist");
|
monitor.Stop("BuildHist");
|
||||||
|
|
||||||
monitor.Start("EvaluateSplits");
|
monitor.Start("EvaluateSplits");
|
||||||
auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx,
|
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
|
||||||
right_child_nidx,
|
right_child_nidx, *p_tree,
|
||||||
*p_tree);
|
new_candidates.subspan(i * 2, 2));
|
||||||
monitor.Stop("EvaluateSplits");
|
monitor.Stop("EvaluateSplits");
|
||||||
|
} else {
|
||||||
qexpand->push(ExpandEntry(left_child_nidx,
|
// Set default
|
||||||
tree.GetDepth(left_child_nidx), splits.at(0),
|
new_candidates[i * 2] = ExpandEntry();
|
||||||
timestamp++));
|
new_candidates[i * 2 + 1] = ExpandEntry();
|
||||||
qexpand->push(ExpandEntry(right_child_nidx,
|
|
||||||
tree.GetDepth(right_child_nidx),
|
|
||||||
splits.at(1), timestamp++));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
driver.Push(new_candidates.begin(), new_candidates.end());
|
||||||
|
expand_set = driver.Pop();
|
||||||
|
}
|
||||||
|
|
||||||
monitor.Start("FinalisePosition");
|
monitor.Start("FinalisePosition");
|
||||||
this->FinalisePosition(p_tree, p_fmat);
|
this->FinalisePosition(p_tree, p_fmat);
|
||||||
|
|||||||
@ -264,7 +264,7 @@ TEST_F(SerializationTest, CPUCoordDescent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
TEST_F(SerializationTest, GPUHist) {
|
TEST_F(SerializationTest, GpuHist) {
|
||||||
TestLearnerSerialization({{"booster", "gbtree"},
|
TestLearnerSerialization({{"booster", "gbtree"},
|
||||||
{"seed", "0"},
|
{"seed", "0"},
|
||||||
{"enable_experimental_json_serialization", "1"},
|
{"enable_experimental_json_serialization", "1"},
|
||||||
@ -441,7 +441,7 @@ TEST_F(LogitSerializationTest, CPUCoordDescent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
TEST_F(LogitSerializationTest, GPUHist) {
|
TEST_F(LogitSerializationTest, GpuHist) {
|
||||||
TestLearnerSerialization({{"booster", "gbtree"},
|
TestLearnerSerialization({{"booster", "gbtree"},
|
||||||
{"objective", "binary:logistic"},
|
{"objective", "binary:logistic"},
|
||||||
{"seed", "0"},
|
{"seed", "0"},
|
||||||
@ -596,7 +596,7 @@ TEST_F(MultiClassesSerializationTest, CPUCoordDescent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
TEST_F(MultiClassesSerializationTest, GPUHist) {
|
TEST_F(MultiClassesSerializationTest, GpuHist) {
|
||||||
TestLearnerSerialization({{"booster", "gbtree"},
|
TestLearnerSerialization({{"booster", "gbtree"},
|
||||||
{"num_class", std::to_string(kClasses)},
|
{"num_class", std::to_string(kClasses)},
|
||||||
{"seed", "0"},
|
{"seed", "0"},
|
||||||
|
|||||||
59
tests/cpp/tree/gpu_hist/test_driver.cu
Normal file
59
tests/cpp/tree/gpu_hist/test_driver.cu
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "../../../../src/tree/gpu_hist/driver.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
TEST(GpuHist, DriverDepthWise) {
|
||||||
|
Driver driver(TrainParam::kDepthWise);
|
||||||
|
EXPECT_TRUE(driver.Pop().empty());
|
||||||
|
DeviceSplitCandidate split;
|
||||||
|
split.loss_chg = 1.0f;
|
||||||
|
ExpandEntry root(0, 0, split);
|
||||||
|
driver.Push({root});
|
||||||
|
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||||
|
driver.Push({ExpandEntry{1, 1, split}});
|
||||||
|
driver.Push({ExpandEntry{2, 1, split}});
|
||||||
|
driver.Push({ExpandEntry{3, 2, split}});
|
||||||
|
// Should return entries from level 1
|
||||||
|
auto res = driver.Pop();
|
||||||
|
EXPECT_EQ(res.size(), 2);
|
||||||
|
for (auto &e : res) {
|
||||||
|
EXPECT_EQ(e.depth, 1);
|
||||||
|
}
|
||||||
|
res = driver.Pop();
|
||||||
|
EXPECT_EQ(res[0].depth, 2);
|
||||||
|
EXPECT_TRUE(driver.Pop().empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GpuHist, DriverLossGuided) {
|
||||||
|
DeviceSplitCandidate high_gain;
|
||||||
|
high_gain.loss_chg = 5.0f;
|
||||||
|
DeviceSplitCandidate low_gain;
|
||||||
|
low_gain.loss_chg = 1.0f;
|
||||||
|
|
||||||
|
Driver driver(TrainParam::kLossGuide);
|
||||||
|
EXPECT_TRUE(driver.Pop().empty());
|
||||||
|
ExpandEntry root(0, 0, high_gain);
|
||||||
|
driver.Push({root});
|
||||||
|
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||||
|
// Select high gain first
|
||||||
|
driver.Push({ExpandEntry{1, 1, low_gain}});
|
||||||
|
driver.Push({ExpandEntry{2, 2, high_gain}});
|
||||||
|
auto res = driver.Pop();
|
||||||
|
EXPECT_EQ(res.size(), 1);
|
||||||
|
EXPECT_EQ(res[0].nid, 2);
|
||||||
|
res = driver.Pop();
|
||||||
|
EXPECT_EQ(res.size(), 1);
|
||||||
|
EXPECT_EQ(res[0].nid, 1);
|
||||||
|
|
||||||
|
// If equal gain, use nid
|
||||||
|
driver.Push({ExpandEntry{2, 1, low_gain}});
|
||||||
|
driver.Push({ExpandEntry{1, 1, low_gain}});
|
||||||
|
res = driver.Pop();
|
||||||
|
EXPECT_EQ(res[0].nid, 1);
|
||||||
|
res = driver.Pop();
|
||||||
|
EXPECT_EQ(res[0].nid, 2);
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -40,7 +40,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
TEST_F(UpdaterTreeStatTest, GPUHist) {
|
TEST_F(UpdaterTreeStatTest, GpuHist) {
|
||||||
this->RunTest("grow_gpu_hist");
|
this->RunTest("grow_gpu_hist");
|
||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user