From 6d5b34d82486cd1d0480c548f5d1953834659bd6 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 24 Mar 2019 17:17:22 +1300 Subject: [PATCH] Further optimisations for gpu_hist. (#4283) - Fuse final update position functions into a single more efficient kernel - Refactor gpu_hist with a more explicit ellpack matrix representation --- include/xgboost/tree_model.h | 38 ++- src/predictor/gpu_predictor.cu | 5 + src/tree/updater_gpu_hist.cu | 535 +++++++++++++++++--------------- tests/cpp/helpers.h | 2 +- tests/cpp/tree/test_gpu_hist.cu | 62 ++-- 5 files changed, 345 insertions(+), 297 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 0e0eefea6..31dc15093 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -93,65 +93,63 @@ class RegTree { "Node: 64 bit align"); } /*! \brief index of left child */ - int LeftChild() const { + XGBOOST_DEVICE int LeftChild() const { return this->cleft_; } /*! \brief index of right child */ - int RightChild() const { + XGBOOST_DEVICE int RightChild() const { return this->cright_; } /*! \brief index of default child when feature is missing */ - int DefaultChild() const { + XGBOOST_DEVICE int DefaultChild() const { return this->DefaultLeft() ? this->LeftChild() : this->RightChild(); } /*! \brief feature index of split condition */ - unsigned SplitIndex() const { + XGBOOST_DEVICE unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } /*! \brief when feature is unknown, whether goes to left child */ - bool DefaultLeft() const { + XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; } /*! \brief whether current node is leaf node */ - bool IsLeaf() const { + XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == -1; } /*! \return get leaf value of leaf node */ - bst_float LeafValue() const { + XGBOOST_DEVICE bst_float LeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ - SplitCondT SplitCond() const { + XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; } /*! \brief get parent of the node */ - int Parent() const { + XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); } /*! \brief whether current node is left child */ - bool IsLeftChild() const { + XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; } /*! \brief whether this node is deleted */ - bool IsDeleted() const { + XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == std::numeric_limits::max(); } /*! \brief whether current node is root */ - bool IsRoot() const { - return parent_ == -1; - } + XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; } /*! * \brief set the left child * \param nid node id to right child */ - void SetLeftChild(int nid) { + XGBOOST_DEVICE void SetLeftChild(int nid) { this->cleft_ = nid; } /*! * \brief set the right child * \param nid node id to right child */ - void SetRightChild(int nid) { + XGBOOST_DEVICE void SetRightChild(int nid) { this->cright_ = nid; } /*! @@ -160,7 +158,7 @@ class RegTree { * \param split_cond split condition * \param default_left the default direction when feature is unknown */ - void SetSplit(unsigned split_index, SplitCondT split_cond, + XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left = false) { if (default_left) split_index |= (1U << 31); this->sindex_ = split_index; @@ -172,17 +170,17 @@ class RegTree { * \param right right index, could be used to store * additional information */ - void SetLeaf(bst_float value, int right = -1) { + XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) { (this->info_).leaf_value = value; this->cleft_ = -1; this->cright_ = right; } /*! \brief mark that this node is deleted */ - void MarkDelete() { + XGBOOST_DEVICE void MarkDelete() { this->sindex_ = std::numeric_limits::max(); } // set parent - void SetParent(int pidx, bool is_left_child = true) { + XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) { if (is_left_child) pidx |= (1U << 31); this->parent_ = pidx; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 3b235c12c..39b5c87cf 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -303,6 +303,7 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { if (tree_end - tree_begin == 0) { return; } + monitor_.StartCuda("DevicePredictInternal"); CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device @@ -337,6 +338,7 @@ class GPUPredictor : public xgboost::Predictor { }); i_batch++; } + monitor_.StopCuda("DevicePredictInternal"); } public: @@ -388,9 +390,11 @@ class GPUPredictor : public xgboost::Predictor { if (it != cache_.end()) { const HostDeviceVector& y = it->second.predictions; if (y.Size() != 0) { + monitor_.StartCuda("PredictFromCache"); out_preds->Reshard(y.Distribution()); out_preds->Resize(y.Size()); out_preds->Copy(y); + monitor_.StopCuda("PredictFromCache"); return true; } } @@ -481,6 +485,7 @@ class GPUPredictor : public xgboost::Predictor { std::unique_ptr cpu_predictor_; std::vector shards_; GPUSet devices_; + common::Monitor monitor_; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index eeb211265..f1ddef018 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -50,6 +50,133 @@ struct GPUHistMakerTrainParam DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); +struct ExpandEntry { + int nid; + int depth; + DeviceSplitCandidate split; + uint64_t timestamp; + ExpandEntry() = default; + ExpandEntry(int nid, int depth, DeviceSplitCandidate split, + uint64_t timestamp) + : nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {} + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + return true; + } + + static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } +}; + +inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.depth == rhs.depth) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.depth > rhs.depth; // favor small depth + } +} +inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.split.loss_chg == rhs.split.loss_chg) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg + } +} + +// Find a gidx value for a given feature otherwise return -1 if not found +__device__ int BinarySearchRow(bst_uint begin, bst_uint end, + common::CompressedIterator data, + int const fidx_begin, int const fidx_end) { + bst_uint previous_middle = UINT32_MAX; + while (end != begin) { + auto middle = begin + (end - begin) / 2; + if (middle == previous_middle) { + break; + } + previous_middle = middle; + + auto gidx = data[middle]; + + if (gidx >= fidx_begin && gidx < fidx_end) { + return gidx; + } else if (gidx < fidx_begin) { + begin = middle; + } else { + end = middle; + } + } + // Value is missing + return -1; +} + +/** \brief Struct for accessing and manipulating an ellpack matrix on the + * device. Does not own underlying memory and may be trivially copied into + * kernels.*/ +struct ELLPackMatrix { + common::Span feature_segments; + /*! \brief minimum value for each feature. */ + common::Span min_fvalue; + /*! \brief Cut. */ + common::Span gidx_fvalue_map; + /*! \brief row length for ELLPack. */ + size_t row_stride{0}; + common::CompressedIterator gidx_iter; + bool is_dense; + int null_gidx_value; + + XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); } + + // Get a matrix element, uses binary search for look up + // Return NaN if missing + __device__ bst_float GetElement(size_t ridx, size_t fidx) const { + auto row_begin = row_stride * ridx; + auto row_end = row_begin + row_stride; + auto gidx = -1; + if (is_dense) { + gidx = gidx_iter[row_begin + fidx]; + } else { + gidx = + BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx], + feature_segments[fidx + 1]); + } + if (gidx == -1) { + return nan(""); + } + return gidx_fvalue_map[gidx]; + } + void Init(common::Span feature_segments, + common::Span min_fvalue, + common::Span gidx_fvalue_map, size_t row_stride, + common::CompressedIterator gidx_iter, bool is_dense, + int null_gidx_value) { + this->feature_segments = feature_segments; + this->min_fvalue = min_fvalue; + this->gidx_fvalue_map = gidx_fvalue_map; + this->row_stride = row_stride; + this->gidx_iter = gidx_iter; + this->is_dense = is_dense; + this->null_gidx_value = null_gidx_value; + } +}; + // With constraints template XGBOOST_DEVICE float inline LossChangeMissing( @@ -111,19 +238,17 @@ __device__ GradientSumT ReduceFeature(common::Span feature_h template __device__ void EvaluateFeature( - int fidx, - common::Span node_histogram, - common::Span feature_segments, // cut.row_ptr - float min_fvalue, // cut.min_value - common::Span gidx_fvalue_map, // cut.cut + int fidx, common::Span node_histogram, + const ELLPackMatrix& matrix, DeviceSplitCandidate* best_split, // shared memory storing best split const DeviceNodeStats& node, const GPUTrainingParam& param, TempStorageT* temp_storage, // temp memory for cub operations int constraint, // monotonic_constraints const ValueConstraint& value_constraint) { // Use pointer from cut to indicate begin and end of bins for each feature. - uint32_t gidx_begin = feature_segments[fidx]; // begining bin - uint32_t gidx_end = feature_segments[fidx + 1]; // end bin for i^th feature + uint32_t gidx_begin = matrix.feature_segments[fidx]; // begining bin + uint32_t gidx_end = + matrix.feature_segments[fidx + 1]; // end bin for i^th feature // Sum histogram bins for current feature GradientSumT const feature_sum = ReduceFeature( @@ -168,16 +293,17 @@ __device__ void EvaluateFeature( // Best thread updates split if (threadIdx.x == block_max.key) { - int gidx = scan_begin + threadIdx.x; - float fvalue = - gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; + int split_gidx = (scan_begin + threadIdx.x) - 1; + float fvalue; + if (split_gidx < static_cast(gidx_begin)) { + fvalue = matrix.min_fvalue[fidx]; + } else { + fvalue = matrix.gidx_fvalue_map[split_gidx]; + } GradientSumT left = missing_left ? bin + missing : bin; GradientSumT right = parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, - fvalue, fidx, - GradientPair(left), - GradientPair(right), - param); + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, + fidx, GradientPair(left), GradientPair(right), param); } __syncthreads(); } @@ -189,10 +315,7 @@ __global__ void EvaluateSplitKernel( node_histogram, // histogram for gradients common::Span feature_set, // Selected features DeviceNodeStats node, - common::Span - d_feature_segments, // row_ptr form HistCutMatrix - common::Span d_fidx_min_map, // min_value - common::Span d_gidx_fvalue_map, // cut + ELLPackMatrix matrix, GPUTrainingParam gpu_param, common::Span split_candidates, // resulting split ValueConstraint value_constraint, @@ -226,10 +349,8 @@ __global__ void EvaluateSplitKernel( int fidx = feature_set[blockIdx.x]; int constraint = d_monotonic_constraints[fidx]; EvaluateFeature( - fidx, node_histogram, - d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map, - &best_split, node, gpu_param, &temp_storage, constraint, - value_constraint); + fidx, node_histogram, matrix, &best_split, node, gpu_param, &temp_storage, + constraint, value_constraint); __syncthreads(); @@ -239,32 +360,6 @@ __global__ void EvaluateSplitKernel( } } -// Find a gidx value for a given feature otherwise return -1 if not found -template -__device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data, - int const fidx_begin, int const fidx_end) { - bst_uint previous_middle = UINT32_MAX; - while (end != begin) { - auto middle = begin + (end - begin) / 2; - if (middle == previous_middle) { - break; - } - previous_middle = middle; - - auto gidx = data[middle]; - - if (gidx >= fidx_begin && gidx < fidx_end) { - return gidx; - } else if (gidx < fidx_begin) { - begin = middle; - } else { - end = middle; - } - } - // Value is missing - return -1; -} - /** * \struct DeviceHistogram * @@ -290,7 +385,6 @@ class DeviceHistogram { } void Reset() { - dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaMemsetAsync( data_.data().get(), 0, data_.size() * sizeof(typename decltype(data_)::value_type))); @@ -397,27 +491,27 @@ __global__ void CompressBinEllpackKernel( } template -__global__ void SharedMemHistKernel(size_t row_stride, const bst_uint* d_ridx, - common::CompressedIterator d_gidx, - int null_gidx_value, +__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx, GradientSumT* d_node_hist, const GradientPair* d_gpair, size_t segment_begin, size_t n_elements) { extern __shared__ char smem[]; GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT - for (auto i : dh::BlockStrideRange(0, null_gidx_value)) { + for (auto i : + dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { smem_arr[i] = GradientSumT(); } __syncthreads(); for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / row_stride + segment_begin]; - int gidx = d_gidx[ridx * row_stride + idx % row_stride]; - if (gidx != null_gidx_value) { + int ridx = d_ridx[idx / matrix.row_stride + segment_begin]; + int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; + if (gidx != matrix.null_gidx_value) { AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]); } } __syncthreads(); - for (auto i : dh::BlockStrideRange(0, null_gidx_value)) { + for (auto i : + dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { AtomicAddGpair(d_node_hist + i, smem_arr[i]); } } @@ -509,32 +603,26 @@ struct DeviceShard { dh::BulkAllocator ba; - /*! \brief HistCutMatrix stored in device. */ - struct DeviceHistCutMatrix { - /*! \brief row_ptr form HistCutMatrix. */ - dh::DVec feature_segments; - /*! \brief minimum value for each feature. */ - dh::DVec min_fvalue; - /*! \brief Cut. */ - dh::DVec gidx_fvalue_map; - } d_cut; + ELLPackMatrix ellpack_matrix; /*! \brief Range of rows for each node. */ std::vector ridx_segments; DeviceHistogram hist; - /*! \brief row length for ELLPack. */ - size_t row_stride; - common::CompressedIterator gidx; + /*! \brief row_ptr form HistCutMatrix. */ + dh::DVec feature_segments; + /*! \brief minimum value for each feature. */ + dh::DVec min_fvalue; + /*! \brief Cut. */ + dh::DVec gidx_fvalue_map; + /*! \brief global index of histogram, which is stored in ELLPack format. */ + dh::DVec gidx_buffer; /*! \brief Row indices relative to this shard, necessary for sorting rows. */ dh::DVec2 ridx; /*! \brief Gradient pair for each row. */ dh::DVec gpair; - /*! \brief The last histogram index. */ - int null_gidx_value; - dh::DVec2 position; dh::DVec monotone_constraints; @@ -543,8 +631,6 @@ struct DeviceShard { /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; dh::DVec node_sum_gradients_d; - /*! \brief global index of histogram, which is stored in ELLPack format. */ - dh::DVec gidx_buffer; /*! \brief row offset in SparsePage (the input data). */ thrust::device_vector row_ptrs; /*! \brief On-device feature set, only actually used on one of the devices */ @@ -572,16 +658,13 @@ struct DeviceShard { : device_id(_device_id), row_begin_idx(row_begin), row_end_idx(row_end), - row_stride(0), n_rows(row_end - row_begin), - n_bins{0}, - null_gidx_value(0), + n_bins(0), param(std::move(_param)), prediction_cache_initialised(false) {} /* Init row_ptrs and row_stride */ - void InitRowPtrs(const SparsePage& row_batch) { - dh::safe_cuda(cudaSetDevice(device_id)); + size_t InitRowPtrs(const SparsePage& row_batch) { const auto& offset_vec = row_batch.offset.HostVector(); row_ptrs.resize(n_rows + 1); thrust::copy(offset_vec.data() + row_begin_idx, @@ -595,22 +678,17 @@ struct DeviceShard { auto counting = thrust::make_counting_iterator(size_t(0)); using TransformT = thrust::transform_iterator; + decltype(counting), size_t>; TransformT row_size_iter = TransformT(counting, get_size); - row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0, - thrust::maximum()); + size_t row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0, + thrust::maximum()); + return row_stride; } - /* - Init: - n_bins, null_gidx_value, gidx_buffer, row_ptrs, gidx, gidx_fvalue_map, - min_fvalue, feature_segments, node_sum_gradients, ridx_segments, - hist - */ void InitCompressedData( - const common::HistCutMatrix& hmat, const SparsePage& row_batch); + const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense); - void CreateHistIndices(const SparsePage& row_batch); + void CreateHistIndices(const SparsePage& row_batch, size_t row_stride, int null_gidx_value); ~DeviceShard() { dh::safe_cuda(cudaSetDevice(device_id)); @@ -708,10 +786,9 @@ struct DeviceShard { int constexpr kBlockThreads = 256; EvaluateSplitKernel <<>>( - hist.GetNodeHistogram(nidx), d_feature_set, node, - d_cut.feature_segments.GetSpan(), d_cut.min_fvalue.GetSpan(), - d_cut.gidx_fvalue_map.GetSpan(), gpu_param, d_split_candidates, - value_constraints[nidx], monotone_constraints.GetSpan()); + hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix, + gpu_param, d_split_candidates, value_constraints[nidx], + monotone_constraints.GetSpan()); // Reduce over features to find best feature auto d_result = d_result_all.subspan(i, 1); @@ -756,47 +833,35 @@ struct DeviceShard { hist.HistogramExists(nidx_parent); } - void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx, - int64_t split_gidx, bool default_dir_left, bool is_dense, - int fidx_begin, // cut.row_ptr[fidx] - int fidx_end) { // cut.row_ptr[fidx + 1] - dh::safe_cuda(cudaSetDevice(device_id)); + void UpdatePosition(int nidx, RegTree::Node split_node) { + CHECK(!split_node.IsLeaf()) <<"Node must not be leaf"; Segment segment = ridx_segments[nidx]; bst_uint* d_ridx = ridx.Current(); int* d_position = position.Current(); - common::CompressedIterator d_gidx = gidx; - size_t row_stride = this->row_stride; if (left_counts.size() <= nidx) { left_counts.resize((nidx * 2) + 1); } int64_t* d_left_count = left_counts.data().get() + nidx; + auto d_matrix = this->ellpack_matrix; // Launch 1 thread for each row dh::LaunchN<1, 128>( device_id, segment.Size(), [=] __device__(bst_uint idx) { idx += segment.begin; bst_uint ridx = d_ridx[idx]; - auto row_begin = row_stride * ridx; - auto row_end = row_begin + row_stride; - auto gidx = -1; - if (is_dense) { - // FIXME: Maybe just search the cuts again. - gidx = d_gidx[row_begin + fidx]; + bst_float element = d_matrix.GetElement(ridx, split_node.SplitIndex()); + // Missing value + int new_position = 0; + if (isnan(element)) { + new_position = split_node.DefaultChild(); } else { - gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, - fidx_end); + if (element <= split_node.SplitCond()) { + new_position = split_node.LeftChild(); + } else { + new_position = split_node.RightChild(); + } } - - // belong to left node or right node. - int position; - if (gidx >= 0) { - // Feature is found - position = gidx <= split_gidx ? left_nidx : right_nidx; - } else { - // Feature is missing - position = default_dir_left ? left_nidx : right_nidx; - } - CountLeft(d_left_count, position, left_nidx); - d_position[idx] = position; + CountLeft(d_left_count, new_position, split_node.LeftChild()); + d_position[idx] = new_position; }); // Overlap device to host memory copy (left_count) with sort @@ -805,16 +870,16 @@ struct DeviceShard { dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t), cudaMemcpyDeviceToHost, streams[0])); - SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, + SortPositionAndCopy(segment, split_node.LeftChild(), split_node.RightChild(), d_left_count, streams[1]); dh::safe_cuda(cudaStreamSynchronize(streams[0])); int64_t left_count = tmp_pinned[0]; CHECK_LE(left_count, segment.Size()); CHECK_GE(left_count, 0); - ridx_segments[left_nidx] = + ridx_segments[split_node.LeftChild()] = Segment(segment.begin, segment.begin + left_count); - ridx_segments[right_nidx] = + ridx_segments[split_node.RightChild()] = Segment(segment.begin + left_count, segment.end); } @@ -840,6 +905,41 @@ struct DeviceShard { }); } + // After tree update is finished, update the position of all training + // instances to their final leaf This information is used later to update the + // prediction cache + void FinalisePosition(RegTree* p_tree) { + const auto d_nodes = + temp_memory.GetSpan(p_tree->GetNodes().size()); + dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), + d_nodes.size() * sizeof(RegTree::Node), + cudaMemcpyHostToDevice)); + auto d_position = position.Current(); + const auto d_ridx = ridx.Current(); + auto d_matrix = this->ellpack_matrix; + dh::LaunchN(device_id, position.Size(), [=] __device__(size_t idx) { + auto position = d_position[idx]; + auto node = d_nodes[position]; + bst_uint ridx = d_ridx[idx]; + + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetElement(ridx, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + if (element <= node.SplitCond()) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; + } + d_position[idx] = position; + }); + } + void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); if (!prediction_cache_initialised) { @@ -880,14 +980,12 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { auto segment = shard->ridx_segments[nidx]; auto segment_begin = segment.begin; auto d_node_hist = shard->hist.GetNodeHistogram(nidx); - auto d_gidx = shard->gidx; auto d_ridx = shard->ridx.Current(); auto d_gpair = shard->gpair.Data(); - int null_gidx_value = shard->null_gidx_value; - auto n_elements = segment.Size() * shard->row_stride; + auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride; - const size_t smem_size = sizeof(GradientSumT) * shard->null_gidx_value; + const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount(); const int items_per_thread = 8; const int block_threads = 256; const int grid_size = @@ -896,10 +994,9 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { if (grid_size <= 0) { return; } - dh::safe_cuda(cudaSetDevice(shard->device_id)); - SharedMemHistKernel<<>> - (shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair, - segment_begin, n_elements); + SharedMemHistKernel<<>>( + shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, + segment_begin, n_elements); } }; @@ -908,20 +1005,18 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { void Build(DeviceShard* shard, int nidx) override { Segment segment = shard->ridx_segments[nidx]; auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); - common::CompressedIterator d_gidx = shard->gidx; bst_uint* d_ridx = shard->ridx.Current(); GradientPair* d_gpair = shard->gpair.Data(); - size_t const n_elements = segment.Size() * shard->row_stride; - size_t const row_stride = shard->row_stride; - int const null_gidx_value = shard->null_gidx_value; + size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride; + auto d_matrix = shard->ellpack_matrix; dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) { - int ridx = d_ridx[(idx / row_stride) + segment.begin]; + int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin]; // lookup the index (bin) of histogram. - int gidx = d_gidx[ridx * row_stride + idx % row_stride]; + int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride]; - if (gidx != null_gidx_value) { + if (gidx != d_matrix.null_gidx_value) { AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]); } }); @@ -930,10 +1025,10 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { template inline void DeviceShard::InitCompressedData( - const common::HistCutMatrix& hmat, const SparsePage& row_batch) { - dh::safe_cuda(cudaSetDevice(device_id)); - n_bins = hmat.NumBins(); - null_gidx_value = hmat.NumBins(); + const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) { + size_t row_stride = this->InitRowPtrs(row_batch); + n_bins = hmat.row_ptr.back(); + int null_gidx_value = hmat.row_ptr.back(); int max_nodes = param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); @@ -944,13 +1039,13 @@ inline void DeviceShard::InitCompressedData( &position, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, - &d_cut.feature_segments, hmat.row_ptr.size(), - &d_cut.gidx_fvalue_map, hmat.cut.size(), - &d_cut.min_fvalue, hmat.min_val.size(), + &feature_segments, hmat.row_ptr.size(), + &gidx_fvalue_map, hmat.cut.size(), + &min_fvalue, hmat.min_val.size(), &monotone_constraints, param.monotone_constraints.size()); - d_cut.gidx_fvalue_map = hmat.cut; - d_cut.min_fvalue = hmat.min_val; - d_cut.feature_segments = hmat.row_ptr; + gidx_fvalue_map = hmat.cut; + min_fvalue = hmat.min_val; + feature_segments = hmat.row_ptr; monotone_constraints = param.monotone_constraints; node_sum_gradients.resize(max_nodes); @@ -970,15 +1065,18 @@ inline void DeviceShard::InitCompressedData( ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes); gidx_buffer.Fill(0); - int nbits = common::detail::SymbolBits(num_symbols); + this->CreateHistIndices(row_batch, row_stride, null_gidx_value); - CreateHistIndices(row_batch); - - gidx = common::CompressedIterator(gidx_buffer.Data(), num_symbols); + ellpack_matrix.Init( + feature_segments.GetSpan(), min_fvalue.GetSpan(), + gidx_fvalue_map.GetSpan(), row_stride, + common::CompressedIterator(gidx_buffer.Data(), num_symbols), + is_dense, null_gidx_value); // check if we can use shared memory for building histograms - // (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding) - auto histogram_size = sizeof(GradientSumT) * null_gidx_value; + // (assuming atleast we need 2 CTAs per SM to maintain decent latency + // hiding) + auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back(); auto max_smem = dh::MaxSharedMemory(device_id); if (histogram_size <= max_smem) { hist_builder.reset(new SharedMemHistBuilder); @@ -990,9 +1088,9 @@ inline void DeviceShard::InitCompressedData( hist.Init(device_id, hmat.NumBins()); } - template -inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { +inline void DeviceShard::CreateHistIndices( + const SparsePage& row_batch, size_t row_stride, int null_gidx_value) { int num_symbols = n_bins + 1; // bin and compress entries in batches of rows size_t gpu_batch_nrows = @@ -1026,7 +1124,7 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_b gidx_buffer.Data(), row_ptrs.data().get() + batch_row_begin, entries_d.data().get(), - d_cut.gidx_fvalue_map.Data(), d_cut.feature_segments.Data(), + gidx_fvalue_map.Data(), feature_segments.Data(), batch_row_begin, batch_nrows, row_ptrs[batch_row_begin], row_stride, null_gidx_value); @@ -1039,12 +1137,9 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_b entries_d.shrink_to_fit(); } - template class GPUHistMakerSpecialised{ public: - struct ExpandEntry; - GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {} void Init( const std::vector>& args) { @@ -1111,7 +1206,6 @@ class GPUHistMakerSpecialised{ shard = std::unique_ptr>( new DeviceShard(dist_.Devices().DeviceId(i), start, start + size, param_)); - shard->InitRowPtrs(batch); }); // Find the cuts. @@ -1119,13 +1213,14 @@ class GPUHistMakerSpecialised{ common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows); n_bins_ = hmat_.row_ptr.back(); monitor_.StopCuda("Quantiles"); + auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; monitor_.StartCuda("BinningCompression"); dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { dh::safe_cuda(cudaSetDevice(shard->device_id)); - shard->InitCompressedData(hmat_, batch); + shard->InitCompressedData(hmat_, batch, is_dense); }); monitor_.StopCuda("BinningCompression"); ++batch_iter; @@ -1300,32 +1395,19 @@ class GPUHistMakerSpecialised{ } void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { - int nidx = candidate.nid; - int left_nidx = (*p_tree)[nidx].LeftChild(); - int right_nidx = (*p_tree)[nidx].RightChild(); - - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - int64_t split_gidx = -1; - int64_t fidx = candidate.split.findex; - bool default_dir_left = candidate.split.dir == kLeftDir; - uint32_t fidx_begin = hmat_.row_ptr[fidx]; - uint32_t fidx_end = hmat_.row_ptr[fidx + 1]; - // split_gidx = i where i is the i^th bin containing split value. - for (auto i = fidx_begin; i < fidx_end; ++i) { - if (candidate.split.fvalue == hmat_.cut[i]) { - split_gidx = static_cast(i); - } - } - auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; - dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { dh::safe_cuda(cudaSetDevice(shard->device_id)); - shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx, - default_dir_left, is_dense, fidx_begin, - fidx_end); + shard->UpdatePosition(candidate.nid, + p_tree->GetNodes()[candidate.nid]); + }); + } + void FinalisePosition(RegTree* p_tree) { + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->FinalisePosition(p_tree); }); } @@ -1380,20 +1462,22 @@ class GPUHistMakerSpecialised{ while (!qexpand_->empty()) { ExpandEntry candidate = qexpand_->top(); qexpand_->pop(); - if (!candidate.IsValid(param_, num_leaves)) continue; + if (!candidate.IsValid(param_, num_leaves)) { + continue; + } this->ApplySplit(candidate, p_tree); - monitor_.StartCuda("UpdatePosition"); - this->UpdatePosition(candidate, p_tree); - monitor_.StopCuda("UpdatePosition"); num_leaves++; int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - // Only create child entries if needed if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx), num_leaves)) { + monitor_.StartCuda("UpdatePosition"); + this->UpdatePosition(candidate, p_tree); + monitor_.StopCuda("UpdatePosition"); + monitor_.StartCuda("BuildHist"); this->BuildHistLeftRight(candidate.nid, left_child_nidx, right_child_nidx); @@ -1407,10 +1491,14 @@ class GPUHistMakerSpecialised{ timestamp++)); qexpand_->push(ExpandEntry(right_child_nidx, tree.GetDepth(right_child_nidx), - splits.at(1), timestamp++)); + splits.at(1), timestamp++)); monitor_.StopCuda("EvaluateSplits"); } } + + monitor_.StartCuda("FinalisePosition"); + this->FinalisePosition(p_tree); + monitor_.StopCuda("FinalisePosition"); } bool UpdatePredictionCache( @@ -1431,64 +1519,6 @@ class GPUHistMakerSpecialised{ return true; } - struct ExpandEntry { - int nid; - int depth; - DeviceSplitCandidate split; - uint64_t timestamp; - ExpandEntry(int _nid, int _depth, const DeviceSplitCandidate _split, - uint64_t _timestamp) : - nid{_nid}, depth{_depth}, split(std::move(_split)), - timestamp{_timestamp} {} - bool IsValid(const TrainParam& param, int num_leaves) const { - if (split.loss_chg <= kRtEps) { - return false; - } - if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { - return false; - } - if (param.max_depth > 0 && depth == param.max_depth) { - return false; - } - if (param.max_leaves > 0 && num_leaves == param.max_leaves) { - return false; - } - return true; - } - - static bool ChildIsValid(const TrainParam& param, int depth, - int num_leaves) { - if (param.max_depth > 0 && depth >= param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; - return true; - } - - friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { - os << "ExpandEntry: \n"; - os << "nidx: " << e.nid << "\n"; - os << "depth: " << e.depth << "\n"; - os << "loss: " << e.split.loss_chg << "\n"; - os << "left_sum: " << e.split.left_sum << "\n"; - os << "right_sum: " << e.split.right_sum << "\n"; - return os; - } - }; - - inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) { - if (lhs.depth == rhs.depth) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.depth > rhs.depth; // favor small depth - } - } - inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { - if (lhs.split.loss_chg == rhs.split.loss_chg) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg - } - } - TrainParam param_; // NOLINT common::HistCutMatrix hmat_; // NOLINT MetaInfo* info_; // NOLINT @@ -1507,8 +1537,9 @@ class GPUHistMakerSpecialised{ GPUHistMakerTrainParam hist_maker_param_; common::GHistIndexMatrix gmat_; - using ExpandQueue = std::priority_queue, - std::function>; + using ExpandQueue = + std::priority_queue, + std::function>; std::unique_ptr qexpand_; dh::AllReducer reducer_; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 9d64114e3..0b279a943 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -26,7 +26,7 @@ bool FileExists(const std::string& filename); -long GetFileSize(const std::string& filename); +int64_t GetFileSize(const std::string& filename); void CreateSimpleTestData(const std::string& filename); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 12ef32917..3616f85c8 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -39,8 +39,9 @@ void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, 0.26f, 0.74f, 1.98f, 0.26f, 0.71f, 1.83f}; - shard->InitRowPtrs(batch); - shard->InitCompressedData(cmat, batch); + auto is_dense = (*dmat)->Info().num_nonzero_ == + (*dmat)->Info().num_row_ * (*dmat)->Info().num_col_; + shard->InitCompressedData(cmat, batch, is_dense); delete dmat; } @@ -59,7 +60,7 @@ TEST(GpuHist, BuildGidxDense) { h_gidx_buffer = shard.gidx_buffer.AsVector(); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); - ASSERT_EQ(shard.row_stride, kNCols); + ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols); std::vector solution = { 0, 3, 8, 9, 14, 17, 20, 21, @@ -98,7 +99,7 @@ TEST(GpuHist, BuildGidxSparse) { h_gidx_buffer = shard.gidx_buffer.AsVector(); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); - ASSERT_LE(shard.row_stride, 3); + ASSERT_LE(shard.ellpack_matrix.row_stride, 3); // row_stride = 3, 16 rows, 48 entries for ELLPack std::vector solution = { @@ -106,7 +107,7 @@ TEST(GpuHist, BuildGidxSparse) { 24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24, 24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24 }; - for (size_t i = 0; i < kNRows * shard.row_stride; ++i) { + for (size_t i = 0; i < kNRows * shard.ellpack_matrix.row_stride; ++i) { ASSERT_EQ(solution[i], gidx[i]); } } @@ -256,16 +257,19 @@ TEST(GpuHist, EvaluateSplits) { common::HistCutMatrix cmat = GetHostCutMatrix(); // Copy cut matrix to device. - DeviceShard::DeviceHistCutMatrix cut; shard->ba.Allocate(0, - &(shard->d_cut.feature_segments), cmat.row_ptr.size(), - &(shard->d_cut.min_fvalue), cmat.min_val.size(), - &(shard->d_cut.gidx_fvalue_map), 24, + &(shard->feature_segments), cmat.row_ptr.size(), + &(shard->min_fvalue), cmat.min_val.size(), + &(shard->gidx_fvalue_map), 24, &(shard->monotone_constraints), kNCols); - shard->d_cut.feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end()); - shard->d_cut.gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end()); + shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end()); + shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end()); shard->monotone_constraints.copy(param.monotone_constraints.begin(), param.monotone_constraints.end()); + shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan(); + shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan(); + shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end()); + shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan(); // Initialize DeviceShard::hist shard->hist.Init(0, (max_bins - 1) * kNCols); @@ -339,7 +343,7 @@ TEST(GpuHist, ApplySplit) { shard->ridx_segments[0] = Segment(0, kNRows); shard->ba.Allocate(0, &(shard->ridx), kNRows, &(shard->position), kNRows); - shard->row_stride = kNCols; + shard->ellpack_matrix.row_stride = kNCols; thrust::sequence(shard->ridx.CurrentDVec().tbegin(), shard->ridx.CurrentDVec().tend()); // Initialize GPUHistMaker @@ -351,11 +355,9 @@ TEST(GpuHist, ApplySplit) { 0.59, 4, // fvalue has to be equal to one of the cut field GradientPair(8.2, 2.8), GradientPair(6.3, 3.6), GPUTrainingParam(param)); - GPUHistMakerSpecialised::ExpandEntry candidate_entry {0, 0, candidate, 0}; + ExpandEntry candidate_entry {0, 0, candidate, 0}; candidate_entry.nid = kNId; - auto const& nodes = tree.GetNodes(); - // Used to get bin_id in update position. common::HistCutMatrix cmat = GetHostCutMatrix(); hist_maker.hmat_ = cmat; @@ -370,19 +372,31 @@ TEST(GpuHist, ApplySplit) { int row_stride = kNCols; int num_symbols = n_bins + 1; size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize( - row_stride * kNRows, num_symbols); - shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes); + common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows, + num_symbols); + shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes, + &(shard->feature_segments), cmat.row_ptr.size(), + &(shard->min_fvalue), cmat.min_val.size(), + &(shard->gidx_fvalue_map), 24); + shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end()); + shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end()); + shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan(); + shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan(); + shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end()); + shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan(); + shard->ellpack_matrix.is_dense = true; common::CompressedBufferWriter wr(num_symbols); - std::vector h_gidx (kNRows * row_stride); - std::iota(h_gidx.begin(), h_gidx.end(), 0); + // gidx 14 should go right, 12 goes left + std::vector h_gidx (kNRows * row_stride, 14); + h_gidx[4] = 12; + h_gidx[12] = 12; std::vector h_gidx_compressed (compressed_size_bytes); wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end()); shard->gidx_buffer.copy(h_gidx_compressed.begin(), h_gidx_compressed.end()); - shard->gidx = common::CompressedIterator( + shard->ellpack_matrix.gidx_iter = common::CompressedIterator( shard->gidx_buffer.Data(), num_symbols); hist_maker.info_ = &info; @@ -395,8 +409,8 @@ TEST(GpuHist, ApplySplit) { int right_nidx = tree[kNId].RightChild(); ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0); - ASSERT_EQ(shard->ridx_segments[left_nidx].end, 6); - ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 6); + ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2); + ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2); ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16); } @@ -417,7 +431,7 @@ void TestSortPosition(const std::vector& position_in, int left_idx, common::Span(position_out.data().get(), position_out.size()), common::Span(ridx.data().get(), ridx.size()), common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, - right_idx, d_left_count.data().get()); + right_idx, d_left_count.data().get(), nullptr); thrust::host_vector position_result = position_out; thrust::host_vector ridx_result = ridx_out;