diff --git a/include/xgboost/data.h b/include/xgboost/data.h index f949c338a..c663d8f3b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -166,6 +166,15 @@ struct BatchParam { int max_bin; /*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */ int gpu_batch_nrows; + /*! \brief Page size for external memory mode. */ + size_t gpu_page_size; + + inline bool operator!=(const BatchParam& other) const { + return gpu_id != other.gpu_id || + max_bin != other.max_bin || + gpu_batch_nrows != other.gpu_batch_nrows || + gpu_page_size != other.gpu_page_size; + } }; /*! diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index 4b2673ece..b250c020c 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -21,6 +21,8 @@ struct GenericParameter : public XGBoostParameter { int nthread; // primary device, -1 means no gpu. int gpu_id; + // gpu page size in external memory mode, 0 means using the default. + size_t gpu_page_size; void CheckDeprecated() { if (this->n_gpus != 0) { @@ -49,6 +51,10 @@ struct GenericParameter : public XGBoostParameter { .set_default(-1) .set_lower_bound(-1) .describe("The primary GPU device ordinal."); + DMLC_DECLARE_FIELD(gpu_page_size) + .set_default(0) + .set_lower_bound(0) + .describe("GPU page size when running in external memory mode."); DMLC_DECLARE_FIELD(n_gpus) .set_default(0) .set_range(0, 1) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 6d52e03e1..b23777458 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -164,8 +164,9 @@ class GPUSketcher { auto counting = thrust::make_counting_iterator(size_t(0)); using TransformT = thrust::transform_iterator; TransformT row_size_iter = TransformT(counting, get_size); - row_stride_ = + size_t batch_row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum()); + row_stride_ = std::max(row_stride_, batch_row_stride); } // This needs to be public because of the __device__ lambda. diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index b2a081e7e..a58c74e6f 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -69,6 +69,8 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.Init("ellpack_page"); dh::safe_cuda(cudaSetDevice(param.gpu_id)); + matrix.n_rows = dmat->Info().num_row_; + monitor_.StartCuda("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts. common::HistogramCuts hmat; @@ -206,7 +208,7 @@ void EllpackPageImpl::CreateHistIndices(int device, // Return the number of rows contained in this page. size_t EllpackPageImpl::Size() const { - return n_rows; + return matrix.n_rows; } // Clear the current page. @@ -214,44 +216,50 @@ void EllpackPageImpl::Clear() { ba_.Clear(); gidx_buffer = {}; idx_buffer.clear(); - n_rows = 0; + sparse_page_.Clear(); + matrix.base_rowid = 0; + matrix.n_rows = 0; + device_initialized_ = false; } // Push a CSR page to the current page. // -// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and -// appended to the existing host vector. +// The CSR pages are accumulated in memory until they reach a certain size, then written out as +// compressed ELLPACK. void EllpackPageImpl::Push(int device, const SparsePage& batch) { + sparse_page_.Push(batch); + matrix.n_rows += batch.Size(); +} + +// Compress the accumulated SparsePage. +void EllpackPageImpl::CompressSparsePage(int device) { monitor_.StartCuda("InitCompressedData"); - InitCompressedData(device, batch.Size()); + InitCompressedData(device, matrix.n_rows); monitor_.StopCuda("InitCompressedData"); monitor_.StartCuda("BinningCompression"); - DeviceHistogramBuilderState hist_builder_row_state(batch.Size()); - hist_builder_row_state.BeginBatch(batch); - CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice()); + DeviceHistogramBuilderState hist_builder_row_state(matrix.n_rows); + hist_builder_row_state.BeginBatch(sparse_page_); + CreateHistIndices(device, sparse_page_, hist_builder_row_state.GetRowStateOnDevice()); hist_builder_row_state.EndBatch(); monitor_.StopCuda("BinningCompression"); monitor_.StartCuda("CopyDeviceToHost"); - std::vector buffer(gidx_buffer.size()); - dh::CopyDeviceSpanToVector(&buffer, gidx_buffer); - int offset = 0; - if (!idx_buffer.empty()) { - offset = ::xgboost::common::detail::kPadding; - } - idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset); - idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end()); + idx_buffer.resize(gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&idx_buffer, gidx_buffer); ba_.Clear(); gidx_buffer = {}; monitor_.StopCuda("CopyDeviceToHost"); - - n_rows += batch.Size(); } // Return the memory cost for storing the compressed features. size_t EllpackPageImpl::MemCostBytes() const { - return idx_buffer.size() * sizeof(common::CompressedByteT); + size_t num_symbols = matrix.info.n_bins + 1; + + // Required buffer size for storing data matrix in ELLPack format. + size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( + matrix.info.row_stride * matrix.n_rows, num_symbols); + return compressed_size_bytes; } // Copy the compressed features to GPU. diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 1b38fcfa6..47fd98910 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -78,13 +78,14 @@ struct EllpackInfo { * kernels.*/ struct EllpackMatrix { EllpackInfo info; + size_t base_rowid{}; + size_t n_rows{}; common::CompressedIterator gidx_iter; - XGBOOST_DEVICE size_t BinCount() const { return info.gidx_fvalue_map.size(); } - // Get a matrix element, uses binary search for look up Return NaN if missing // Given a row index and a feature index, returns the corresponding cut value __device__ bst_float GetElement(size_t ridx, size_t fidx) const { + ridx -= base_rowid; auto row_begin = info.row_stride * ridx; auto row_end = row_begin + info.row_stride; auto gidx = -1; @@ -102,6 +103,11 @@ struct EllpackMatrix { } return info.gidx_fvalue_map[gidx]; } + + // Check if the row id is withing range of the current batch. + __device__ bool IsInRange(size_t row_id) const { + return row_id >= base_rowid && row_id < base_rowid + n_rows; + } }; // Instances of this type are created while creating the histogram bins for the @@ -185,7 +191,6 @@ class EllpackPageImpl { /*! \brief global index of histogram, which is stored in ELLPack format. */ common::Span gidx_buffer; std::vector idx_buffer; - size_t n_rows{}; /*! * \brief Default constructor. @@ -240,7 +245,7 @@ class EllpackPageImpl { /*! \brief Set the base row id for this page. */ inline void SetBaseRowId(size_t row_id) { - base_rowid_ = row_id; + matrix.base_rowid = row_id; } /*! \brief clear the page. */ @@ -263,11 +268,17 @@ class EllpackPageImpl { */ void InitDevice(int device, EllpackInfo info); + /*! \brief Compress the accumulated SparsePage into ELLPACK format. + * + * @param device The GPU device to use. + */ + void CompressSparsePage(int device); + private: common::Monitor monitor_; dh::BulkAllocator ba_; - size_t base_rowid_{}; bool device_initialized_{false}; + SparsePage sparse_page_{}; }; } // namespace xgboost diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index fc8dcde62..7760b13dc 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -17,7 +17,8 @@ class EllpackPageRawFormat : public SparsePageFormat { public: bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { auto* impl = page->Impl(); - if (!fi->Read(&impl->n_rows)) return false; + impl->Clear(); + if (!fi->Read(&impl->matrix.n_rows)) return false; return fi->Read(&impl->idx_buffer); } @@ -25,13 +26,14 @@ class EllpackPageRawFormat : public SparsePageFormat { dmlc::SeekStream* fi, const std::vector& sorted_index_set) override { auto* impl = page->Impl(); - if (!fi->Read(&impl->n_rows)) return false; + impl->Clear(); + if (!fi->Read(&impl->matrix.n_rows)) return false; return fi->Read(&page->Impl()->idx_buffer); } void Write(const EllpackPage& page, dmlc::Stream* fo) override { auto* impl = page.Impl(); - fo->Write(impl->n_rows); + fo->Write(impl->matrix.n_rows); auto buffer = impl->idx_buffer; CHECK(!buffer.empty()); fo->Write(buffer); diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index e2456d9a4..dfa548e16 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -40,11 +40,13 @@ class EllpackPageSourceImpl : public DataSource { const std::string kPageType_{".ellpack.page"}; int device_{-1}; + size_t page_size_{DMatrix::kPageSize}; common::Monitor monitor_; dh::BulkAllocator ba_; /*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */ EllpackInfo ellpack_info_; std::unique_ptr> source_; + std::string cache_info_; }; EllpackPageSource::EllpackPageSource(DMatrix* dmat, @@ -72,8 +74,12 @@ const EllpackPage& EllpackPageSource::Value() const { // each CSR page, and write the accumulated ELLPACK pages to disk. EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, const std::string& cache_info, - const BatchParam& param) noexcept(false) { - device_ = param.gpu_id; + const BatchParam& param) noexcept(false) + : device_(param.gpu_id), cache_info_(cache_info) { + + if (param.gpu_page_size > 0) { + page_size_ = param.gpu_page_size; + } monitor_.Init("ellpack_page_source"); dh::safe_cuda(cudaSetDevice(device_)); @@ -92,10 +98,11 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, WriteEllpackPages(dmat, cache_info); monitor_.StopCuda("WriteEllpackPages"); - source_.reset(new SparsePageSource(cache_info, kPageType_)); + source_.reset(new SparsePageSource(cache_info_, kPageType_)); } void EllpackPageSourceImpl::BeforeFirst() { + source_.reset(new SparsePageSource(cache_info_, kPageType_)); source_->BeforeFirst(); } @@ -133,20 +140,23 @@ void EllpackPageSourceImpl::WriteEllpackPages(DMatrix* dmat, const std::string& for (const auto& batch : dmat->GetBatches()) { impl->Push(device_, batch); - if (impl->MemCostBytes() >= DMatrix::kPageSize) { - bytes_write += impl->MemCostBytes(); + size_t mem_cost_bytes = impl->MemCostBytes(); + if (mem_cost_bytes >= page_size_) { + bytes_write += mem_cost_bytes; + impl->CompressSparsePage(device_); writer.PushWrite(std::move(page)); writer.Alloc(&page); impl = page->Impl(); impl->matrix.info = ellpack_info_; impl->Clear(); double tdiff = dmlc::GetTime() - tstart; - LOG(INFO) << "Writing to " << cache_info << " in " + LOG(INFO) << "Writing " << kPageType_ << " to " << cache_info << " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, " << (bytes_write >> 20UL) << " written"; } } if (impl->Size() != 0) { + impl->CompressSparsePage(device_); writer.PushWrite(std::move(page)); } } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 4d90067a2..67712e26e 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -81,10 +81,7 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& par CHECK_GE(param.gpu_id, 0); CHECK_GE(param.max_bin, 2); // Lazily instantiate - if (!ellpack_source_ || - batch_param_.gpu_id != param.gpu_id || - batch_param_.max_bin != param.max_bin || - batch_param_.gpu_batch_nrows != param.gpu_batch_nrows) { + if (!ellpack_source_ || batch_param_ != param) { ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param)); batch_param_ = param; } diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 4b6dcfb60..2d9500faf 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -33,6 +33,7 @@ class RowPartitioner { public: using RowIndexT = bst_uint; struct Segment; + static constexpr bst_node_t kIgnoredTreePosition = -1; private: int device_idx; @@ -124,6 +125,7 @@ class RowPartitioner { idx += segment.begin; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx); // new node id + if (new_position == kIgnoredTreePosition) return; KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); AtomicIncrement(d_left_count, new_position == left_nidx); d_position[idx] = new_position; @@ -163,7 +165,9 @@ class RowPartitioner { dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) { auto position = d_position[idx]; RowIndexT ridx = d_ridx[idx]; - d_position[idx] = op(ridx, position); + bst_node_t new_position = op(ridx, position); + if (new_position == kIgnoredTreePosition) return; + d_position[idx] = new_position; }); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 628e3efca..c06f64656 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -409,13 +409,16 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, extern __shared__ char smem[]; GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT()); + dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT()); __syncthreads(); } for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / matrix.info.row_stride ]; - int gidx = - matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; + int ridx = d_ridx[idx / matrix.info.row_stride]; + if (!matrix.IsInRange(ridx)) { + continue; + } + int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride + + idx % matrix.info.row_stride]; if (gidx != matrix.info.n_bins) { // If we are not using shared memory, accumulate the values directly into // global memory @@ -428,8 +431,7 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, if (use_shared_memory_histograms) { // Write shared memory back to global memory __syncthreads(); - for (auto i : - dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { + for (auto i : dh::BlockStrideRange(static_cast(0), matrix.info.n_bins)) { dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]); } } @@ -440,6 +442,7 @@ template struct GPUHistMakerDevice { int device_id; EllpackPageImpl* page; + BatchParam batch_param; dh::BulkAllocator ba; @@ -481,14 +484,16 @@ struct GPUHistMakerDevice { bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, - uint32_t n_features) + uint32_t n_features, + BatchParam _batch_param) : device_id(_device_id), page(_page), n_rows(_n_rows), param(std::move(_param)), prediction_cache_initialised(false), column_sampler(column_sampler_seed), - interaction_constraints(param, n_features) { + interaction_constraints(param, n_features), + batch_param(_batch_param) { monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -626,6 +631,14 @@ struct GPUHistMakerDevice { return std::vector(result_all.begin(), result_all.end()); } + // Build gradient histograms for a given node across all the batches in the DMatrix. + void BuildHistBatches(int nidx, DMatrix* p_fmat) { + for (auto& batch : p_fmat->GetBatches(batch_param)) { + page = batch.Impl(); + BuildHist(nidx); + } + } + void BuildHist(int nidx) { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); @@ -636,7 +649,7 @@ struct GPUHistMakerDevice { const size_t smem_size = use_shared_memory_histograms - ? sizeof(GradientSumT) * page->matrix.BinCount() + ? sizeof(GradientSumT) * page->matrix.info.n_bins : 0; uint32_t items_per_thread = 8; uint32_t block_threads = 256; @@ -673,7 +686,10 @@ struct GPUHistMakerDevice { row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), - [=] __device__(bst_uint ridx) { + [=] __device__(size_t ridx) { + if (!d_matrix.IsInRange(ridx)) { + return RowPartitioner::kIgnoredTreePosition; + } // given a row index, returns the node id it belongs to bst_float cut_value = d_matrix.GetElement(ridx, split_node.SplitIndex()); @@ -693,35 +709,42 @@ struct GPUHistMakerDevice { } // After tree update is finished, update the position of all training - // instances to their final leaf This information is used later to update the + // instances to their final leaf. This information is used later to update the // prediction cache - void FinalisePosition(RegTree* p_tree) { + void FinalisePosition(RegTree* p_tree, DMatrix* p_fmat) { const auto d_nodes = temp_memory.GetSpan(p_tree->GetNodes().size()); dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - auto d_matrix = page->matrix; - row_partitioner->FinalisePosition( - [=] __device__(bst_uint ridx, int position) { - auto node = d_nodes[position]; - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetElement(ridx, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); - } + for (auto& batch : p_fmat->GetBatches(batch_param)) { + page = batch.Impl(); + auto d_matrix = page->matrix; + row_partitioner->FinalisePosition( + [=] __device__(size_t row_id, int position) { + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; } - node = d_nodes[position]; - } - return position; - }); + auto node = d_nodes[position]; + + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + if (element <= node.SplitCond()) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; + } + return position; + }); + } } void UpdatePredictionCache(bst_float* out_preds_d) { @@ -764,7 +787,7 @@ struct GPUHistMakerDevice { reducer->AllReduceSum( reinterpret_cast(d_node_hist), reinterpret_cast(d_node_hist), - page->matrix.BinCount() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); + page->matrix.info.n_bins * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); reducer->Synchronize(); monitor.StopCuda("AllReduce"); @@ -773,12 +796,10 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, - int nidx_right, dh::AllReducer* reducer) { + void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) { auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; - // Decide whether to build the left histogram or right histogram // Use sum of Hessian as a heuristic to select node with fewest training instances bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); @@ -787,22 +808,50 @@ struct GPUHistMakerDevice { } this->BuildHist(build_hist_nidx); - this->AllReduceHist(build_hist_nidx, reducer); // Check whether we can use the subtraction trick to calculate the other bool do_subtraction_trick = this->CanDoSubtractionTrick( candidate.nid, build_hist_nidx, subtraction_trick_nidx); + if (!do_subtraction_trick) { + // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); + } + } + + /** + * \brief AllReduce GPU histograms for the left and right child of some parent node. + */ + void ReduceHistLeftRight(const ExpandEntry& candidate, + int nidx_left, + int nidx_right, + dh::AllReducer* reducer) { + auto build_hist_nidx = nidx_left; + auto subtraction_trick_nidx = nidx_right; + + // Decide whether to build the left histogram or right histogram + // Use sum of Hessian as a heuristic to select node with fewest training instances + bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); + if (fewer_right) { + std::swap(build_hist_nidx, subtraction_trick_nidx); + } + + this->AllReduceHist(build_hist_nidx, reducer); + + // Check whether we can use the subtraction trick to calculate the other + bool do_subtraction_trick = this->CanDoSubtractionTrick( + candidate.nid, build_hist_nidx, subtraction_trick_nidx); + if (do_subtraction_trick) { // Calculate other histogram using subtraction trick this->SubtractionTrick(candidate.nid, build_hist_nidx, subtraction_trick_nidx); } else { // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); this->AllReduceHist(subtraction_trick_nidx, reducer); } } + void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; @@ -839,7 +888,7 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, + void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, DMatrix* p_fmat, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; @@ -855,7 +904,7 @@ struct GPUHistMakerDevice { node_sum_gradients_d.data(), sizeof(GradientPair), cudaMemcpyDeviceToHost)); - this->BuildHist(kRootNIdx); + this->BuildHistBatches(kRootNIdx, p_fmat); this->AllReduceHist(kRootNIdx, reducer); // Remember root stats @@ -882,7 +931,7 @@ struct GPUHistMakerDevice { monitor.StopCuda("Reset"); monitor.StartCuda("InitRoot"); - this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); + this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); auto timestamp = qexpand->size(); @@ -901,15 +950,21 @@ struct GPUHistMakerDevice { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed - if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { - monitor.StartCuda("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); - monitor.StopCuda("UpdatePosition"); + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { + for (auto& batch : p_fmat->GetBatches(batch_param)) { + page = batch.Impl(); - monitor.StartCuda("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.StopCuda("BuildHist"); + monitor.StartCuda("UpdatePosition"); + this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + monitor.StopCuda("UpdatePosition"); + + monitor.StartCuda("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx); + monitor.StopCuda("BuildHist"); + } + monitor.StartCuda("ReduceHist"); + this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.StopCuda("ReduceHist"); monitor.StartCuda("EvaluateSplits"); auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx}, @@ -926,7 +981,7 @@ struct GPUHistMakerDevice { } monitor.StartCuda("FinalisePosition"); - this->FinalisePosition(p_tree); + this->FinalisePosition(p_tree, p_fmat); monitor.StopCuda("FinalisePosition"); } }; @@ -1016,21 +1071,21 @@ class GPUHistMakerSpecialised { uint32_t column_sampling_seed = common::GlobalRandom()(); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); - // TODO(rongou): support multiple Ellpack pages. - EllpackPageImpl* page{}; - for (auto& batch : dmat->GetBatches({device_, - param_.max_bin, - hist_maker_param_.gpu_batch_nrows})) { - page = batch.Impl(); - } - + BatchParam batch_param{ + device_, + param_.max_bin, + hist_maker_param_.gpu_batch_nrows, + generic_param_->gpu_page_size + }; + auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); dh::safe_cuda(cudaSetDevice(device_)); maker.reset(new GPUHistMakerDevice(device_, page, info_->num_row_, param_, column_sampling_seed, - info_->num_col_)); + info_->num_col_, + batch_param)); monitor_.StartCuda("InitHistogram"); dh::safe_cuda(cudaSetDevice(device_)); diff --git a/tests/benchmark/generate_libsvm.py b/tests/benchmark/generate_libsvm.py new file mode 100644 index 000000000..b0ec27318 --- /dev/null +++ b/tests/benchmark/generate_libsvm.py @@ -0,0 +1,87 @@ +"""Generate synthetic data in LibSVM format.""" + +import argparse +import io +import time + +import numpy as np +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split + +RNG = np.random.RandomState(2019) + + +def generate_data(args): + """Generates the data.""" + print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) + print("Sparsity {}".format(args.sparsity)) + print("{}/{} train/test split".format(1.0 - args.test_size, args.test_size)) + + tmp = time.time() + n_informative = args.columns * 7 // 10 + n_redundant = args.columns // 10 + n_repeated = args.columns // 10 + print("n_informative: {}, n_redundant: {}, n_repeated: {}".format(n_informative, n_redundant, + n_repeated)) + x, y = make_classification(n_samples=args.rows, n_features=args.columns, + n_informative=n_informative, n_redundant=n_redundant, + n_repeated=n_repeated, shuffle=False, random_state=RNG) + print("Generate Time: {} seconds".format(time.time() - tmp)) + + tmp = time.time() + x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=args.test_size, + random_state=RNG, shuffle=False) + print("Train/Test Split Time: {} seconds".format(time.time() - tmp)) + + tmp = time.time() + write_file('train.libsvm', x_train, y_train, args.sparsity) + print("Write Train Time: {} seconds".format(time.time() - tmp)) + + tmp = time.time() + write_file('test.libsvm', x_test, y_test, args.sparsity) + print("Write Test Time: {} seconds".format(time.time() - tmp)) + + +def write_file(filename, x_data, y_data, sparsity): + with open(filename, 'w') as f: + for x, y in zip(x_data, y_data): + write_line(f, x, y, sparsity) + + +def write_line(f, x, y, sparsity): + with io.StringIO() as line: + line.write(str(y)) + for i, col in enumerate(x): + if 0.0 < sparsity < 1.0: + if RNG.uniform(0, 1) > sparsity: + write_feature(line, i, col) + else: + write_feature(line, i, col) + line.write('\n') + f.write(line.getvalue()) + + +def write_feature(line, index, feature): + line.write(' ') + line.write(str(index)) + line.write(':') + line.write(str(feature)) + + +def main(): + """The main function. + + Defines and parses command line arguments and calls the generator. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--rows', type=int, default=1000000) + parser.add_argument('--columns', type=int, default=50) + parser.add_argument('--sparsity', type=float, default=0.0) + parser.add_argument('--test_size', type=float, default=0.01) + args = parser.parse_args() + + generate_data(args) + + +if __name__ == '__main__': + main() diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index c95d86817..920abec8d 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -2,10 +2,11 @@ #include #include "../helpers.h" +#include "../../../src/common/compressed_iterator.h" namespace xgboost { -TEST(GPUSparsePageDMatrix, EllpackPage) { +TEST(SparsePageDMatrix, EllpackPage) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); @@ -23,4 +24,162 @@ TEST(GPUSparsePageDMatrix, EllpackPage) { delete dmat; } +TEST(SparsePageDMatrix, MultipleEllpackPages) { + dmlc::TemporaryDirectory tmpdir; + std::string filename = tmpdir.path + "/big.libsvm"; + std::unique_ptr dmat = CreateSparsePageDMatrix(12, 64, filename); + + // Loop over the batches and count the records + int64_t batch_count = 0; + int64_t row_count = 0; + for (const auto& batch : dmat->GetBatches({0, 256, 0, 7UL})) { + EXPECT_LT(batch.Size(), dmat->Info().num_row_); + batch_count++; + row_count += batch.Size(); + } + EXPECT_GE(batch_count, 2); + EXPECT_EQ(row_count, dmat->Info().num_row_); + + EXPECT_TRUE(FileExists(filename + ".cache.ellpack.page")); +} + +TEST(SparsePageDMatrix, EllpackPageContent) { + constexpr size_t kRows = 6; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + BatchParam param{0, 2, 0, 0}; + auto impl = (*dmat->GetBatches(param).begin()).Impl(); + EXPECT_EQ(impl->matrix.base_rowid, 0); + EXPECT_EQ(impl->matrix.n_rows, kRows); + EXPECT_FALSE(impl->matrix.info.is_dense); + EXPECT_EQ(impl->matrix.info.row_stride, 2); + EXPECT_EQ(impl->matrix.info.n_bins, 4); + + auto impl_ext = (*dmat_ext->GetBatches(param).begin()).Impl(); + EXPECT_EQ(impl_ext->matrix.base_rowid, 0); + EXPECT_EQ(impl_ext->matrix.n_rows, kRows); + EXPECT_FALSE(impl_ext->matrix.info.is_dense); + EXPECT_EQ(impl_ext->matrix.info.row_stride, 2); + EXPECT_EQ(impl_ext->matrix.info.n_bins, 4); + + std::vector buffer(impl->gidx_buffer.size()); + std::vector buffer_ext(impl_ext->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&buffer, impl->gidx_buffer); + dh::CopyDeviceSpanToVector(&buffer_ext, impl_ext->gidx_buffer); + EXPECT_EQ(buffer, buffer_ext); +} + +struct ReadRowFunction { + EllpackMatrix matrix; + int row; + bst_float* row_data_d; + ReadRowFunction(EllpackMatrix matrix, int row, bst_float* row_data_d) + : matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {} + + __device__ void operator()(size_t col) { + auto value = matrix.GetElement(row, col); + if (isnan(value)) { + value = -1; + } + row_data_d[col] = value; + } +}; + +TEST(SparsePageDMatrix, MultipleEllpackPageContent) { + constexpr size_t kRows = 6; + constexpr size_t kCols = 2; + constexpr int kMaxBins = 256; + constexpr size_t kPageSize = 1; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + BatchParam param{0, kMaxBins, 0, kPageSize}; + auto impl = (*dmat->GetBatches(param).begin()).Impl(); + EXPECT_EQ(impl->matrix.base_rowid, 0); + EXPECT_EQ(impl->matrix.n_rows, kRows); + + size_t current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_ext_d(kCols); + std::vector row(kCols); + std::vector row_ext(kCols); + for (auto& page : dmat_ext->GetBatches(param)) { + auto impl_ext = page.Impl(); + EXPECT_EQ(impl_ext->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl_ext->Size(); i++) { + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get())); + thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin()); + + EXPECT_EQ(row, row_ext); + current_row++; + } + } +} + +TEST(SparsePageDMatrix, EllpackPageMultipleLoops) { + constexpr size_t kRows = 1024; + constexpr size_t kCols = 16; + constexpr int kMaxBins = 256; + constexpr size_t kPageSize = 4096; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + BatchParam param{0, kMaxBins, 0, kPageSize}; + auto impl = (*dmat->GetBatches(param).begin()).Impl(); + + size_t current_row = 0; + for (auto& page : dmat_ext->GetBatches(param)) { + auto impl_ext = page.Impl(); + EXPECT_EQ(impl_ext->matrix.base_rowid, current_row); + current_row += impl_ext->matrix.n_rows; + } + + current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_ext_d(kCols); + std::vector row(kCols); + std::vector row_ext(kCols); + for (auto& page : dmat_ext->GetBatches(param)) { + auto impl_ext = page.Impl(); + EXPECT_EQ(impl_ext->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl_ext->Size(); i++) { + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get())); + thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin()); + + EXPECT_EQ(row, row_ext) << "for row " << current_row; + + current_row++; + } + } +} + } // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index f6878339d..ed6c5fa40 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -217,17 +217,17 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( } else { gen.reset(new std::mt19937(rdev())); } + std::uniform_int_distribution label(0, 1); std::uniform_int_distribution dis(1, n_cols); for (size_t i = 0; i < n_rows; ++i) { // Make sure that all cols are slotted in the first few rows; randomly distribute the // rest std::stringstream row_data; - fo << i; size_t j = 0; if (rem_cols > 0) { for (; j < std::min(static_cast(rem_cols), cols_per_row); ++j) { - row_data << " " << (col_idx+j) << ":" << (col_idx+j+1)*10; + row_data << label(*gen) << " " << (col_idx+j) << ":" << (col_idx+j+1)*10*i; } rem_cols -= cols_per_row; } else { @@ -235,7 +235,7 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( size_t ncols = dis(*gen); for (; j < ncols; ++j) { size_t fid = (col_idx+j) % n_cols; - row_data << " " << fid << ":" << (fid+1)*10; + row_data << label(*gen) << " " << fid << ":" << (fid+1)*10*i; } } col_idx += j; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 85afa9a6a..38d455c68 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -56,22 +56,6 @@ TEST(GpuHist, DeviceHistogram) { }; } -namespace { -class HistogramCutsWrapper : public common::HistogramCuts { - public: - using SuperT = common::HistogramCuts; - void SetValues(std::vector cuts) { - SuperT::cut_values_ = cuts; - } - void SetPtrs(std::vector ptrs) { - SuperT::cut_ptrs_ = ptrs; - } - void SetMins(std::vector mins) { - SuperT::min_vals_ = mins; - } -}; -} // anonymous namespace - std::vector GetHostHistGpair() { // 24 bins, 3 bins for each feature (column). std::vector hist_gpair = { @@ -98,7 +82,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { }; param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols); + BatchParam batch_param{}; + GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param); maker.InitHistogram(); xgboost::SimpleLCG gen; @@ -199,7 +184,9 @@ TEST(GpuHist, EvaluateSplits) { // Initialize GPUHistMakerDevice auto page = BuildEllpackPage(kNRows, kNCols); - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols); + BatchParam batch_param{}; + GPUHistMakerDevice + maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {{6.4f, 12.8f}}; @@ -332,21 +319,25 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector GenerateRandomGradients(const size_t n_rows) { xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - std::vector h_gpair(kRows); + std::vector h_gpair(n_rows); for (auto &gpair : h_gpair) { bst_float grad = dist(&gen); bst_float hess = dist(&gen); gpair = GradientPair(grad, hess); } HostDeviceVector gpair(h_gpair); + return gpair; +} + +TEST(GpuHist, MinSplitLoss) { + constexpr size_t kRows = 32; + constexpr size_t kCols = 16; + constexpr float kSparsity = 0.6; + auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3); + auto gpair = GenerateRandomGradients(kRows); { int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair); @@ -363,5 +354,75 @@ TEST(GpuHist, MinSplitLoss) { delete dmat; } +void UpdateTree(HostDeviceVector* gpair, + DMatrix* dmat, + size_t gpu_page_size, + RegTree* tree, + HostDeviceVector* preds) { + constexpr size_t kMaxBin = 2; + + if (gpu_page_size > 0) { + // Loop over the batches and count the records + int64_t batch_count = 0; + int64_t row_count = 0; + for (const auto& batch : dmat->GetBatches({0, kMaxBin, 0, gpu_page_size})) { + EXPECT_LT(batch.Size(), dmat->Info().num_row_); + batch_count++; + row_count += batch.Size(); + } + EXPECT_GE(batch_count, 2); + EXPECT_EQ(row_count, dmat->Info().num_row_); + } + + Args args{ + {"max_depth", "2"}, + {"max_bin", std::to_string(kMaxBin)}, + {"min_child_weight", "0.0"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"} + }; + + tree::GPUHistMakerSpecialised hist_maker; + GenericParameter generic_param(CreateEmptyGenericParam(0)); + generic_param.gpu_page_size = gpu_page_size; + hist_maker.Configure(args, &generic_param); + + hist_maker.Update(gpair, dmat, {tree}); + hist_maker.UpdatePredictionCache(dmat, preds); +} + +TEST(GpuHist, ExternalMemory) { + constexpr size_t kRows = 6; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + + // Build another tree using multiple ELLPACK pages. + RegTree tree_ext; + HostDeviceVector preds_ext(kRows, 0.0, 0); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_ext_h = preds_ext.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + ASSERT_FLOAT_EQ(preds_h[i], preds_ext_h[i]); + } +} + } // namespace tree } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 91b9e50d2..62467bfae 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -19,17 +19,19 @@ def assert_gpu_results(cpu_results, gpu_results): datasets = ["Boston", "Cancer", "Digits", "Sparse regression", "Sparse regression with weights", "Small weights regression"] +test_param = parameter_combinations({ + 'gpu_id': [0], + 'max_depth': [2, 8], + 'max_leaves': [255, 4], + 'max_bin': [2, 256], + 'grow_policy': ['lossguide'], + 'single_precision_histogram': [True], + 'min_child_weight': [0], + 'lambda': [0]}) + class TestGPU(unittest.TestCase): def test_gpu_hist(self): - test_param = parameter_combinations({'gpu_id': [0], - 'max_depth': [2, 8], - 'max_leaves': [255, 4], - 'max_bin': [2, 256], - 'grow_policy': ['lossguide']}) - test_param.append({'single_precision_histogram': True}) - test_param.append({'min_child_weight': 0, - 'lambda': 0}) for param in test_param: param['tree_method'] = 'gpu_hist' gpu_results = run_suite(param, select_datasets=datasets) @@ -38,6 +40,19 @@ class TestGPU(unittest.TestCase): cpu_results = run_suite(param, select_datasets=datasets) assert_gpu_results(cpu_results, gpu_results) + # NOTE(rongou): Because the `Boston` dataset is too small, this only tests external memory mode + # with a single page. To test multiple pages, set DMatrix::kPageSize to, say, 1024. + def test_external_memory(self): + for param in reversed(test_param): + param['tree_method'] = 'gpu_hist' + param['gpu_page_size'] = 1024 + gpu_results = run_suite(param, select_datasets=["Boston"]) + assert_results_non_increasing(gpu_results, 1e-2) + ext_mem_results = run_suite(param, select_datasets=["Boston External Memory"]) + assert_results_non_increasing(ext_mem_results, 1e-2) + assert_gpu_results(gpu_results, ext_mem_results) + break + def test_with_empty_dmatrix(self): # FIXME(trivialfis): This should be done with all updaters kRows = 0