diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 2e94d2c5a..1ee1267aa 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -13,6 +13,8 @@ #include #include +#include +#include #include "../tree/param.h" #include "./host_device_vector.h" @@ -82,6 +84,36 @@ __global__ void UnpackFeaturesK } } +/*! + * \brief A container that holds the device sketches across all + * sparse page batches which are distributed to different devices. + * As sketches are aggregated by column, the mutex guards + * multiple devices pushing sketch summary for the same column + * across distinct rows. + */ +struct SketchContainer { + std::vector sketches_; // NOLINT + std::vector col_locks_; // NOLINT + static constexpr int kOmpNumColsParallelizeLimit = 1000; + + SketchContainer(const tree::TrainParam ¶m, DMatrix *dmat) : + col_locks_(dmat->Info().num_col_) { + const MetaInfo &info = dmat->Info(); + // Initialize Sketches for this dmatrix + sketches_.resize(info.num_col_); +#pragma omp parallel for schedule(static) if (info.num_col_ > kOmpNumColsParallelizeLimit) + for (int icol = 0; icol < info.num_col_; ++icol) { + sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin)); + } + } + + // Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved + SketchContainer(const SketchContainer &) = delete; + SketchContainer(const SketchContainer &&) = delete; + SketchContainer &operator=(const SketchContainer &) = delete; + SketchContainer &operator=(const SketchContainer &&) = delete; +}; + // finds quantiles on the GPU struct GPUSketcher { // manage memory for a single GPU @@ -94,11 +126,11 @@ struct GPUSketcher { size_t n_cuts_{0}; size_t gpu_batch_nrows_{0}; bool has_weights_{false}; + size_t row_stride_{0}; tree::TrainParam param_; - std::vector sketches_; + SketchContainer *sketch_container_; thrust::device_vector row_ptrs_; - std::vector summaries_; thrust::device_vector entries_; thrust::device_vector fvalues_; thrust::device_vector feature_weights_; @@ -113,9 +145,13 @@ struct GPUSketcher { public: DeviceShard(int device, bst_uint row_begin, bst_uint row_end, - tree::TrainParam param) : + tree::TrainParam param, SketchContainer *sketch_container) : device_(device), row_begin_(row_begin), row_end_(row_end), - n_rows_(row_end - row_begin), param_(std::move(param)) { + n_rows_(row_end - row_begin), param_(std::move(param)), sketch_container_(sketch_container) { + } + + inline size_t GetRowStride() const { + return row_stride_; } void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { @@ -136,20 +172,10 @@ struct GPUSketcher { gpu_batch_nrows_ = n_rows_; } - // initialize sketches - sketches_.resize(num_cols_); - summaries_.resize(num_cols_); constexpr int kFactor = 8; double eps = 1.0 / (kFactor * param_.max_bin); size_t dummy_nlevel; - WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_); - // double ncuts to be the same as the number of values - // in the temporary buffers of the sketches - n_cuts_ *= 2; - for (int icol = 0; icol < num_cols_; ++icol) { - sketches_[icol].Init(row_batch.Size(), eps); - summaries_[icol].Reserve(n_cuts_); - } + WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); // allocate necessary GPU buffers dh::safe_cuda(cudaSetDevice(device_)); @@ -306,9 +332,12 @@ struct GPUSketcher { // unpack the features; also unpack weights if present thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); - thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + if (has_weights_) { + thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + } - dim3 block3(64, 4, 1); + dim3 block3(16, 64, 1); + // NOTE: This will typically support ~ 4M features - 64K*64 dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x), dh::DivRoundUp(num_cols_, block3.y), 1); UnpackFeaturesK<<>> @@ -324,12 +353,34 @@ struct GPUSketcher { // add cuts into sketches thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); +#pragma omp parallel for schedule(static) \ + if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { - summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); - sketches_[icol].PushSummary(summaries_[icol]); + WXQSketch::SummaryContainer summary; + summary.Reserve(n_cuts_); + summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); + + std::lock_guard lock(sketch_container_->col_locks_[icol]); + sketch_container_->sketches_[icol].PushSummary(summary); } } + void ComputeRowStride() { + // Find the row stride for this batch + auto row_iter = row_ptrs_.begin(); + // Functor for finding the maximum row size for this batch + auto get_size = [=] __device__(size_t row) { + return row_iter[row + 1] - row_iter[row]; + }; // NOLINT + + auto counting = thrust::make_counting_iterator(size_t(0)); + using TransformT = thrust::transform_iterator; + TransformT row_size_iter = TransformT(counting, get_size); + row_stride_ = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, + thrust::maximum()); + } + void Sketch(const SparsePage& row_batch, const MetaInfo& info) { // copy rows to the device dh::safe_cuda(cudaSetDevice(device_)); @@ -342,63 +393,71 @@ struct GPUSketcher { SketchBatch(row_batch, info, gpu_batch); } } - - void GetSummary(WXQSketch::SummaryContainer *summary, size_t const icol) { - sketches_[icol].GetSummary(summary); - } }; - void Sketch(const SparsePage& batch, const MetaInfo& info, - HistCutMatrix* hmat, int gpu_batch_nrows) { + void SketchBatch(const SparsePage &batch, const MetaInfo &info) { + GPUDistribution dist = + GPUDistribution::Block(GPUSet::All(learner_param_.gpu_id, learner_param_.n_gpus, + batch.Size())); + // create device shards - shards_.resize(dist_.Devices().Size()); + shards_.resize(dist.Devices().Size()); dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - size_t start = dist_.ShardStart(info.num_row_, i); - size_t size = dist_.ShardSize(info.num_row_, i); + size_t start = dist.ShardStart(batch.Size(), i); + size_t size = dist.ShardSize(batch.Size(), i); shard = std::unique_ptr( - new DeviceShard(dist_.Devices().DeviceId(i), - start, start + size, param_)); + new DeviceShard(dist.Devices().DeviceId(i), start, + start + size, param_, sketch_container_.get())); }); // compute sketches for each shard dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->Init(batch, info, gpu_batch_nrows); + shard->Init(batch, info, gpu_batch_nrows_); shard->Sketch(batch, info); + shard->ComputeRowStride(); }); - // merge the sketches from all shards - // TODO(canonizer): do it in a tree-like reduction - int num_cols = info.num_col_; - std::vector sketches(num_cols); - WXQSketch::SummaryContainer summary; - for (int icol = 0; icol < num_cols; ++icol) { - sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin)); - for (auto &shard : shards_) { - shard->GetSummary(&summary, icol); - sketches[icol].PushSummary(summary); - } + // compute row stride across all shards + for (const auto &shard : shards_) { + row_stride_ = std::max(row_stride_, shard->GetRowStride()); } - - hmat->Init(&sketches, param_.max_bin); } - GPUSketcher(tree::TrainParam param, GPUSet const& devices) : param_(std::move(param)) { - dist_ = GPUDistribution::Block(devices); + GPUSketcher(const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_nrows) + : param_(param), learner_param_(learner_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) { + } + + /* Builds the sketches on the GPU for the dmatrix and returns the row stride + * for the entire dataset */ + size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) { + const MetaInfo &info = dmat->Info(); + + row_stride_ = 0; + sketch_container_.reset(new SketchContainer(param_, dmat)); + for (const auto &batch : dmat->GetRowBatches()) { + this->SketchBatch(batch, info); + } + + hmat->Init(&sketch_container_.get()->sketches_, param_.max_bin); + + return row_stride_; } private: std::vector> shards_; - tree::TrainParam param_; - GPUDistribution dist_; + const tree::TrainParam ¶m_; + const LearnerTrainParam &learner_param_; + int gpu_batch_nrows_; + size_t row_stride_; + std::unique_ptr sketch_container_; }; -void DeviceSketch - (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows, - GPUSet const& devices) { - GPUSketcher sketcher(param, devices); - sketcher.Sketch(batch, info, hmat, gpu_batch_nrows); +size_t DeviceSketch + (const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_batch_nrows, + DMatrix *dmat, HistCutMatrix *hmat) { + GPUSketcher sketcher(param, learner_param, gpu_batch_nrows); + return sketcher.Sketch(dmat, hmat); } } // namespace common diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 5fb9b9c85..dc2b80bb8 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -8,6 +8,7 @@ #define XGBOOST_COMMON_HIST_UTIL_H_ #include +#include #include #include #include "row_set.h" @@ -84,9 +85,6 @@ struct SimpleArray { size_t n_ = 0; }; - - - /*! \brief Cut configuration for all the features. */ struct HistCutMatrix { /*! \brief Unit pointer to rows by element position */ @@ -115,11 +113,13 @@ struct HistCutMatrix { Monitor monitor_; }; -/*! \brief Builds the cut matrix on the GPU */ -void DeviceSketch - (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows, - GPUSet const& devices); +/*! \brief Builds the cut matrix on the GPU. + * + * \return The row stride across the entire dataset. + */ +size_t DeviceSketch + (const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows, + DMatrix* dmat, HistCutMatrix* hmat); /*! * \brief A single row in global histogram index. diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index a398d5f2c..0a4f0d7f8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1374,7 +1374,7 @@ inline void DeviceShard::CreateHistIndices( } template -class GPUHistMakerSpecialised{ +class GPUHistMakerSpecialised { public: GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {} void Init(const std::vector>& args, @@ -1449,10 +1449,12 @@ class GPUHistMakerSpecialised{ // Find the cuts. monitor_.StartCuda("Quantiles"); - common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows, - GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus)); + // TODO(sriramch): The return value will be used when we add support for histogram + // index creation for multiple batches + common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_); n_bins_ = hmat_.row_ptr.back(); monitor_.StopCuda("Quantiles"); + auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; monitor_.StartCuda("BinningCompression"); @@ -1557,7 +1559,6 @@ class GPUHistMakerSpecialised{ GPUHistMakerTrainParam hist_maker_param_; LearnerTrainParam const* learner_param_; - common::GHistIndexMatrix gmat_; dh::AllReducer reducer_; diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index b6bdc40cd..8a1b3ed59 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -1,50 +1,72 @@ -#include "gtest/gtest.h" -#include "xgboost/c_api.h" #include #include + +#include "gtest/gtest.h" + #include #include -#include "../helpers.h" +#include "xgboost/c_api.h" + #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/hist_util.h" +#include "../helpers.h" + namespace xgboost { namespace common { -void TestDeviceSketch(const GPUSet& devices) { +void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) { // create the data int nrows = 10001; - std::vector test_data(nrows); - auto count_iter = thrust::make_counting_iterator(0); - // fill in reverse order - std::copy(count_iter, count_iter + nrows, test_data.rbegin()); + std::shared_ptr *dmat = nullptr; - // create the DMatrix - DMatrixHandle dmat_handle; - XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, - &dmat_handle); - auto dmat = static_cast *>(dmat_handle); + size_t num_cols = 1; + if (use_external_memory) { + auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL); // 3 entries/row + dmat = new std::shared_ptr(std::move(sp_dmat)); + num_cols = 5; + } else { + std::vector test_data(nrows); + auto count_iter = thrust::make_counting_iterator(0); + // fill in reverse order + std::copy(count_iter, count_iter + nrows, test_data.rbegin()); + + // create the DMatrix + DMatrixHandle dmat_handle; + XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, + &dmat_handle); + dmat = static_cast *>(dmat_handle); + } - // parameters for finding quantiles tree::TrainParam p; p.max_bin = 20; - // ensure that the exact quantiles are found - int gpu_batch_nrows = nrows * 10; + int gpu_batch_nrows = 0; // find quantiles on the CPU HistCutMatrix hmat_cpu; hmat_cpu.Init((*dmat).get(), p.max_bin); // find the cuts on the GPU - const SparsePage& batch = *(*dmat)->GetRowBatches().begin(); HistCutMatrix hmat_gpu; - DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows, devices); + size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows, + dmat->get(), &hmat_gpu); + + // compare the row stride with the one obtained from the dmatrix + size_t expected_row_stride = 0; + for (const auto &batch : dmat->get()->GetRowBatches()) { + const auto &offset_vec = batch.offset.ConstHostVector(); + for (int i = 1; i <= offset_vec.size() -1; ++i) { + expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]); + } + } + + ASSERT_EQ(expected_row_stride, row_stride); // compare the cuts double eps = 1e-2; - ASSERT_EQ(hmat_gpu.min_val.size(), 1); - ASSERT_EQ(hmat_gpu.row_ptr.size(), 2); + ASSERT_EQ(hmat_gpu.min_val.size(), num_cols); + ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1); ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size()); ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows); for (int i = 0; i < hmat_gpu.cut.size(); ++i) { @@ -55,14 +77,24 @@ void TestDeviceSketch(const GPUSet& devices) { } TEST(gpu_hist_util, DeviceSketch) { - TestDeviceSketch(GPUSet::Range(0, 1)); + TestDeviceSketch(GPUSet::Range(0, 1), false); +} + +TEST(gpu_hist_util, DeviceSketch_ExternalMemory) { + TestDeviceSketch(GPUSet::Range(0, 1), true); } #if defined(XGBOOST_USE_NCCL) TEST(gpu_hist_util, MGPU_DeviceSketch) { auto devices = GPUSet::AllVisible(); CHECK_GT(devices.Size(), 1); - TestDeviceSketch(devices); + TestDeviceSketch(devices, false); +} + +TEST(gpu_hist_util, MGPU_DeviceSketch_ExternalMemory) { + auto devices = GPUSet::AllVisible(); + CHECK_GT(devices.Size(), 1); + TestDeviceSketch(devices, true); } #endif