From fed665ae8a98fb64818b60e75bebda6c4d32e770 Mon Sep 17 00:00:00 2001 From: sriramch <33358417+sriramch@users.noreply.github.com> Date: Wed, 29 May 2019 13:18:34 -0700 Subject: [PATCH] - training with external memory part 1 of 2 (#4486) * - training with external memory part 1 of 2 - this pr focuses on computing the quantiles using multiple gpus on a dataset that uses the external cache capabilities - there will a follow-up pr soon after this that will support creation of histogram indices on large dataset as well - both of these changes are required to support training with external memory - the sparse pages in dmatrix are taken in batches and the the cut matrices are incrementally built - also snuck in some (perf) changes related to sketches aggregation amongst multiple features across multiple sparse page batches. instead of aggregating the summary inside each device and merged later, it is aggregated in-place when the device is working on different rows but the same feature --- src/common/hist_util.cu | 167 +++++++++++++++++-------- src/common/hist_util.h | 16 +-- src/tree/updater_gpu_hist.cu | 9 +- tests/cpp/common/test_gpu_hist_util.cu | 76 +++++++---- 4 files changed, 180 insertions(+), 88 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 2e94d2c5a..1ee1267aa 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -13,6 +13,8 @@ #include #include +#include +#include #include "../tree/param.h" #include "./host_device_vector.h" @@ -82,6 +84,36 @@ __global__ void UnpackFeaturesK } } +/*! + * \brief A container that holds the device sketches across all + * sparse page batches which are distributed to different devices. + * As sketches are aggregated by column, the mutex guards + * multiple devices pushing sketch summary for the same column + * across distinct rows. + */ +struct SketchContainer { + std::vector sketches_; // NOLINT + std::vector col_locks_; // NOLINT + static constexpr int kOmpNumColsParallelizeLimit = 1000; + + SketchContainer(const tree::TrainParam ¶m, DMatrix *dmat) : + col_locks_(dmat->Info().num_col_) { + const MetaInfo &info = dmat->Info(); + // Initialize Sketches for this dmatrix + sketches_.resize(info.num_col_); +#pragma omp parallel for schedule(static) if (info.num_col_ > kOmpNumColsParallelizeLimit) + for (int icol = 0; icol < info.num_col_; ++icol) { + sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin)); + } + } + + // Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved + SketchContainer(const SketchContainer &) = delete; + SketchContainer(const SketchContainer &&) = delete; + SketchContainer &operator=(const SketchContainer &) = delete; + SketchContainer &operator=(const SketchContainer &&) = delete; +}; + // finds quantiles on the GPU struct GPUSketcher { // manage memory for a single GPU @@ -94,11 +126,11 @@ struct GPUSketcher { size_t n_cuts_{0}; size_t gpu_batch_nrows_{0}; bool has_weights_{false}; + size_t row_stride_{0}; tree::TrainParam param_; - std::vector sketches_; + SketchContainer *sketch_container_; thrust::device_vector row_ptrs_; - std::vector summaries_; thrust::device_vector entries_; thrust::device_vector fvalues_; thrust::device_vector feature_weights_; @@ -113,9 +145,13 @@ struct GPUSketcher { public: DeviceShard(int device, bst_uint row_begin, bst_uint row_end, - tree::TrainParam param) : + tree::TrainParam param, SketchContainer *sketch_container) : device_(device), row_begin_(row_begin), row_end_(row_end), - n_rows_(row_end - row_begin), param_(std::move(param)) { + n_rows_(row_end - row_begin), param_(std::move(param)), sketch_container_(sketch_container) { + } + + inline size_t GetRowStride() const { + return row_stride_; } void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { @@ -136,20 +172,10 @@ struct GPUSketcher { gpu_batch_nrows_ = n_rows_; } - // initialize sketches - sketches_.resize(num_cols_); - summaries_.resize(num_cols_); constexpr int kFactor = 8; double eps = 1.0 / (kFactor * param_.max_bin); size_t dummy_nlevel; - WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_); - // double ncuts to be the same as the number of values - // in the temporary buffers of the sketches - n_cuts_ *= 2; - for (int icol = 0; icol < num_cols_; ++icol) { - sketches_[icol].Init(row_batch.Size(), eps); - summaries_[icol].Reserve(n_cuts_); - } + WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); // allocate necessary GPU buffers dh::safe_cuda(cudaSetDevice(device_)); @@ -306,9 +332,12 @@ struct GPUSketcher { // unpack the features; also unpack weights if present thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); - thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + if (has_weights_) { + thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + } - dim3 block3(64, 4, 1); + dim3 block3(16, 64, 1); + // NOTE: This will typically support ~ 4M features - 64K*64 dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x), dh::DivRoundUp(num_cols_, block3.y), 1); UnpackFeaturesK<<>> @@ -324,12 +353,34 @@ struct GPUSketcher { // add cuts into sketches thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); +#pragma omp parallel for schedule(static) \ + if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { - summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); - sketches_[icol].PushSummary(summaries_[icol]); + WXQSketch::SummaryContainer summary; + summary.Reserve(n_cuts_); + summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); + + std::lock_guard lock(sketch_container_->col_locks_[icol]); + sketch_container_->sketches_[icol].PushSummary(summary); } } + void ComputeRowStride() { + // Find the row stride for this batch + auto row_iter = row_ptrs_.begin(); + // Functor for finding the maximum row size for this batch + auto get_size = [=] __device__(size_t row) { + return row_iter[row + 1] - row_iter[row]; + }; // NOLINT + + auto counting = thrust::make_counting_iterator(size_t(0)); + using TransformT = thrust::transform_iterator; + TransformT row_size_iter = TransformT(counting, get_size); + row_stride_ = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, + thrust::maximum()); + } + void Sketch(const SparsePage& row_batch, const MetaInfo& info) { // copy rows to the device dh::safe_cuda(cudaSetDevice(device_)); @@ -342,63 +393,71 @@ struct GPUSketcher { SketchBatch(row_batch, info, gpu_batch); } } - - void GetSummary(WXQSketch::SummaryContainer *summary, size_t const icol) { - sketches_[icol].GetSummary(summary); - } }; - void Sketch(const SparsePage& batch, const MetaInfo& info, - HistCutMatrix* hmat, int gpu_batch_nrows) { + void SketchBatch(const SparsePage &batch, const MetaInfo &info) { + GPUDistribution dist = + GPUDistribution::Block(GPUSet::All(learner_param_.gpu_id, learner_param_.n_gpus, + batch.Size())); + // create device shards - shards_.resize(dist_.Devices().Size()); + shards_.resize(dist.Devices().Size()); dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - size_t start = dist_.ShardStart(info.num_row_, i); - size_t size = dist_.ShardSize(info.num_row_, i); + size_t start = dist.ShardStart(batch.Size(), i); + size_t size = dist.ShardSize(batch.Size(), i); shard = std::unique_ptr( - new DeviceShard(dist_.Devices().DeviceId(i), - start, start + size, param_)); + new DeviceShard(dist.Devices().DeviceId(i), start, + start + size, param_, sketch_container_.get())); }); // compute sketches for each shard dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->Init(batch, info, gpu_batch_nrows); + shard->Init(batch, info, gpu_batch_nrows_); shard->Sketch(batch, info); + shard->ComputeRowStride(); }); - // merge the sketches from all shards - // TODO(canonizer): do it in a tree-like reduction - int num_cols = info.num_col_; - std::vector sketches(num_cols); - WXQSketch::SummaryContainer summary; - for (int icol = 0; icol < num_cols; ++icol) { - sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin)); - for (auto &shard : shards_) { - shard->GetSummary(&summary, icol); - sketches[icol].PushSummary(summary); - } + // compute row stride across all shards + for (const auto &shard : shards_) { + row_stride_ = std::max(row_stride_, shard->GetRowStride()); } - - hmat->Init(&sketches, param_.max_bin); } - GPUSketcher(tree::TrainParam param, GPUSet const& devices) : param_(std::move(param)) { - dist_ = GPUDistribution::Block(devices); + GPUSketcher(const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_nrows) + : param_(param), learner_param_(learner_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) { + } + + /* Builds the sketches on the GPU for the dmatrix and returns the row stride + * for the entire dataset */ + size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) { + const MetaInfo &info = dmat->Info(); + + row_stride_ = 0; + sketch_container_.reset(new SketchContainer(param_, dmat)); + for (const auto &batch : dmat->GetRowBatches()) { + this->SketchBatch(batch, info); + } + + hmat->Init(&sketch_container_.get()->sketches_, param_.max_bin); + + return row_stride_; } private: std::vector> shards_; - tree::TrainParam param_; - GPUDistribution dist_; + const tree::TrainParam ¶m_; + const LearnerTrainParam &learner_param_; + int gpu_batch_nrows_; + size_t row_stride_; + std::unique_ptr sketch_container_; }; -void DeviceSketch - (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows, - GPUSet const& devices) { - GPUSketcher sketcher(param, devices); - sketcher.Sketch(batch, info, hmat, gpu_batch_nrows); +size_t DeviceSketch + (const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_batch_nrows, + DMatrix *dmat, HistCutMatrix *hmat) { + GPUSketcher sketcher(param, learner_param, gpu_batch_nrows); + return sketcher.Sketch(dmat, hmat); } } // namespace common diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 5fb9b9c85..dc2b80bb8 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -8,6 +8,7 @@ #define XGBOOST_COMMON_HIST_UTIL_H_ #include +#include #include #include #include "row_set.h" @@ -84,9 +85,6 @@ struct SimpleArray { size_t n_ = 0; }; - - - /*! \brief Cut configuration for all the features. */ struct HistCutMatrix { /*! \brief Unit pointer to rows by element position */ @@ -115,11 +113,13 @@ struct HistCutMatrix { Monitor monitor_; }; -/*! \brief Builds the cut matrix on the GPU */ -void DeviceSketch - (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows, - GPUSet const& devices); +/*! \brief Builds the cut matrix on the GPU. + * + * \return The row stride across the entire dataset. + */ +size_t DeviceSketch + (const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows, + DMatrix* dmat, HistCutMatrix* hmat); /*! * \brief A single row in global histogram index. diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index a398d5f2c..0a4f0d7f8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1374,7 +1374,7 @@ inline void DeviceShard::CreateHistIndices( } template -class GPUHistMakerSpecialised{ +class GPUHistMakerSpecialised { public: GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {} void Init(const std::vector>& args, @@ -1449,10 +1449,12 @@ class GPUHistMakerSpecialised{ // Find the cuts. monitor_.StartCuda("Quantiles"); - common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows, - GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus)); + // TODO(sriramch): The return value will be used when we add support for histogram + // index creation for multiple batches + common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_); n_bins_ = hmat_.row_ptr.back(); monitor_.StopCuda("Quantiles"); + auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; monitor_.StartCuda("BinningCompression"); @@ -1557,7 +1559,6 @@ class GPUHistMakerSpecialised{ GPUHistMakerTrainParam hist_maker_param_; LearnerTrainParam const* learner_param_; - common::GHistIndexMatrix gmat_; dh::AllReducer reducer_; diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index b6bdc40cd..8a1b3ed59 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -1,50 +1,72 @@ -#include "gtest/gtest.h" -#include "xgboost/c_api.h" #include #include + +#include "gtest/gtest.h" + #include #include -#include "../helpers.h" +#include "xgboost/c_api.h" + #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/hist_util.h" +#include "../helpers.h" + namespace xgboost { namespace common { -void TestDeviceSketch(const GPUSet& devices) { +void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) { // create the data int nrows = 10001; - std::vector test_data(nrows); - auto count_iter = thrust::make_counting_iterator(0); - // fill in reverse order - std::copy(count_iter, count_iter + nrows, test_data.rbegin()); + std::shared_ptr *dmat = nullptr; - // create the DMatrix - DMatrixHandle dmat_handle; - XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, - &dmat_handle); - auto dmat = static_cast *>(dmat_handle); + size_t num_cols = 1; + if (use_external_memory) { + auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL); // 3 entries/row + dmat = new std::shared_ptr(std::move(sp_dmat)); + num_cols = 5; + } else { + std::vector test_data(nrows); + auto count_iter = thrust::make_counting_iterator(0); + // fill in reverse order + std::copy(count_iter, count_iter + nrows, test_data.rbegin()); + + // create the DMatrix + DMatrixHandle dmat_handle; + XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, + &dmat_handle); + dmat = static_cast *>(dmat_handle); + } - // parameters for finding quantiles tree::TrainParam p; p.max_bin = 20; - // ensure that the exact quantiles are found - int gpu_batch_nrows = nrows * 10; + int gpu_batch_nrows = 0; // find quantiles on the CPU HistCutMatrix hmat_cpu; hmat_cpu.Init((*dmat).get(), p.max_bin); // find the cuts on the GPU - const SparsePage& batch = *(*dmat)->GetRowBatches().begin(); HistCutMatrix hmat_gpu; - DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows, devices); + size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows, + dmat->get(), &hmat_gpu); + + // compare the row stride with the one obtained from the dmatrix + size_t expected_row_stride = 0; + for (const auto &batch : dmat->get()->GetRowBatches()) { + const auto &offset_vec = batch.offset.ConstHostVector(); + for (int i = 1; i <= offset_vec.size() -1; ++i) { + expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]); + } + } + + ASSERT_EQ(expected_row_stride, row_stride); // compare the cuts double eps = 1e-2; - ASSERT_EQ(hmat_gpu.min_val.size(), 1); - ASSERT_EQ(hmat_gpu.row_ptr.size(), 2); + ASSERT_EQ(hmat_gpu.min_val.size(), num_cols); + ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1); ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size()); ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows); for (int i = 0; i < hmat_gpu.cut.size(); ++i) { @@ -55,14 +77,24 @@ void TestDeviceSketch(const GPUSet& devices) { } TEST(gpu_hist_util, DeviceSketch) { - TestDeviceSketch(GPUSet::Range(0, 1)); + TestDeviceSketch(GPUSet::Range(0, 1), false); +} + +TEST(gpu_hist_util, DeviceSketch_ExternalMemory) { + TestDeviceSketch(GPUSet::Range(0, 1), true); } #if defined(XGBOOST_USE_NCCL) TEST(gpu_hist_util, MGPU_DeviceSketch) { auto devices = GPUSet::AllVisible(); CHECK_GT(devices.Size(), 1); - TestDeviceSketch(devices); + TestDeviceSketch(devices, false); +} + +TEST(gpu_hist_util, MGPU_DeviceSketch_ExternalMemory) { + auto devices = GPUSet::AllVisible(); + CHECK_GT(devices.Size(), 1); + TestDeviceSketch(devices, true); } #endif