Further optimisations for gpu_hist. (#4283)
- Fuse final update position functions into a single more efficient kernel - Refactor gpu_hist with a more explicit ellpack matrix representation
This commit is contained in:
parent
5aa42b5f11
commit
6d5b34d824
@ -93,65 +93,63 @@ class RegTree {
|
||||
"Node: 64 bit align");
|
||||
}
|
||||
/*! \brief index of left child */
|
||||
int LeftChild() const {
|
||||
XGBOOST_DEVICE int LeftChild() const {
|
||||
return this->cleft_;
|
||||
}
|
||||
/*! \brief index of right child */
|
||||
int RightChild() const {
|
||||
XGBOOST_DEVICE int RightChild() const {
|
||||
return this->cright_;
|
||||
}
|
||||
/*! \brief index of default child when feature is missing */
|
||||
int DefaultChild() const {
|
||||
XGBOOST_DEVICE int DefaultChild() const {
|
||||
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
unsigned SplitIndex() const {
|
||||
XGBOOST_DEVICE unsigned SplitIndex() const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
bool DefaultLeft() const {
|
||||
XGBOOST_DEVICE bool DefaultLeft() const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
/*! \brief whether current node is leaf node */
|
||||
bool IsLeaf() const {
|
||||
XGBOOST_DEVICE bool IsLeaf() const {
|
||||
return cleft_ == -1;
|
||||
}
|
||||
/*! \return get leaf value of leaf node */
|
||||
bst_float LeafValue() const {
|
||||
XGBOOST_DEVICE bst_float LeafValue() const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
/*! \return get split condition of the node */
|
||||
SplitCondT SplitCond() const {
|
||||
XGBOOST_DEVICE SplitCondT SplitCond() const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
/*! \brief get parent of the node */
|
||||
int Parent() const {
|
||||
XGBOOST_DEVICE int Parent() const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
/*! \brief whether current node is left child */
|
||||
bool IsLeftChild() const {
|
||||
XGBOOST_DEVICE bool IsLeftChild() const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
/*! \brief whether this node is deleted */
|
||||
bool IsDeleted() const {
|
||||
XGBOOST_DEVICE bool IsDeleted() const {
|
||||
return sindex_ == std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
bool IsRoot() const {
|
||||
return parent_ == -1;
|
||||
}
|
||||
XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; }
|
||||
/*!
|
||||
* \brief set the left child
|
||||
* \param nid node id to right child
|
||||
*/
|
||||
void SetLeftChild(int nid) {
|
||||
XGBOOST_DEVICE void SetLeftChild(int nid) {
|
||||
this->cleft_ = nid;
|
||||
}
|
||||
/*!
|
||||
* \brief set the right child
|
||||
* \param nid node id to right child
|
||||
*/
|
||||
void SetRightChild(int nid) {
|
||||
XGBOOST_DEVICE void SetRightChild(int nid) {
|
||||
this->cright_ = nid;
|
||||
}
|
||||
/*!
|
||||
@ -160,7 +158,7 @@ class RegTree {
|
||||
* \param split_cond split condition
|
||||
* \param default_left the default direction when feature is unknown
|
||||
*/
|
||||
void SetSplit(unsigned split_index, SplitCondT split_cond,
|
||||
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
|
||||
bool default_left = false) {
|
||||
if (default_left) split_index |= (1U << 31);
|
||||
this->sindex_ = split_index;
|
||||
@ -172,17 +170,17 @@ class RegTree {
|
||||
* \param right right index, could be used to store
|
||||
* additional information
|
||||
*/
|
||||
void SetLeaf(bst_float value, int right = -1) {
|
||||
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) {
|
||||
(this->info_).leaf_value = value;
|
||||
this->cleft_ = -1;
|
||||
this->cright_ = right;
|
||||
}
|
||||
/*! \brief mark that this node is deleted */
|
||||
void MarkDelete() {
|
||||
XGBOOST_DEVICE void MarkDelete() {
|
||||
this->sindex_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
// set parent
|
||||
void SetParent(int pidx, bool is_left_child = true) {
|
||||
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
|
||||
if (is_left_child) pidx |= (1U << 31);
|
||||
this->parent_ = pidx;
|
||||
}
|
||||
|
||||
@ -303,6 +303,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||
size_t tree_end) {
|
||||
if (tree_end - tree_begin == 0) { return; }
|
||||
monitor_.StartCuda("DevicePredictInternal");
|
||||
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
@ -337,6 +338,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
});
|
||||
i_batch++;
|
||||
}
|
||||
monitor_.StopCuda("DevicePredictInternal");
|
||||
}
|
||||
|
||||
public:
|
||||
@ -388,9 +390,11 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (it != cache_.end()) {
|
||||
const HostDeviceVector<bst_float>& y = it->second.predictions;
|
||||
if (y.Size() != 0) {
|
||||
monitor_.StartCuda("PredictFromCache");
|
||||
out_preds->Reshard(y.Distribution());
|
||||
out_preds->Resize(y.Size());
|
||||
out_preds->Copy(y);
|
||||
monitor_.StopCuda("PredictFromCache");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -481,6 +485,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::unique_ptr<Predictor> cpu_predictor_;
|
||||
std::vector<DeviceShard> shards_;
|
||||
GPUSet devices_;
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
@ -50,6 +50,133 @@ struct GPUHistMakerTrainParam
|
||||
|
||||
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
|
||||
struct ExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
uint64_t timestamp;
|
||||
ExpandEntry() = default;
|
||||
ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
|
||||
uint64_t timestamp)
|
||||
: nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
|
||||
os << "ExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "left_sum: " << e.split.left_sum << "\n";
|
||||
os << "right_sum: " << e.split.right_sum << "\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
if (lhs.depth == rhs.depth) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.depth > rhs.depth; // favor small depth
|
||||
}
|
||||
}
|
||||
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
||||
}
|
||||
}
|
||||
|
||||
// Find a gidx value for a given feature otherwise return -1 if not found
|
||||
__device__ int BinarySearchRow(bst_uint begin, bst_uint end,
|
||||
common::CompressedIterator<uint32_t> data,
|
||||
int const fidx_begin, int const fidx_end) {
|
||||
bst_uint previous_middle = UINT32_MAX;
|
||||
while (end != begin) {
|
||||
auto middle = begin + (end - begin) / 2;
|
||||
if (middle == previous_middle) {
|
||||
break;
|
||||
}
|
||||
previous_middle = middle;
|
||||
|
||||
auto gidx = data[middle];
|
||||
|
||||
if (gidx >= fidx_begin && gidx < fidx_end) {
|
||||
return gidx;
|
||||
} else if (gidx < fidx_begin) {
|
||||
begin = middle;
|
||||
} else {
|
||||
end = middle;
|
||||
}
|
||||
}
|
||||
// Value is missing
|
||||
return -1;
|
||||
}
|
||||
|
||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||
* device. Does not own underlying memory and may be trivially copied into
|
||||
* kernels.*/
|
||||
struct ELLPackMatrix {
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief minimum value for each feature. */
|
||||
common::Span<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
/*! \brief row length for ELLPack. */
|
||||
size_t row_stride{0};
|
||||
common::CompressedIterator<uint32_t> gidx_iter;
|
||||
bool is_dense;
|
||||
int null_gidx_value;
|
||||
|
||||
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
// Get a matrix element, uses binary search for look up
|
||||
// Return NaN if missing
|
||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
auto gidx = -1;
|
||||
if (is_dense) {
|
||||
gidx = gidx_iter[row_begin + fidx];
|
||||
} else {
|
||||
gidx =
|
||||
BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx],
|
||||
feature_segments[fidx + 1]);
|
||||
}
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
return gidx_fvalue_map[gidx];
|
||||
}
|
||||
void Init(common::Span<uint32_t> feature_segments,
|
||||
common::Span<bst_float> min_fvalue,
|
||||
common::Span<bst_float> gidx_fvalue_map, size_t row_stride,
|
||||
common::CompressedIterator<uint32_t> gidx_iter, bool is_dense,
|
||||
int null_gidx_value) {
|
||||
this->feature_segments = feature_segments;
|
||||
this->min_fvalue = min_fvalue;
|
||||
this->gidx_fvalue_map = gidx_fvalue_map;
|
||||
this->row_stride = row_stride;
|
||||
this->gidx_iter = gidx_iter;
|
||||
this->is_dense = is_dense;
|
||||
this->null_gidx_value = null_gidx_value;
|
||||
}
|
||||
};
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||
@ -111,19 +238,17 @@ __device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_h
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx,
|
||||
common::Span<const GradientSumT> node_histogram,
|
||||
common::Span<const uint32_t> feature_segments, // cut.row_ptr
|
||||
float min_fvalue, // cut.min_value
|
||||
common::Span<const float> gidx_fvalue_map, // cut.cut
|
||||
int fidx, common::Span<const GradientSumT> node_histogram,
|
||||
const ELLPackMatrix& matrix,
|
||||
DeviceSplitCandidate* best_split, // shared memory storing best split
|
||||
const DeviceNodeStats& node, const GPUTrainingParam& param,
|
||||
TempStorageT* temp_storage, // temp memory for cub operations
|
||||
int constraint, // monotonic_constraints
|
||||
const ValueConstraint& value_constraint) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end = feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
uint32_t gidx_begin = matrix.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end =
|
||||
matrix.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum = ReduceFeature<BLOCK_THREADS, ReduceT>(
|
||||
@ -168,16 +293,17 @@ __device__ void EvaluateFeature(
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
int gidx = scan_begin + threadIdx.x;
|
||||
float fvalue =
|
||||
gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1];
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = matrix.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = matrix.gidx_fvalue_map[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir,
|
||||
fvalue, fidx,
|
||||
GradientPair(left),
|
||||
GradientPair(right),
|
||||
param);
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
||||
fidx, GradientPair(left), GradientPair(right), param);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -189,10 +315,7 @@ __global__ void EvaluateSplitKernel(
|
||||
node_histogram, // histogram for gradients
|
||||
common::Span<const int> feature_set, // Selected features
|
||||
DeviceNodeStats node,
|
||||
common::Span<const uint32_t>
|
||||
d_feature_segments, // row_ptr form HistCutMatrix
|
||||
common::Span<const float> d_fidx_min_map, // min_value
|
||||
common::Span<const float> d_gidx_fvalue_map, // cut
|
||||
ELLPackMatrix matrix,
|
||||
GPUTrainingParam gpu_param,
|
||||
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
|
||||
ValueConstraint value_constraint,
|
||||
@ -226,10 +349,8 @@ __global__ void EvaluateSplitKernel(
|
||||
int fidx = feature_set[blockIdx.x];
|
||||
int constraint = d_monotonic_constraints[fidx];
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
|
||||
fidx, node_histogram,
|
||||
d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map,
|
||||
&best_split, node, gpu_param, &temp_storage, constraint,
|
||||
value_constraint);
|
||||
fidx, node_histogram, matrix, &best_split, node, gpu_param, &temp_storage,
|
||||
constraint, value_constraint);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -239,32 +360,6 @@ __global__ void EvaluateSplitKernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Find a gidx value for a given feature otherwise return -1 if not found
|
||||
template <typename GidxIterT>
|
||||
__device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data,
|
||||
int const fidx_begin, int const fidx_end) {
|
||||
bst_uint previous_middle = UINT32_MAX;
|
||||
while (end != begin) {
|
||||
auto middle = begin + (end - begin) / 2;
|
||||
if (middle == previous_middle) {
|
||||
break;
|
||||
}
|
||||
previous_middle = middle;
|
||||
|
||||
auto gidx = data[middle];
|
||||
|
||||
if (gidx >= fidx_begin && gidx < fidx_end) {
|
||||
return gidx;
|
||||
} else if (gidx < fidx_begin) {
|
||||
begin = middle;
|
||||
} else {
|
||||
end = middle;
|
||||
}
|
||||
}
|
||||
// Value is missing
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \struct DeviceHistogram
|
||||
*
|
||||
@ -290,7 +385,6 @@ class DeviceHistogram {
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaMemsetAsync(
|
||||
data_.data().get(), 0,
|
||||
data_.size() * sizeof(typename decltype(data_)::value_type)));
|
||||
@ -397,27 +491,27 @@ __global__ void CompressBinEllpackKernel(
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
__global__ void SharedMemHistKernel(size_t row_stride, const bst_uint* d_ridx,
|
||||
common::CompressedIterator<uint32_t> d_gidx,
|
||||
int null_gidx_value,
|
||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx,
|
||||
GradientSumT* d_node_hist,
|
||||
const GradientPair* d_gpair,
|
||||
size_t segment_begin, size_t n_elements) {
|
||||
extern __shared__ char smem[];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
smem_arr[i] = GradientSumT();
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / row_stride + segment_begin];
|
||||
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
|
||||
if (gidx != null_gidx_value) {
|
||||
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
|
||||
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
if (gidx != matrix.null_gidx_value) {
|
||||
AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||
}
|
||||
}
|
||||
@ -509,32 +603,26 @@ struct DeviceShard {
|
||||
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
|
||||
/*! \brief HistCutMatrix stored in device. */
|
||||
struct DeviceHistCutMatrix {
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
|
||||
/*! \brief Range of rows for each node. */
|
||||
std::vector<Segment> ridx_segments;
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief row_ptr form HistCutMatrix. */
|
||||
dh::DVec<uint32_t> feature_segments;
|
||||
/*! \brief minimum value for each feature. */
|
||||
dh::DVec<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
dh::DVec<bst_float> gidx_fvalue_map;
|
||||
} d_cut;
|
||||
|
||||
/*! \brief Range of rows for each node. */
|
||||
std::vector<Segment> ridx_segments;
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief row length for ELLPack. */
|
||||
size_t row_stride;
|
||||
common::CompressedIterator<uint32_t> gidx;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
|
||||
/*! \brief Row indices relative to this shard, necessary for sorting rows. */
|
||||
dh::DVec2<bst_uint> ridx;
|
||||
/*! \brief Gradient pair for each row. */
|
||||
dh::DVec<GradientPair> gpair;
|
||||
|
||||
/*! \brief The last histogram index. */
|
||||
int null_gidx_value;
|
||||
|
||||
dh::DVec2<int> position;
|
||||
|
||||
dh::DVec<int> monotone_constraints;
|
||||
@ -543,8 +631,6 @@ struct DeviceShard {
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
@ -572,16 +658,13 @@ struct DeviceShard {
|
||||
: device_id(_device_id),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins{0},
|
||||
null_gidx_value(0),
|
||||
n_bins(0),
|
||||
param(std::move(_param)),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
/* Init row_ptrs and row_stride */
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
size_t InitRowPtrs(const SparsePage& row_batch) {
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs.resize(n_rows + 1);
|
||||
thrust::copy(offset_vec.data() + row_begin_idx,
|
||||
@ -597,20 +680,15 @@ struct DeviceShard {
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size),
|
||||
decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
size_t row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
/*
|
||||
Init:
|
||||
n_bins, null_gidx_value, gidx_buffer, row_ptrs, gidx, gidx_fvalue_map,
|
||||
min_fvalue, feature_segments, node_sum_gradients, ridx_segments,
|
||||
hist
|
||||
*/
|
||||
void InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch);
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense);
|
||||
|
||||
void CreateHistIndices(const SparsePage& row_batch);
|
||||
void CreateHistIndices(const SparsePage& row_batch, size_t row_stride, int null_gidx_value);
|
||||
|
||||
~DeviceShard() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
@ -708,10 +786,9 @@ struct DeviceShard {
|
||||
int constexpr kBlockThreads = 256;
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node,
|
||||
d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(),
|
||||
d_cut.gidx_fvalue_map.GetSpan(), gpu_param, d_split_candidates,
|
||||
value_constraints[nidx], monotone_constraints.GetSpan());
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
|
||||
gpu_param, d_split_candidates, value_constraints[nidx],
|
||||
monotone_constraints.GetSpan());
|
||||
|
||||
// Reduce over features to find best feature
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
@ -756,47 +833,35 @@ struct DeviceShard {
|
||||
hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx,
|
||||
int64_t split_gidx, bool default_dir_left, bool is_dense,
|
||||
int fidx_begin, // cut.row_ptr[fidx]
|
||||
int fidx_end) { // cut.row_ptr[fidx + 1]
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||
CHECK(!split_node.IsLeaf()) <<"Node must not be leaf";
|
||||
Segment segment = ridx_segments[nidx];
|
||||
bst_uint* d_ridx = ridx.Current();
|
||||
int* d_position = position.Current();
|
||||
common::CompressedIterator<uint32_t> d_gidx = gidx;
|
||||
size_t row_stride = this->row_stride;
|
||||
if (left_counts.size() <= nidx) {
|
||||
left_counts.resize((nidx * 2) + 1);
|
||||
}
|
||||
int64_t* d_left_count = left_counts.data().get() + nidx;
|
||||
auto d_matrix = this->ellpack_matrix;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(
|
||||
device_id, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
idx += segment.begin;
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
auto gidx = -1;
|
||||
if (is_dense) {
|
||||
// FIXME: Maybe just search the cuts again.
|
||||
gidx = d_gidx[row_begin + fidx];
|
||||
bst_float element = d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
if (isnan(element)) {
|
||||
new_position = split_node.DefaultChild();
|
||||
} else {
|
||||
gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin,
|
||||
fidx_end);
|
||||
}
|
||||
|
||||
// belong to left node or right node.
|
||||
int position;
|
||||
if (gidx >= 0) {
|
||||
// Feature is found
|
||||
position = gidx <= split_gidx ? left_nidx : right_nidx;
|
||||
if (element <= split_node.SplitCond()) {
|
||||
new_position = split_node.LeftChild();
|
||||
} else {
|
||||
// Feature is missing
|
||||
position = default_dir_left ? left_nidx : right_nidx;
|
||||
new_position = split_node.RightChild();
|
||||
}
|
||||
CountLeft(d_left_count, position, left_nidx);
|
||||
d_position[idx] = position;
|
||||
}
|
||||
CountLeft(d_left_count, new_position, split_node.LeftChild());
|
||||
d_position[idx] = new_position;
|
||||
});
|
||||
|
||||
// Overlap device to host memory copy (left_count) with sort
|
||||
@ -805,16 +870,16 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, streams[0]));
|
||||
|
||||
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
|
||||
SortPositionAndCopy(segment, split_node.LeftChild(), split_node.RightChild(), d_left_count,
|
||||
streams[1]);
|
||||
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
|
||||
int64_t left_count = tmp_pinned[0];
|
||||
CHECK_LE(left_count, segment.Size());
|
||||
CHECK_GE(left_count, 0);
|
||||
ridx_segments[left_nidx] =
|
||||
ridx_segments[split_node.LeftChild()] =
|
||||
Segment(segment.begin, segment.begin + left_count);
|
||||
ridx_segments[right_nidx] =
|
||||
ridx_segments[split_node.RightChild()] =
|
||||
Segment(segment.begin + left_count, segment.end);
|
||||
}
|
||||
|
||||
@ -840,6 +905,41 @@ struct DeviceShard {
|
||||
});
|
||||
}
|
||||
|
||||
// After tree update is finished, update the position of all training
|
||||
// instances to their final leaf This information is used later to update the
|
||||
// prediction cache
|
||||
void FinalisePosition(RegTree* p_tree) {
|
||||
const auto d_nodes =
|
||||
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
const auto d_ridx = ridx.Current();
|
||||
auto d_matrix = this->ellpack_matrix;
|
||||
dh::LaunchN(device_id, position.Size(), [=] __device__(size_t idx) {
|
||||
auto position = d_position[idx];
|
||||
auto node = d_nodes[position];
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
d_position[idx] = position;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
@ -880,14 +980,12 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
auto segment = shard->ridx_segments[nidx];
|
||||
auto segment_begin = segment.begin;
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
|
||||
auto d_gidx = shard->gidx;
|
||||
auto d_ridx = shard->ridx.Current();
|
||||
auto d_gpair = shard->gpair.Data();
|
||||
|
||||
int null_gidx_value = shard->null_gidx_value;
|
||||
auto n_elements = segment.Size() * shard->row_stride;
|
||||
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
|
||||
const size_t smem_size = sizeof(GradientSumT) * shard->null_gidx_value;
|
||||
const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount();
|
||||
const int items_per_thread = 8;
|
||||
const int block_threads = 256;
|
||||
const int grid_size =
|
||||
@ -896,9 +994,8 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
if (grid_size <= 0) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
|
||||
(shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair,
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||
shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair,
|
||||
segment_begin, n_elements);
|
||||
}
|
||||
};
|
||||
@ -908,20 +1005,18 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
|
||||
Segment segment = shard->ridx_segments[nidx];
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
common::CompressedIterator<uint32_t> d_gidx = shard->gidx;
|
||||
bst_uint* d_ridx = shard->ridx.Current();
|
||||
GradientPair* d_gpair = shard->gpair.Data();
|
||||
|
||||
size_t const n_elements = segment.Size() * shard->row_stride;
|
||||
size_t const row_stride = shard->row_stride;
|
||||
int const null_gidx_value = shard->null_gidx_value;
|
||||
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
auto d_matrix = shard->ellpack_matrix;
|
||||
|
||||
dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) {
|
||||
int ridx = d_ridx[(idx / row_stride) + segment.begin];
|
||||
int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin];
|
||||
// lookup the index (bin) of histogram.
|
||||
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
|
||||
int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride];
|
||||
|
||||
if (gidx != null_gidx_value) {
|
||||
if (gidx != d_matrix.null_gidx_value) {
|
||||
AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]);
|
||||
}
|
||||
});
|
||||
@ -930,10 +1025,10 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
n_bins = hmat.NumBins();
|
||||
null_gidx_value = hmat.NumBins();
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
|
||||
size_t row_stride = this->InitRowPtrs(row_batch);
|
||||
n_bins = hmat.row_ptr.back();
|
||||
int null_gidx_value = hmat.row_ptr.back();
|
||||
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
@ -944,13 +1039,13 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
&position, n_rows,
|
||||
&prediction_cache, n_rows,
|
||||
&node_sum_gradients_d, max_nodes,
|
||||
&d_cut.feature_segments, hmat.row_ptr.size(),
|
||||
&d_cut.gidx_fvalue_map, hmat.cut.size(),
|
||||
&d_cut.min_fvalue, hmat.min_val.size(),
|
||||
&feature_segments, hmat.row_ptr.size(),
|
||||
&gidx_fvalue_map, hmat.cut.size(),
|
||||
&min_fvalue, hmat.min_val.size(),
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
d_cut.gidx_fvalue_map = hmat.cut;
|
||||
d_cut.min_fvalue = hmat.min_val;
|
||||
d_cut.feature_segments = hmat.row_ptr;
|
||||
gidx_fvalue_map = hmat.cut;
|
||||
min_fvalue = hmat.min_val;
|
||||
feature_segments = hmat.row_ptr;
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
@ -970,15 +1065,18 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
int nbits = common::detail::SymbolBits(num_symbols);
|
||||
this->CreateHistIndices(row_batch, row_stride, null_gidx_value);
|
||||
|
||||
CreateHistIndices(row_batch);
|
||||
|
||||
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);
|
||||
ellpack_matrix.Init(
|
||||
feature_segments.GetSpan(), min_fvalue.GetSpan(),
|
||||
gidx_fvalue_map.GetSpan(), row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols),
|
||||
is_dense, null_gidx_value);
|
||||
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
|
||||
auto histogram_size = sizeof(GradientSumT) * null_gidx_value;
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||
// hiding)
|
||||
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
|
||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||
if (histogram_size <= max_smem) {
|
||||
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
|
||||
@ -990,9 +1088,9 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
hist.Init(device_id, hmat.NumBins());
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_batch) {
|
||||
inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
const SparsePage& row_batch, size_t row_stride, int null_gidx_value) {
|
||||
int num_symbols = n_bins + 1;
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows =
|
||||
@ -1026,7 +1124,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
||||
gidx_buffer.Data(),
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(),
|
||||
d_cut.gidx_fvalue_map.Data(), d_cut.feature_segments.Data(),
|
||||
gidx_fvalue_map.Data(), feature_segments.Data(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
@ -1039,12 +1137,9 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
||||
entries_d.shrink_to_fit();
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GPUHistMakerSpecialised{
|
||||
public:
|
||||
struct ExpandEntry;
|
||||
|
||||
GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {}
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||
@ -1111,7 +1206,6 @@ class GPUHistMakerSpecialised{
|
||||
shard = std::unique_ptr<DeviceShard<GradientSumT>>(
|
||||
new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(i), start,
|
||||
start + size, param_));
|
||||
shard->InitRowPtrs(batch);
|
||||
});
|
||||
|
||||
// Find the cuts.
|
||||
@ -1119,13 +1213,14 @@ class GPUHistMakerSpecialised{
|
||||
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
monitor_.StopCuda("Quantiles");
|
||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
shard->InitCompressedData(hmat_, batch, is_dense);
|
||||
});
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
++batch_iter;
|
||||
@ -1300,32 +1395,19 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
int nidx = candidate.nid;
|
||||
int left_nidx = (*p_tree)[nidx].LeftChild();
|
||||
int right_nidx = (*p_tree)[nidx].RightChild();
|
||||
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
int64_t split_gidx = -1;
|
||||
int64_t fidx = candidate.split.findex;
|
||||
bool default_dir_left = candidate.split.dir == kLeftDir;
|
||||
uint32_t fidx_begin = hmat_.row_ptr[fidx];
|
||||
uint32_t fidx_end = hmat_.row_ptr[fidx + 1];
|
||||
// split_gidx = i where i is the i^th bin containing split value.
|
||||
for (auto i = fidx_begin; i < fidx_end; ++i) {
|
||||
if (candidate.split.fvalue == hmat_.cut[i]) {
|
||||
split_gidx = static_cast<int64_t>(i);
|
||||
}
|
||||
}
|
||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx,
|
||||
default_dir_left, is_dense, fidx_begin,
|
||||
fidx_end);
|
||||
shard->UpdatePosition(candidate.nid,
|
||||
p_tree->GetNodes()[candidate.nid]);
|
||||
});
|
||||
}
|
||||
void FinalisePosition(RegTree* p_tree) {
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->FinalisePosition(p_tree);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1380,20 +1462,22 @@ class GPUHistMakerSpecialised{
|
||||
while (!qexpand_->empty()) {
|
||||
ExpandEntry candidate = qexpand_->top();
|
||||
qexpand_->pop();
|
||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||
if (!candidate.IsValid(param_, num_leaves)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
this->ApplySplit(candidate, p_tree);
|
||||
monitor_.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor_.StopCuda("UpdatePosition");
|
||||
num_leaves++;
|
||||
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor_.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor_.StopCuda("UpdatePosition");
|
||||
|
||||
monitor_.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
||||
right_child_nidx);
|
||||
@ -1411,6 +1495,10 @@ class GPUHistMakerSpecialised{
|
||||
monitor_.StopCuda("EvaluateSplits");
|
||||
}
|
||||
}
|
||||
|
||||
monitor_.StartCuda("FinalisePosition");
|
||||
this->FinalisePosition(p_tree);
|
||||
monitor_.StopCuda("FinalisePosition");
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
@ -1431,64 +1519,6 @@ class GPUHistMakerSpecialised{
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
uint64_t timestamp;
|
||||
ExpandEntry(int _nid, int _depth, const DeviceSplitCandidate _split,
|
||||
uint64_t _timestamp) :
|
||||
nid{_nid}, depth{_depth}, split(std::move(_split)),
|
||||
timestamp{_timestamp} {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) {
|
||||
return false;
|
||||
}
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth,
|
||||
int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
|
||||
os << "ExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "left_sum: " << e.split.left_sum << "\n";
|
||||
os << "right_sum: " << e.split.right_sum << "\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
if (lhs.depth == rhs.depth) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.depth > rhs.depth; // favor small depth
|
||||
}
|
||||
}
|
||||
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
|
||||
}
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
common::HistCutMatrix hmat_; // NOLINT
|
||||
MetaInfo* info_; // NOLINT
|
||||
@ -1507,7 +1537,8 @@ class GPUHistMakerSpecialised{
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
common::GHistIndexMatrix gmat_;
|
||||
|
||||
using ExpandQueue = std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand_;
|
||||
dh::AllReducer reducer_;
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
|
||||
bool FileExists(const std::string& filename);
|
||||
|
||||
long GetFileSize(const std::string& filename);
|
||||
int64_t GetFileSize(const std::string& filename);
|
||||
|
||||
void CreateSimpleTestData(const std::string& filename);
|
||||
|
||||
|
||||
@ -39,8 +39,9 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
||||
0.26f, 0.74f, 1.98f,
|
||||
0.26f, 0.71f, 1.83f};
|
||||
|
||||
shard->InitRowPtrs(batch);
|
||||
shard->InitCompressedData(cmat, batch);
|
||||
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||
shard->InitCompressedData(cmat, batch, is_dense);
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
@ -59,7 +60,7 @@ TEST(GpuHist, BuildGidxDense) {
|
||||
h_gidx_buffer = shard.gidx_buffer.AsVector();
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_EQ(shard.row_stride, kNCols);
|
||||
ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols);
|
||||
|
||||
std::vector<uint32_t> solution = {
|
||||
0, 3, 8, 9, 14, 17, 20, 21,
|
||||
@ -98,7 +99,7 @@ TEST(GpuHist, BuildGidxSparse) {
|
||||
h_gidx_buffer = shard.gidx_buffer.AsVector();
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_LE(shard.row_stride, 3);
|
||||
ASSERT_LE(shard.ellpack_matrix.row_stride, 3);
|
||||
|
||||
// row_stride = 3, 16 rows, 48 entries for ELLPack
|
||||
std::vector<uint32_t> solution = {
|
||||
@ -106,7 +107,7 @@ TEST(GpuHist, BuildGidxSparse) {
|
||||
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
|
||||
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
|
||||
};
|
||||
for (size_t i = 0; i < kNRows * shard.row_stride; ++i) {
|
||||
for (size_t i = 0; i < kNRows * shard.ellpack_matrix.row_stride; ++i) {
|
||||
ASSERT_EQ(solution[i], gidx[i]);
|
||||
}
|
||||
}
|
||||
@ -256,16 +257,19 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
common::HistCutMatrix cmat = GetHostCutMatrix();
|
||||
|
||||
// Copy cut matrix to device.
|
||||
DeviceShard<GradientPairPrecise>::DeviceHistCutMatrix cut;
|
||||
shard->ba.Allocate(0,
|
||||
&(shard->d_cut.feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->d_cut.min_fvalue), cmat.min_val.size(),
|
||||
&(shard->d_cut.gidx_fvalue_map), 24,
|
||||
&(shard->feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->min_fvalue), cmat.min_val.size(),
|
||||
&(shard->gidx_fvalue_map), 24,
|
||||
&(shard->monotone_constraints), kNCols);
|
||||
shard->d_cut.feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
|
||||
shard->d_cut.gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
|
||||
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
|
||||
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
|
||||
shard->monotone_constraints.copy(param.monotone_constraints.begin(),
|
||||
param.monotone_constraints.end());
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
|
||||
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
|
||||
|
||||
// Initialize DeviceShard::hist
|
||||
shard->hist.Init(0, (max_bins - 1) * kNCols);
|
||||
@ -339,7 +343,7 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->ridx_segments[0] = Segment(0, kNRows);
|
||||
shard->ba.Allocate(0, &(shard->ridx), kNRows,
|
||||
&(shard->position), kNRows);
|
||||
shard->row_stride = kNCols;
|
||||
shard->ellpack_matrix.row_stride = kNCols;
|
||||
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
|
||||
shard->ridx.CurrentDVec().tend());
|
||||
// Initialize GPUHistMaker
|
||||
@ -351,11 +355,9 @@ TEST(GpuHist, ApplySplit) {
|
||||
0.59, 4, // fvalue has to be equal to one of the cut field
|
||||
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
|
||||
GPUTrainingParam(param));
|
||||
GPUHistMakerSpecialised<GradientPairPrecise>::ExpandEntry candidate_entry {0, 0, candidate, 0};
|
||||
ExpandEntry candidate_entry {0, 0, candidate, 0};
|
||||
candidate_entry.nid = kNId;
|
||||
|
||||
auto const& nodes = tree.GetNodes();
|
||||
|
||||
// Used to get bin_id in update position.
|
||||
common::HistCutMatrix cmat = GetHostCutMatrix();
|
||||
hist_maker.hmat_ = cmat;
|
||||
@ -370,19 +372,31 @@ TEST(GpuHist, ApplySplit) {
|
||||
int row_stride = kNCols;
|
||||
int num_symbols = n_bins + 1;
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(
|
||||
row_stride * kNRows, num_symbols);
|
||||
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes);
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows,
|
||||
num_symbols);
|
||||
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes,
|
||||
&(shard->feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->min_fvalue), cmat.min_val.size(),
|
||||
&(shard->gidx_fvalue_map), 24);
|
||||
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
|
||||
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
|
||||
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
|
||||
shard->ellpack_matrix.is_dense = true;
|
||||
|
||||
common::CompressedBufferWriter wr(num_symbols);
|
||||
std::vector<int> h_gidx (kNRows * row_stride);
|
||||
std::iota(h_gidx.begin(), h_gidx.end(), 0);
|
||||
// gidx 14 should go right, 12 goes left
|
||||
std::vector<int> h_gidx (kNRows * row_stride, 14);
|
||||
h_gidx[4] = 12;
|
||||
h_gidx[12] = 12;
|
||||
std::vector<common::CompressedByteT> h_gidx_compressed (compressed_size_bytes);
|
||||
|
||||
wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end());
|
||||
shard->gidx_buffer.copy(h_gidx_compressed.begin(), h_gidx_compressed.end());
|
||||
|
||||
shard->gidx = common::CompressedIterator<uint32_t>(
|
||||
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
|
||||
shard->gidx_buffer.Data(), num_symbols);
|
||||
|
||||
hist_maker.info_ = &info;
|
||||
@ -395,8 +409,8 @@ TEST(GpuHist, ApplySplit) {
|
||||
int right_nidx = tree[kNId].RightChild();
|
||||
|
||||
ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0);
|
||||
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 6);
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 6);
|
||||
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2);
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2);
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
|
||||
}
|
||||
|
||||
@ -417,7 +431,7 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, d_left_count.data().get());
|
||||
right_idx, d_left_count.data().get(), nullptr);
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user