- training with external memory part 1 of 2 (#4486)

* - training with external memory part 1 of 2
   - this pr focuses on computing the quantiles using multiple gpus on a
     dataset that uses the external cache capabilities
   - there will a follow-up pr soon after this that will support creation
     of histogram indices on large dataset as well
   - both of these changes are required to support training with external memory
   - the sparse pages in dmatrix are taken in batches and the the cut matrices
     are incrementally built
   - also snuck in some (perf) changes related to sketches aggregation amongst multiple
     features across multiple sparse page batches. instead of aggregating the summary
     inside each device and merged later, it is aggregated in-place when the device
     is working on different rows but the same feature
This commit is contained in:
sriramch 2019-05-29 13:18:34 -07:00 committed by Rory Mitchell
parent 6e16900711
commit fed665ae8a
4 changed files with 180 additions and 88 deletions

View File

@ -13,6 +13,8 @@
#include <utility>
#include <vector>
#include <memory>
#include <mutex>
#include "../tree/param.h"
#include "./host_device_vector.h"
@ -82,6 +84,36 @@ __global__ void UnpackFeaturesK
}
}
/*!
* \brief A container that holds the device sketches across all
* sparse page batches which are distributed to different devices.
* As sketches are aggregated by column, the mutex guards
* multiple devices pushing sketch summary for the same column
* across distinct rows.
*/
struct SketchContainer {
std::vector<HistCutMatrix::WXQSketch> sketches_; // NOLINT
std::vector<std::mutex> col_locks_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
SketchContainer(const tree::TrainParam &param, DMatrix *dmat) :
col_locks_(dmat->Info().num_col_) {
const MetaInfo &info = dmat->Info();
// Initialize Sketches for this dmatrix
sketches_.resize(info.num_col_);
#pragma omp parallel for schedule(static) if (info.num_col_ > kOmpNumColsParallelizeLimit)
for (int icol = 0; icol < info.num_col_; ++icol) {
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
}
}
// Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved
SketchContainer(const SketchContainer &) = delete;
SketchContainer(const SketchContainer &&) = delete;
SketchContainer &operator=(const SketchContainer &) = delete;
SketchContainer &operator=(const SketchContainer &&) = delete;
};
// finds quantiles on the GPU
struct GPUSketcher {
// manage memory for a single GPU
@ -94,11 +126,11 @@ struct GPUSketcher {
size_t n_cuts_{0};
size_t gpu_batch_nrows_{0};
bool has_weights_{false};
size_t row_stride_{0};
tree::TrainParam param_;
std::vector<WXQSketch> sketches_;
SketchContainer *sketch_container_;
thrust::device_vector<size_t> row_ptrs_;
std::vector<WXQSketch::SummaryContainer> summaries_;
thrust::device_vector<Entry> entries_;
thrust::device_vector<bst_float> fvalues_;
thrust::device_vector<bst_float> feature_weights_;
@ -113,9 +145,13 @@ struct GPUSketcher {
public:
DeviceShard(int device, bst_uint row_begin, bst_uint row_end,
tree::TrainParam param) :
tree::TrainParam param, SketchContainer *sketch_container) :
device_(device), row_begin_(row_begin), row_end_(row_end),
n_rows_(row_end - row_begin), param_(std::move(param)) {
n_rows_(row_end - row_begin), param_(std::move(param)), sketch_container_(sketch_container) {
}
inline size_t GetRowStride() const {
return row_stride_;
}
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
@ -136,20 +172,10 @@ struct GPUSketcher {
gpu_batch_nrows_ = n_rows_;
}
// initialize sketches
sketches_.resize(num_cols_);
summaries_.resize(num_cols_);
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * param_.max_bin);
size_t dummy_nlevel;
WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_);
// double ncuts to be the same as the number of values
// in the temporary buffers of the sketches
n_cuts_ *= 2;
for (int icol = 0; icol < num_cols_; ++icol) {
sketches_[icol].Init(row_batch.Size(), eps);
summaries_[icol].Reserve(n_cuts_);
}
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
// allocate necessary GPU buffers
dh::safe_cuda(cudaSetDevice(device_));
@ -306,9 +332,12 @@ struct GPUSketcher {
// unpack the features; also unpack weights if present
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
if (has_weights_) {
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
}
dim3 block3(64, 4, 1);
dim3 block3(16, 64, 1);
// NOTE: This will typically support ~ 4M features - 64K*64
dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(num_cols_, block3.y), 1);
UnpackFeaturesK<<<grid3, block3>>>
@ -324,12 +353,34 @@ struct GPUSketcher {
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
#pragma omp parallel for schedule(static) \
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
sketches_[icol].PushSummary(summaries_[icol]);
WXQSketch::SummaryContainer summary;
summary.Reserve(n_cuts_);
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
std::lock_guard<std::mutex> lock(sketch_container_->col_locks_[icol]);
sketch_container_->sketches_[icol].PushSummary(summary);
}
}
void ComputeRowStride() {
// Find the row stride for this batch
auto row_iter = row_ptrs_.begin();
// Functor for finding the maximum row size for this batch
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size),
decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
row_stride_ = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0,
thrust::maximum<size_t>());
}
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
// copy rows to the device
dh::safe_cuda(cudaSetDevice(device_));
@ -342,63 +393,71 @@ struct GPUSketcher {
SketchBatch(row_batch, info, gpu_batch);
}
}
void GetSummary(WXQSketch::SummaryContainer *summary, size_t const icol) {
sketches_[icol].GetSummary(summary);
}
};
void Sketch(const SparsePage& batch, const MetaInfo& info,
HistCutMatrix* hmat, int gpu_batch_nrows) {
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
GPUDistribution dist =
GPUDistribution::Block(GPUSet::All(learner_param_.gpu_id, learner_param_.n_gpus,
batch.Size()));
// create device shards
shards_.resize(dist_.Devices().Size());
shards_.resize(dist.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist_.ShardStart(info.num_row_, i);
size_t size = dist_.ShardSize(info.num_row_, i);
size_t start = dist.ShardStart(batch.Size(), i);
size_t size = dist.ShardSize(batch.Size(), i);
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(dist_.Devices().DeviceId(i),
start, start + size, param_));
new DeviceShard(dist.Devices().DeviceId(i), start,
start + size, param_, sketch_container_.get()));
});
// compute sketches for each shard
dh::ExecuteIndexShards(&shards_,
[&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info, gpu_batch_nrows);
shard->Init(batch, info, gpu_batch_nrows_);
shard->Sketch(batch, info);
shard->ComputeRowStride();
});
// merge the sketches from all shards
// TODO(canonizer): do it in a tree-like reduction
int num_cols = info.num_col_;
std::vector<WXQSketch> sketches(num_cols);
WXQSketch::SummaryContainer summary;
for (int icol = 0; icol < num_cols; ++icol) {
sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin));
for (auto &shard : shards_) {
shard->GetSummary(&summary, icol);
sketches[icol].PushSummary(summary);
}
// compute row stride across all shards
for (const auto &shard : shards_) {
row_stride_ = std::max(row_stride_, shard->GetRowStride());
}
hmat->Init(&sketches, param_.max_bin);
}
GPUSketcher(tree::TrainParam param, GPUSet const& devices) : param_(std::move(param)) {
dist_ = GPUDistribution::Block(devices);
GPUSketcher(const tree::TrainParam &param, const LearnerTrainParam &learner_param, int gpu_nrows)
: param_(param), learner_param_(learner_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {
}
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
* for the entire dataset */
size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) {
const MetaInfo &info = dmat->Info();
row_stride_ = 0;
sketch_container_.reset(new SketchContainer(param_, dmat));
for (const auto &batch : dmat->GetRowBatches()) {
this->SketchBatch(batch, info);
}
hmat->Init(&sketch_container_.get()->sketches_, param_.max_bin);
return row_stride_;
}
private:
std::vector<std::unique_ptr<DeviceShard>> shards_;
tree::TrainParam param_;
GPUDistribution dist_;
const tree::TrainParam &param_;
const LearnerTrainParam &learner_param_;
int gpu_batch_nrows_;
size_t row_stride_;
std::unique_ptr<SketchContainer> sketch_container_;
};
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows,
GPUSet const& devices) {
GPUSketcher sketcher(param, devices);
sketcher.Sketch(batch, info, hmat, gpu_batch_nrows);
size_t DeviceSketch
(const tree::TrainParam &param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
DMatrix *dmat, HistCutMatrix *hmat) {
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
return sketcher.Sketch(dmat, hmat);
}
} // namespace common

View File

@ -8,6 +8,7 @@
#define XGBOOST_COMMON_HIST_UTIL_H_
#include <xgboost/data.h>
#include <xgboost/generic_parameters.h>
#include <limits>
#include <vector>
#include "row_set.h"
@ -84,9 +85,6 @@ struct SimpleArray {
size_t n_ = 0;
};
/*! \brief Cut configuration for all the features. */
struct HistCutMatrix {
/*! \brief Unit pointer to rows by element position */
@ -115,11 +113,13 @@ struct HistCutMatrix {
Monitor monitor_;
};
/*! \brief Builds the cut matrix on the GPU */
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows,
GPUSet const& devices);
/*! \brief Builds the cut matrix on the GPU.
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch
(const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
DMatrix* dmat, HistCutMatrix* hmat);
/*!
* \brief A single row in global histogram index.

View File

@ -1374,7 +1374,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
}
template <typename GradientSumT>
class GPUHistMakerSpecialised{
class GPUHistMakerSpecialised {
public:
GPUHistMakerSpecialised() : initialised_{false}, p_last_fmat_{nullptr} {}
void Init(const std::vector<std::pair<std::string, std::string>>& args,
@ -1449,10 +1449,12 @@ class GPUHistMakerSpecialised{
// Find the cuts.
monitor_.StartCuda("Quantiles");
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows,
GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus));
// TODO(sriramch): The return value will be used when we add support for histogram
// index creation for multiple batches
common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_);
n_bins_ = hmat_.row_ptr.back();
monitor_.StopCuda("Quantiles");
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
monitor_.StartCuda("BinningCompression");
@ -1557,7 +1559,6 @@ class GPUHistMakerSpecialised{
GPUHistMakerTrainParam hist_maker_param_;
LearnerTrainParam const* learner_param_;
common::GHistIndexMatrix gmat_;
dh::AllReducer reducer_;

View File

@ -1,50 +1,72 @@
#include "gtest/gtest.h"
#include "xgboost/c_api.h"
#include <algorithm>
#include <cmath>
#include "gtest/gtest.h"
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include "../helpers.h"
#include "xgboost/c_api.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h"
#include "../helpers.h"
namespace xgboost {
namespace common {
void TestDeviceSketch(const GPUSet& devices) {
void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
// create the data
int nrows = 10001;
std::vector<float> test_data(nrows);
auto count_iter = thrust::make_counting_iterator(0);
// fill in reverse order
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
std::shared_ptr<xgboost::DMatrix> *dmat = nullptr;
// create the DMatrix
DMatrixHandle dmat_handle;
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
&dmat_handle);
auto dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
size_t num_cols = 1;
if (use_external_memory) {
auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL); // 3 entries/row
dmat = new std::shared_ptr<xgboost::DMatrix>(std::move(sp_dmat));
num_cols = 5;
} else {
std::vector<float> test_data(nrows);
auto count_iter = thrust::make_counting_iterator(0);
// fill in reverse order
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
// create the DMatrix
DMatrixHandle dmat_handle;
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
&dmat_handle);
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
}
// parameters for finding quantiles
tree::TrainParam p;
p.max_bin = 20;
// ensure that the exact quantiles are found
int gpu_batch_nrows = nrows * 10;
int gpu_batch_nrows = 0;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
hmat_cpu.Init((*dmat).get(), p.max_bin);
// find the cuts on the GPU
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
HistCutMatrix hmat_gpu;
DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows, devices);
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows,
dmat->get(), &hmat_gpu);
// compare the row stride with the one obtained from the dmatrix
size_t expected_row_stride = 0;
for (const auto &batch : dmat->get()->GetRowBatches()) {
const auto &offset_vec = batch.offset.ConstHostVector();
for (int i = 1; i <= offset_vec.size() -1; ++i) {
expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]);
}
}
ASSERT_EQ(expected_row_stride, row_stride);
// compare the cuts
double eps = 1e-2;
ASSERT_EQ(hmat_gpu.min_val.size(), 1);
ASSERT_EQ(hmat_gpu.row_ptr.size(), 2);
ASSERT_EQ(hmat_gpu.min_val.size(), num_cols);
ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
@ -55,14 +77,24 @@ void TestDeviceSketch(const GPUSet& devices) {
}
TEST(gpu_hist_util, DeviceSketch) {
TestDeviceSketch(GPUSet::Range(0, 1));
TestDeviceSketch(GPUSet::Range(0, 1), false);
}
TEST(gpu_hist_util, DeviceSketch_ExternalMemory) {
TestDeviceSketch(GPUSet::Range(0, 1), true);
}
#if defined(XGBOOST_USE_NCCL)
TEST(gpu_hist_util, MGPU_DeviceSketch) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestDeviceSketch(devices);
TestDeviceSketch(devices, false);
}
TEST(gpu_hist_util, MGPU_DeviceSketch_ExternalMemory) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestDeviceSketch(devices, true);
}
#endif