Support multiple batches in gpu_hist (#5014)

* Initial external memory training support for GPU Hist tree method.
This commit is contained in:
Rong Ou 2019-11-15 22:50:20 -08:00 committed by Jiaming Yuan
parent 97abcc7ee2
commit 0afcc55d98
15 changed files with 559 additions and 134 deletions

View File

@ -166,6 +166,15 @@ struct BatchParam {
int max_bin; int max_bin;
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */ /*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows; int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;
inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
gpu_batch_nrows != other.gpu_batch_nrows ||
gpu_page_size != other.gpu_page_size;
}
}; };
/*! /*!

View File

@ -21,6 +21,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
int nthread; int nthread;
// primary device, -1 means no gpu. // primary device, -1 means no gpu.
int gpu_id; int gpu_id;
// gpu page size in external memory mode, 0 means using the default.
size_t gpu_page_size;
void CheckDeprecated() { void CheckDeprecated() {
if (this->n_gpus != 0) { if (this->n_gpus != 0) {
@ -49,6 +51,10 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.set_default(-1) .set_default(-1)
.set_lower_bound(-1) .set_lower_bound(-1)
.describe("The primary GPU device ordinal."); .describe("The primary GPU device ordinal.");
DMLC_DECLARE_FIELD(gpu_page_size)
.set_default(0)
.set_lower_bound(0)
.describe("GPU page size when running in external memory mode.");
DMLC_DECLARE_FIELD(n_gpus) DMLC_DECLARE_FIELD(n_gpus)
.set_default(0) .set_default(0)
.set_range(0, 1) .set_range(0, 1)

View File

@ -164,8 +164,9 @@ class GPUSketcher {
auto counting = thrust::make_counting_iterator(size_t(0)); auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size), decltype(counting), size_t>; using TransformT = thrust::transform_iterator<decltype(get_size), decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size); TransformT row_size_iter = TransformT(counting, get_size);
row_stride_ = size_t batch_row_stride =
thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum<size_t>()); thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum<size_t>());
row_stride_ = std::max(row_stride_, batch_row_stride);
} }
// This needs to be public because of the __device__ lambda. // This needs to be public because of the __device__ lambda.

View File

@ -69,6 +69,8 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page"); monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(param.gpu_id)); dh::safe_cuda(cudaSetDevice(param.gpu_id));
matrix.n_rows = dmat->Info().num_row_;
monitor_.StartCuda("Quantiles"); monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts. // Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat; common::HistogramCuts hmat;
@ -206,7 +208,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
// Return the number of rows contained in this page. // Return the number of rows contained in this page.
size_t EllpackPageImpl::Size() const { size_t EllpackPageImpl::Size() const {
return n_rows; return matrix.n_rows;
} }
// Clear the current page. // Clear the current page.
@ -214,44 +216,50 @@ void EllpackPageImpl::Clear() {
ba_.Clear(); ba_.Clear();
gidx_buffer = {}; gidx_buffer = {};
idx_buffer.clear(); idx_buffer.clear();
n_rows = 0; sparse_page_.Clear();
matrix.base_rowid = 0;
matrix.n_rows = 0;
device_initialized_ = false;
} }
// Push a CSR page to the current page. // Push a CSR page to the current page.
// //
// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and // The CSR pages are accumulated in memory until they reach a certain size, then written out as
// appended to the existing host vector. // compressed ELLPACK.
void EllpackPageImpl::Push(int device, const SparsePage& batch) { void EllpackPageImpl::Push(int device, const SparsePage& batch) {
sparse_page_.Push(batch);
matrix.n_rows += batch.Size();
}
// Compress the accumulated SparsePage.
void EllpackPageImpl::CompressSparsePage(int device) {
monitor_.StartCuda("InitCompressedData"); monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, batch.Size()); InitCompressedData(device, matrix.n_rows);
monitor_.StopCuda("InitCompressedData"); monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression"); monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(batch.Size()); DeviceHistogramBuilderState hist_builder_row_state(matrix.n_rows);
hist_builder_row_state.BeginBatch(batch); hist_builder_row_state.BeginBatch(sparse_page_);
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice()); CreateHistIndices(device, sparse_page_, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch(); hist_builder_row_state.EndBatch();
monitor_.StopCuda("BinningCompression"); monitor_.StopCuda("BinningCompression");
monitor_.StartCuda("CopyDeviceToHost"); monitor_.StartCuda("CopyDeviceToHost");
std::vector<common::CompressedByteT> buffer(gidx_buffer.size()); idx_buffer.resize(gidx_buffer.size());
dh::CopyDeviceSpanToVector(&buffer, gidx_buffer); dh::CopyDeviceSpanToVector(&idx_buffer, gidx_buffer);
int offset = 0;
if (!idx_buffer.empty()) {
offset = ::xgboost::common::detail::kPadding;
}
idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset);
idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end());
ba_.Clear(); ba_.Clear();
gidx_buffer = {}; gidx_buffer = {};
monitor_.StopCuda("CopyDeviceToHost"); monitor_.StopCuda("CopyDeviceToHost");
n_rows += batch.Size();
} }
// Return the memory cost for storing the compressed features. // Return the memory cost for storing the compressed features.
size_t EllpackPageImpl::MemCostBytes() const { size_t EllpackPageImpl::MemCostBytes() const {
return idx_buffer.size() * sizeof(common::CompressedByteT); size_t num_symbols = matrix.info.n_bins + 1;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
matrix.info.row_stride * matrix.n_rows, num_symbols);
return compressed_size_bytes;
} }
// Copy the compressed features to GPU. // Copy the compressed features to GPU.

View File

@ -78,13 +78,14 @@ struct EllpackInfo {
* kernels.*/ * kernels.*/
struct EllpackMatrix { struct EllpackMatrix {
EllpackInfo info; EllpackInfo info;
size_t base_rowid{};
size_t n_rows{};
common::CompressedIterator<uint32_t> gidx_iter; common::CompressedIterator<uint32_t> gidx_iter;
XGBOOST_DEVICE size_t BinCount() const { return info.gidx_fvalue_map.size(); }
// Get a matrix element, uses binary search for look up Return NaN if missing // Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value // Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const { __device__ bst_float GetElement(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx; auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride; auto row_end = row_begin + info.row_stride;
auto gidx = -1; auto gidx = -1;
@ -102,6 +103,11 @@ struct EllpackMatrix {
} }
return info.gidx_fvalue_map[gidx]; return info.gidx_fvalue_map[gidx];
} }
// Check if the row id is withing range of the current batch.
__device__ bool IsInRange(size_t row_id) const {
return row_id >= base_rowid && row_id < base_rowid + n_rows;
}
}; };
// Instances of this type are created while creating the histogram bins for the // Instances of this type are created while creating the histogram bins for the
@ -185,7 +191,6 @@ class EllpackPageImpl {
/*! \brief global index of histogram, which is stored in ELLPack format. */ /*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer; common::Span<common::CompressedByteT> gidx_buffer;
std::vector<common::CompressedByteT> idx_buffer; std::vector<common::CompressedByteT> idx_buffer;
size_t n_rows{};
/*! /*!
* \brief Default constructor. * \brief Default constructor.
@ -240,7 +245,7 @@ class EllpackPageImpl {
/*! \brief Set the base row id for this page. */ /*! \brief Set the base row id for this page. */
inline void SetBaseRowId(size_t row_id) { inline void SetBaseRowId(size_t row_id) {
base_rowid_ = row_id; matrix.base_rowid = row_id;
} }
/*! \brief clear the page. */ /*! \brief clear the page. */
@ -263,11 +268,17 @@ class EllpackPageImpl {
*/ */
void InitDevice(int device, EllpackInfo info); void InitDevice(int device, EllpackInfo info);
/*! \brief Compress the accumulated SparsePage into ELLPACK format.
*
* @param device The GPU device to use.
*/
void CompressSparsePage(int device);
private: private:
common::Monitor monitor_; common::Monitor monitor_;
dh::BulkAllocator ba_; dh::BulkAllocator ba_;
size_t base_rowid_{};
bool device_initialized_{false}; bool device_initialized_{false};
SparsePage sparse_page_{};
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -17,7 +17,8 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
public: public:
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
auto* impl = page->Impl(); auto* impl = page->Impl();
if (!fi->Read(&impl->n_rows)) return false; impl->Clear();
if (!fi->Read(&impl->matrix.n_rows)) return false;
return fi->Read(&impl->idx_buffer); return fi->Read(&impl->idx_buffer);
} }
@ -25,13 +26,14 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
dmlc::SeekStream* fi, dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) override { const std::vector<bst_uint>& sorted_index_set) override {
auto* impl = page->Impl(); auto* impl = page->Impl();
if (!fi->Read(&impl->n_rows)) return false; impl->Clear();
if (!fi->Read(&impl->matrix.n_rows)) return false;
return fi->Read(&page->Impl()->idx_buffer); return fi->Read(&page->Impl()->idx_buffer);
} }
void Write(const EllpackPage& page, dmlc::Stream* fo) override { void Write(const EllpackPage& page, dmlc::Stream* fo) override {
auto* impl = page.Impl(); auto* impl = page.Impl();
fo->Write(impl->n_rows); fo->Write(impl->matrix.n_rows);
auto buffer = impl->idx_buffer; auto buffer = impl->idx_buffer;
CHECK(!buffer.empty()); CHECK(!buffer.empty());
fo->Write(buffer); fo->Write(buffer);

View File

@ -40,11 +40,13 @@ class EllpackPageSourceImpl : public DataSource<EllpackPage> {
const std::string kPageType_{".ellpack.page"}; const std::string kPageType_{".ellpack.page"};
int device_{-1}; int device_{-1};
size_t page_size_{DMatrix::kPageSize};
common::Monitor monitor_; common::Monitor monitor_;
dh::BulkAllocator ba_; dh::BulkAllocator ba_;
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */ /*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
EllpackInfo ellpack_info_; EllpackInfo ellpack_info_;
std::unique_ptr<SparsePageSource<EllpackPage>> source_; std::unique_ptr<SparsePageSource<EllpackPage>> source_;
std::string cache_info_;
}; };
EllpackPageSource::EllpackPageSource(DMatrix* dmat, EllpackPageSource::EllpackPageSource(DMatrix* dmat,
@ -72,8 +74,12 @@ const EllpackPage& EllpackPageSource::Value() const {
// each CSR page, and write the accumulated ELLPACK pages to disk. // each CSR page, and write the accumulated ELLPACK pages to disk.
EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
const std::string& cache_info, const std::string& cache_info,
const BatchParam& param) noexcept(false) { const BatchParam& param) noexcept(false)
device_ = param.gpu_id; : device_(param.gpu_id), cache_info_(cache_info) {
if (param.gpu_page_size > 0) {
page_size_ = param.gpu_page_size;
}
monitor_.Init("ellpack_page_source"); monitor_.Init("ellpack_page_source");
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
@ -92,10 +98,11 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
WriteEllpackPages(dmat, cache_info); WriteEllpackPages(dmat, cache_info);
monitor_.StopCuda("WriteEllpackPages"); monitor_.StopCuda("WriteEllpackPages");
source_.reset(new SparsePageSource<EllpackPage>(cache_info, kPageType_)); source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
} }
void EllpackPageSourceImpl::BeforeFirst() { void EllpackPageSourceImpl::BeforeFirst() {
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
source_->BeforeFirst(); source_->BeforeFirst();
} }
@ -133,20 +140,23 @@ void EllpackPageSourceImpl::WriteEllpackPages(DMatrix* dmat, const std::string&
for (const auto& batch : dmat->GetBatches<SparsePage>()) { for (const auto& batch : dmat->GetBatches<SparsePage>()) {
impl->Push(device_, batch); impl->Push(device_, batch);
if (impl->MemCostBytes() >= DMatrix::kPageSize) { size_t mem_cost_bytes = impl->MemCostBytes();
bytes_write += impl->MemCostBytes(); if (mem_cost_bytes >= page_size_) {
bytes_write += mem_cost_bytes;
impl->CompressSparsePage(device_);
writer.PushWrite(std::move(page)); writer.PushWrite(std::move(page));
writer.Alloc(&page); writer.Alloc(&page);
impl = page->Impl(); impl = page->Impl();
impl->matrix.info = ellpack_info_; impl->matrix.info = ellpack_info_;
impl->Clear(); impl->Clear();
double tdiff = dmlc::GetTime() - tstart; double tdiff = dmlc::GetTime() - tstart;
LOG(INFO) << "Writing to " << cache_info << " in " LOG(INFO) << "Writing " << kPageType_ << " to " << cache_info << " in "
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
<< (bytes_write >> 20UL) << " written"; << (bytes_write >> 20UL) << " written";
} }
} }
if (impl->Size() != 0) { if (impl->Size() != 0) {
impl->CompressSparsePage(device_);
writer.PushWrite(std::move(page)); writer.PushWrite(std::move(page));
} }
} }

View File

@ -81,10 +81,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0); CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
// Lazily instantiate // Lazily instantiate
if (!ellpack_source_ || if (!ellpack_source_ || batch_param_ != param) {
batch_param_.gpu_id != param.gpu_id ||
batch_param_.max_bin != param.max_bin ||
batch_param_.gpu_batch_nrows != param.gpu_batch_nrows) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param)); ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param; batch_param_ = param;
} }

View File

@ -33,6 +33,7 @@ class RowPartitioner {
public: public:
using RowIndexT = bst_uint; using RowIndexT = bst_uint;
struct Segment; struct Segment;
static constexpr bst_node_t kIgnoredTreePosition = -1;
private: private:
int device_idx; int device_idx;
@ -124,6 +125,7 @@ class RowPartitioner {
idx += segment.begin; idx += segment.begin;
RowIndexT ridx = d_ridx[idx]; RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx); // new node id bst_node_t new_position = op(ridx); // new node id
if (new_position == kIgnoredTreePosition) return;
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx); AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position; d_position[idx] = new_position;
@ -163,7 +165,9 @@ class RowPartitioner {
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) { dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx]; auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx]; RowIndexT ridx = d_ridx[idx];
d_position[idx] = op(ridx, position); bst_node_t new_position = op(ridx, position);
if (new_position == kIgnoredTreePosition) return;
d_position[idx] = new_position;
}); });
} }

View File

@ -409,13 +409,16 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
extern __shared__ char smem[]; extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) { if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT()); dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT());
__syncthreads(); __syncthreads();
} }
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) { for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride ]; int ridx = d_ridx[idx / matrix.info.row_stride];
int gidx = if (!matrix.IsInRange(ridx)) {
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; continue;
}
int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride
+ idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) { if (gidx != matrix.info.n_bins) {
// If we are not using shared memory, accumulate the values directly into // If we are not using shared memory, accumulate the values directly into
// global memory // global memory
@ -428,8 +431,7 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
if (use_shared_memory_histograms) { if (use_shared_memory_histograms) {
// Write shared memory back to global memory // Write shared memory back to global memory
__syncthreads(); __syncthreads();
for (auto i : for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.info.n_bins)) {
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]); dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]);
} }
} }
@ -440,6 +442,7 @@ template <typename GradientSumT>
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
int device_id; int device_id;
EllpackPageImpl* page; EllpackPageImpl* page;
BatchParam batch_param;
dh::BulkAllocator ba; dh::BulkAllocator ba;
@ -481,14 +484,16 @@ struct GPUHistMakerDevice {
bst_uint _n_rows, bst_uint _n_rows,
TrainParam _param, TrainParam _param,
uint32_t column_sampler_seed, uint32_t column_sampler_seed,
uint32_t n_features) uint32_t n_features,
BatchParam _batch_param)
: device_id(_device_id), : device_id(_device_id),
page(_page), page(_page),
n_rows(_n_rows), n_rows(_n_rows),
param(std::move(_param)), param(std::move(_param)),
prediction_cache_initialised(false), prediction_cache_initialised(false),
column_sampler(column_sampler_seed), column_sampler(column_sampler_seed),
interaction_constraints(param, n_features) { interaction_constraints(param, n_features),
batch_param(_batch_param) {
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
} }
@ -626,6 +631,14 @@ struct GPUHistMakerDevice {
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end()); return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
} }
// Build gradient histograms for a given node across all the batches in the DMatrix.
void BuildHistBatches(int nidx, DMatrix* p_fmat) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
page = batch.Impl();
BuildHist(nidx);
}
}
void BuildHist(int nidx) { void BuildHist(int nidx) {
hist.AllocateHistogram(nidx); hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
@ -636,7 +649,7 @@ struct GPUHistMakerDevice {
const size_t smem_size = const size_t smem_size =
use_shared_memory_histograms use_shared_memory_histograms
? sizeof(GradientSumT) * page->matrix.BinCount() ? sizeof(GradientSumT) * page->matrix.info.n_bins
: 0; : 0;
uint32_t items_per_thread = 8; uint32_t items_per_thread = 8;
uint32_t block_threads = 256; uint32_t block_threads = 256;
@ -673,7 +686,10 @@ struct GPUHistMakerDevice {
row_partitioner->UpdatePosition( row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(), nidx, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(bst_uint ridx) { [=] __device__(size_t ridx) {
if (!d_matrix.IsInRange(ridx)) {
return RowPartitioner::kIgnoredTreePosition;
}
// given a row index, returns the node id it belongs to // given a row index, returns the node id it belongs to
bst_float cut_value = bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex()); d_matrix.GetElement(ridx, split_node.SplitIndex());
@ -693,35 +709,42 @@ struct GPUHistMakerDevice {
} }
// After tree update is finished, update the position of all training // After tree update is finished, update the position of all training
// instances to their final leaf This information is used later to update the // instances to their final leaf. This information is used later to update the
// prediction cache // prediction cache
void FinalisePosition(RegTree* p_tree) { void FinalisePosition(RegTree* p_tree, DMatrix* p_fmat) {
const auto d_nodes = const auto d_nodes =
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size()); temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node), d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
auto d_matrix = page->matrix;
row_partitioner->FinalisePosition(
[=] __device__(bst_uint ridx, int position) {
auto node = d_nodes[position];
while (!node.IsLeaf()) { for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex()); page = batch.Impl();
// Missing value auto d_matrix = page->matrix;
if (isnan(element)) { row_partitioner->FinalisePosition(
position = node.DefaultChild(); [=] __device__(size_t row_id, int position) {
} else { if (!d_matrix.IsInRange(row_id)) {
if (element <= node.SplitCond()) { return RowPartitioner::kIgnoredTreePosition;
position = node.LeftChild();
} else {
position = node.RightChild();
}
} }
node = d_nodes[position]; auto node = d_nodes[position];
}
return position; while (!node.IsLeaf()) {
}); bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
return position;
});
}
} }
void UpdatePredictionCache(bst_float* out_preds_d) { void UpdatePredictionCache(bst_float* out_preds_d) {
@ -764,7 +787,7 @@ struct GPUHistMakerDevice {
reducer->AllReduceSum( reducer->AllReduceSum(
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist), reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist), reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
page->matrix.BinCount() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); page->matrix.info.n_bins * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->Synchronize(); reducer->Synchronize();
monitor.StopCuda("AllReduce"); monitor.StopCuda("AllReduce");
@ -773,12 +796,10 @@ struct GPUHistMakerDevice {
/** /**
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) {
int nidx_right, dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left; auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right; auto subtraction_trick_nidx = nidx_right;
// Decide whether to build the left histogram or right histogram // Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances // Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
@ -787,22 +808,50 @@ struct GPUHistMakerDevice {
} }
this->BuildHist(build_hist_nidx); this->BuildHist(build_hist_nidx);
this->AllReduceHist(build_hist_nidx, reducer);
// Check whether we can use the subtraction trick to calculate the other // Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick( bool do_subtraction_trick = this->CanDoSubtractionTrick(
candidate.nid, build_hist_nidx, subtraction_trick_nidx); candidate.nid, build_hist_nidx, subtraction_trick_nidx);
if (!do_subtraction_trick) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
}
}
/**
* \brief AllReduce GPU histograms for the left and right child of some parent node.
*/
void ReduceHistLeftRight(const ExpandEntry& candidate,
int nidx_left,
int nidx_right,
dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
}
this->AllReduceHist(build_hist_nidx, reducer);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick(
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
if (do_subtraction_trick) { if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick // Calculate other histogram using subtraction trick
this->SubtractionTrick(candidate.nid, build_hist_nidx, this->SubtractionTrick(candidate.nid, build_hist_nidx,
subtraction_trick_nidx); subtraction_trick_nidx);
} else { } else {
// Calculate other histogram manually // Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer); this->AllReduceHist(subtraction_trick_nidx, reducer);
} }
} }
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree; RegTree& tree = *p_tree;
@ -839,7 +888,7 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild()); tree[candidate.nid].RightChild());
} }
void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, void InitRoot(RegTree* p_tree, HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
dh::AllReducer* reducer, int64_t num_columns) { dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0; constexpr int kRootNIdx = 0;
@ -855,7 +904,7 @@ struct GPUHistMakerDevice {
node_sum_gradients_d.data(), sizeof(GradientPair), node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
this->BuildHist(kRootNIdx); this->BuildHistBatches(kRootNIdx, p_fmat);
this->AllReduceHist(kRootNIdx, reducer); this->AllReduceHist(kRootNIdx, reducer);
// Remember root stats // Remember root stats
@ -882,7 +931,7 @@ struct GPUHistMakerDevice {
monitor.StopCuda("Reset"); monitor.StopCuda("Reset");
monitor.StartCuda("InitRoot"); monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot"); monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size(); auto timestamp = qexpand->size();
@ -901,15 +950,21 @@ struct GPUHistMakerDevice {
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild(); int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed // Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) {
num_leaves)) { for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
monitor.StartCuda("UpdatePosition"); page = batch.Impl();
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");
monitor.StartCuda("BuildHist"); monitor.StartCuda("UpdatePosition");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("BuildHist"); monitor.StopCuda("UpdatePosition");
monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx);
monitor.StopCuda("BuildHist");
}
monitor.StartCuda("ReduceHist");
this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("ReduceHist");
monitor.StartCuda("EvaluateSplits"); monitor.StartCuda("EvaluateSplits");
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx}, auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
@ -926,7 +981,7 @@ struct GPUHistMakerDevice {
} }
monitor.StartCuda("FinalisePosition"); monitor.StartCuda("FinalisePosition");
this->FinalisePosition(p_tree); this->FinalisePosition(p_tree, p_fmat);
monitor.StopCuda("FinalisePosition"); monitor.StopCuda("FinalisePosition");
} }
}; };
@ -1016,21 +1071,21 @@ class GPUHistMakerSpecialised {
uint32_t column_sampling_seed = common::GlobalRandom()(); uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
// TODO(rongou): support multiple Ellpack pages. BatchParam batch_param{
EllpackPageImpl* page{}; device_,
for (auto& batch : dmat->GetBatches<EllpackPage>({device_, param_.max_bin,
param_.max_bin, hist_maker_param_.gpu_batch_nrows,
hist_maker_param_.gpu_batch_nrows})) { generic_param_->gpu_page_size
page = batch.Impl(); };
} auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_, maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
page, page,
info_->num_row_, info_->num_row_,
param_, param_,
column_sampling_seed, column_sampling_seed,
info_->num_col_)); info_->num_col_,
batch_param));
monitor_.StartCuda("InitHistogram"); monitor_.StartCuda("InitHistogram");
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));

View File

@ -0,0 +1,87 @@
"""Generate synthetic data in LibSVM format."""
import argparse
import io
import time
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
RNG = np.random.RandomState(2019)
def generate_data(args):
"""Generates the data."""
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
print("Sparsity {}".format(args.sparsity))
print("{}/{} train/test split".format(1.0 - args.test_size, args.test_size))
tmp = time.time()
n_informative = args.columns * 7 // 10
n_redundant = args.columns // 10
n_repeated = args.columns // 10
print("n_informative: {}, n_redundant: {}, n_repeated: {}".format(n_informative, n_redundant,
n_repeated))
x, y = make_classification(n_samples=args.rows, n_features=args.columns,
n_informative=n_informative, n_redundant=n_redundant,
n_repeated=n_repeated, shuffle=False, random_state=RNG)
print("Generate Time: {} seconds".format(time.time() - tmp))
tmp = time.time()
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=args.test_size,
random_state=RNG, shuffle=False)
print("Train/Test Split Time: {} seconds".format(time.time() - tmp))
tmp = time.time()
write_file('train.libsvm', x_train, y_train, args.sparsity)
print("Write Train Time: {} seconds".format(time.time() - tmp))
tmp = time.time()
write_file('test.libsvm', x_test, y_test, args.sparsity)
print("Write Test Time: {} seconds".format(time.time() - tmp))
def write_file(filename, x_data, y_data, sparsity):
with open(filename, 'w') as f:
for x, y in zip(x_data, y_data):
write_line(f, x, y, sparsity)
def write_line(f, x, y, sparsity):
with io.StringIO() as line:
line.write(str(y))
for i, col in enumerate(x):
if 0.0 < sparsity < 1.0:
if RNG.uniform(0, 1) > sparsity:
write_feature(line, i, col)
else:
write_feature(line, i, col)
line.write('\n')
f.write(line.getvalue())
def write_feature(line, index, feature):
line.write(' ')
line.write(str(index))
line.write(':')
line.write(str(feature))
def main():
"""The main function.
Defines and parses command line arguments and calls the generator.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--rows', type=int, default=1000000)
parser.add_argument('--columns', type=int, default=50)
parser.add_argument('--sparsity', type=float, default=0.0)
parser.add_argument('--test_size', type=float, default=0.01)
args = parser.parse_args()
generate_data(args)
if __name__ == '__main__':
main()

View File

@ -2,10 +2,11 @@
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/common/compressed_iterator.h"
namespace xgboost { namespace xgboost {
TEST(GPUSparsePageDMatrix, EllpackPage) { TEST(SparsePageDMatrix, EllpackPage) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
@ -23,4 +24,162 @@ TEST(GPUSparsePageDMatrix, EllpackPage) {
delete dmat; delete dmat;
} }
TEST(SparsePageDMatrix, MultipleEllpackPages) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256, 0, 7UL})) {
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
batch_count++;
row_count += batch.Size();
}
EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
EXPECT_TRUE(FileExists(filename + ".cache.ellpack.page"));
}
TEST(SparsePageDMatrix, EllpackPageContent) {
constexpr size_t kRows = 6;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
BatchParam param{0, 2, 0, 0};
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
EXPECT_EQ(impl->matrix.base_rowid, 0);
EXPECT_EQ(impl->matrix.n_rows, kRows);
EXPECT_FALSE(impl->matrix.info.is_dense);
EXPECT_EQ(impl->matrix.info.row_stride, 2);
EXPECT_EQ(impl->matrix.info.n_bins, 4);
auto impl_ext = (*dmat_ext->GetBatches<EllpackPage>(param).begin()).Impl();
EXPECT_EQ(impl_ext->matrix.base_rowid, 0);
EXPECT_EQ(impl_ext->matrix.n_rows, kRows);
EXPECT_FALSE(impl_ext->matrix.info.is_dense);
EXPECT_EQ(impl_ext->matrix.info.row_stride, 2);
EXPECT_EQ(impl_ext->matrix.info.n_bins, 4);
std::vector<common::CompressedByteT> buffer(impl->gidx_buffer.size());
std::vector<common::CompressedByteT> buffer_ext(impl_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&buffer, impl->gidx_buffer);
dh::CopyDeviceSpanToVector(&buffer_ext, impl_ext->gidx_buffer);
EXPECT_EQ(buffer, buffer_ext);
}
struct ReadRowFunction {
EllpackMatrix matrix;
int row;
bst_float* row_data_d;
ReadRowFunction(EllpackMatrix matrix, int row, bst_float* row_data_d)
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
__device__ void operator()(size_t col) {
auto value = matrix.GetElement(row, col);
if (isnan(value)) {
value = -1;
}
row_data_d[col] = value;
}
};
TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
constexpr size_t kRows = 6;
constexpr size_t kCols = 2;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 1;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
BatchParam param{0, kMaxBins, 0, kPageSize};
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
EXPECT_EQ(impl->matrix.base_rowid, 0);
EXPECT_EQ(impl->matrix.n_rows, kRows);
size_t current_row = 0;
thrust::device_vector<bst_float> row_d(kCols);
thrust::device_vector<bst_float> row_ext_d(kCols);
std::vector<bst_float> row(kCols);
std::vector<bst_float> row_ext(kCols);
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
auto impl_ext = page.Impl();
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
for (size_t i = 0; i < impl_ext->Size(); i++) {
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get()));
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());
EXPECT_EQ(row, row_ext);
current_row++;
}
}
}
TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 4096;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
BatchParam param{0, kMaxBins, 0, kPageSize};
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
size_t current_row = 0;
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
auto impl_ext = page.Impl();
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
current_row += impl_ext->matrix.n_rows;
}
current_row = 0;
thrust::device_vector<bst_float> row_d(kCols);
thrust::device_vector<bst_float> row_ext_d(kCols);
std::vector<bst_float> row(kCols);
std::vector<bst_float> row_ext(kCols);
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
auto impl_ext = page.Impl();
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
for (size_t i = 0; i < impl_ext->Size(); i++) {
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get()));
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());
EXPECT_EQ(row, row_ext) << "for row " << current_row;
current_row++;
}
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -217,17 +217,17 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
} else { } else {
gen.reset(new std::mt19937(rdev())); gen.reset(new std::mt19937(rdev()));
} }
std::uniform_int_distribution<size_t> label(0, 1);
std::uniform_int_distribution<size_t> dis(1, n_cols); std::uniform_int_distribution<size_t> dis(1, n_cols);
for (size_t i = 0; i < n_rows; ++i) { for (size_t i = 0; i < n_rows; ++i) {
// Make sure that all cols are slotted in the first few rows; randomly distribute the // Make sure that all cols are slotted in the first few rows; randomly distribute the
// rest // rest
std::stringstream row_data; std::stringstream row_data;
fo << i;
size_t j = 0; size_t j = 0;
if (rem_cols > 0) { if (rem_cols > 0) {
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) { for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
row_data << " " << (col_idx+j) << ":" << (col_idx+j+1)*10; row_data << label(*gen) << " " << (col_idx+j) << ":" << (col_idx+j+1)*10*i;
} }
rem_cols -= cols_per_row; rem_cols -= cols_per_row;
} else { } else {
@ -235,7 +235,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
size_t ncols = dis(*gen); size_t ncols = dis(*gen);
for (; j < ncols; ++j) { for (; j < ncols; ++j) {
size_t fid = (col_idx+j) % n_cols; size_t fid = (col_idx+j) % n_cols;
row_data << " " << fid << ":" << (fid+1)*10; row_data << label(*gen) << " " << fid << ":" << (fid+1)*10*i;
} }
} }
col_idx += j; col_idx += j;

View File

@ -56,22 +56,6 @@ TEST(GpuHist, DeviceHistogram) {
}; };
} }
namespace {
class HistogramCutsWrapper : public common::HistogramCuts {
public:
using SuperT = common::HistogramCuts;
void SetValues(std::vector<float> cuts) {
SuperT::cut_values_ = cuts;
}
void SetPtrs(std::vector<uint32_t> ptrs) {
SuperT::cut_ptrs_ = ptrs;
}
void SetMins(std::vector<float> mins) {
SuperT::min_vals_ = mins;
}
};
} // anonymous namespace
std::vector<GradientPairPrecise> GetHostHistGpair() { std::vector<GradientPairPrecise> GetHostHistGpair() {
// 24 bins, 3 bins for each feature (column). // 24 bins, 3 bins for each feature (column).
std::vector<GradientPairPrecise> hist_gpair = { std::vector<GradientPairPrecise> hist_gpair = {
@ -98,7 +82,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
}; };
param.Init(args); param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols); BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param);
maker.InitHistogram(); maker.InitHistogram();
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
@ -199,7 +184,9 @@ TEST(GpuHist, EvaluateSplits) {
// Initialize GPUHistMakerDevice // Initialize GPUHistMakerDevice
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
GPUHistMakerDevice<GradientPairPrecise> maker(0, page.get(), kNRows, param, kNCols, kNCols); BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients // Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {{6.4f, 12.8f}}; maker.node_sum_gradients = {{6.4f, 12.8f}};
@ -332,21 +319,25 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPa
return n_nodes; return n_nodes;
} }
TEST(GpuHist, MinSplitLoss) { HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kRows); std::vector<GradientPair> h_gpair(n_rows);
for (auto &gpair : h_gpair) { for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen); bst_float grad = dist(&gen);
bst_float hess = dist(&gen); bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess); gpair = GradientPair(grad, hess);
} }
HostDeviceVector<GradientPair> gpair(h_gpair); HostDeviceVector<GradientPair> gpair(h_gpair);
return gpair;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
auto gpair = GenerateRandomGradients(kRows);
{ {
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair); int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair);
@ -363,5 +354,75 @@ TEST(GpuHist, MinSplitLoss) {
delete dmat; delete dmat;
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair,
DMatrix* dmat,
size_t gpu_page_size,
RegTree* tree,
HostDeviceVector<bst_float>* preds) {
constexpr size_t kMaxBin = 2;
if (gpu_page_size > 0) {
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, kMaxBin, 0, gpu_page_size})) {
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
batch_count++;
row_count += batch.Size();
}
EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
}
Args args{
{"max_depth", "2"},
{"max_bin", std::to_string(kMaxBin)},
{"min_child_weight", "0.0"},
{"reg_alpha", "0"},
{"reg_lambda", "0"}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
GenericParameter generic_param(CreateEmptyGenericParam(0));
generic_param.gpu_page_size = gpu_page_size;
hist_maker.Configure(args, &generic_param);
hist_maker.Update(gpair, dmat, {tree});
hist_maker.UpdatePredictionCache(dmat, preds);
}
TEST(GpuHist, ExternalMemory) {
constexpr size_t kRows = 6;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto gpair = GenerateRandomGradients(kRows);
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds);
// Build another tree using multiple ELLPACK pages.
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) {
ASSERT_FLOAT_EQ(preds_h[i], preds_ext_h[i]);
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -19,17 +19,19 @@ def assert_gpu_results(cpu_results, gpu_results):
datasets = ["Boston", "Cancer", "Digits", "Sparse regression", datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
"Sparse regression with weights", "Small weights regression"] "Sparse regression with weights", "Small weights regression"]
test_param = parameter_combinations({
'gpu_id': [0],
'max_depth': [2, 8],
'max_leaves': [255, 4],
'max_bin': [2, 256],
'grow_policy': ['lossguide'],
'single_precision_histogram': [True],
'min_child_weight': [0],
'lambda': [0]})
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_gpu_hist(self): def test_gpu_hist(self):
test_param = parameter_combinations({'gpu_id': [0],
'max_depth': [2, 8],
'max_leaves': [255, 4],
'max_bin': [2, 256],
'grow_policy': ['lossguide']})
test_param.append({'single_precision_histogram': True})
test_param.append({'min_child_weight': 0,
'lambda': 0})
for param in test_param: for param in test_param:
param['tree_method'] = 'gpu_hist' param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets) gpu_results = run_suite(param, select_datasets=datasets)
@ -38,6 +40,19 @@ class TestGPU(unittest.TestCase):
cpu_results = run_suite(param, select_datasets=datasets) cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results) assert_gpu_results(cpu_results, gpu_results)
# NOTE(rongou): Because the `Boston` dataset is too small, this only tests external memory mode
# with a single page. To test multiple pages, set DMatrix::kPageSize to, say, 1024.
def test_external_memory(self):
for param in reversed(test_param):
param['tree_method'] = 'gpu_hist'
param['gpu_page_size'] = 1024
gpu_results = run_suite(param, select_datasets=["Boston"])
assert_results_non_increasing(gpu_results, 1e-2)
ext_mem_results = run_suite(param, select_datasets=["Boston External Memory"])
assert_results_non_increasing(ext_mem_results, 1e-2)
assert_gpu_results(gpu_results, ext_mem_results)
break
def test_with_empty_dmatrix(self): def test_with_empty_dmatrix(self):
# FIXME(trivialfis): This should be done with all updaters # FIXME(trivialfis): This should be done with all updaters
kRows = 0 kRows = 0