xgboost/src/data/ellpack_page.cu
Jiaming Yuan 7663de956c
Run training with empty DMatrix. (#4990)
This makes GPU Hist robust in distributed environment as some workers might not
be associated with any data in either training or evaluation.

* Disable rabit mock test for now: See #5012 .

* Disable dask-cudf test at prediction for now: See #5003

* Launch dask job for all workers despite they might not have any data.
* Check 0 rows in elementwise evaluation metrics.

   Using AUC and AUC-PR still throws an error.  See #4663 for a robust fix.

* Add tests for edge cases.
* Add `LaunchKernel` wrapper handling zero sized grid.
* Move some parts of allreducer into a cu file.
* Don't validate feature names when the booster is empty.

* Sync number of columns in DMatrix.

  As num_feature is required to be the same across all workers in data split
  mode.

* Filtering in dask interface now by default syncs all booster that's not
empty, instead of using rank 0.

* Fix Jenkins' GPU tests.

* Install dask-cuda from source in Jenkins' test.

  Now all tests are actually running.

* Restore GPU Hist tree synchronization test.

* Check UUID of running devices.

  The check is only performed on CUDA version >= 10.x, as 9.x doesn't have UUID field.

* Fix CMake policy and project variables.

  Use xgboost_SOURCE_DIR uniformly, add policy for CMake >= 3.13.

* Fix copying data to CPU

* Fix race condition in cpu predictor.

* Fix duplicated DMatrix construction.

* Don't download extra nccl in CI script.
2019-11-06 16:13:13 +08:00

277 lines
10 KiB
Plaintext

/*!
* Copyright 2019 XGBoost contributors
*/
#include <xgboost/data.h>
#include "./ellpack_page.cuh"
#include "../common/hist_util.h"
#include "../common/random.h"
namespace xgboost {
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl(dmat, param)} {}
EllpackPage::~EllpackPage() = default;
size_t EllpackPage::Size() const {
return impl_->Size();
}
void EllpackPage::SetBaseRowId(size_t row_id) {
impl_->SetBaseRowId(row_id);
}
// Bin each input data entry, store the bin indices in compressed form.
__global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr,
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride) {
return;
}
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_length) {
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float* feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts) {
bin = ncuts - 1;
}
// Add the number of bins in previous features.
bin += cut_rows[feature];
}
// Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(param.gpu_id));
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
size_t row_stride =
common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
monitor_.StopCuda("Quantiles");
monitor_.StartCuda("InitEllpackInfo");
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat);
monitor_.StopCuda("InitEllpackInfo");
monitor_.StartCuda("InitCompressedData");
InitCompressedData(param.gpu_id, dmat->Info().num_row_);
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(dmat->Info().num_row_);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
CreateHistIndices(param.gpu_id, batch, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch();
}
monitor_.StopCuda("BinningCompression");
}
// Construct an EllpackInfo based on histogram cuts of features.
EllpackInfo::EllpackInfo(int device,
bool is_dense,
size_t row_stride,
const common::HistogramCuts& hmat,
dh::BulkAllocator* ba)
: is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) {
ba->Allocate(device,
&feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size());
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
}
// Initialize the EllpackInfo for this page.
void EllpackPageImpl::InitInfo(int device,
bool is_dense,
size_t row_stride,
const common::HistogramCuts& hmat) {
matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, &ba_);
}
// Initialize the buffer to stored compressed features.
void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) {
size_t num_symbols = matrix.info.n_bins + 1;
// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
matrix.info.row_stride * num_rows, num_symbols);
ba_.Allocate(device, &gidx_buffer, compressed_size_bytes);
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
matrix.gidx_iter = common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols);
}
// Compress a CSR page into ELLPACK.
void EllpackPageImpl::CreateHistIndices(int device,
const SparsePage& row_batch,
const RowStateOnDevice& device_row_state) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = matrix.info.n_bins;
size_t row_stride = matrix.info.row_stride;
const auto& offset_vec = row_batch.offset.ConstHostVector();
int num_symbols = matrix.info.n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
batch_row_end = device_row_state.rows_to_process_from_batch;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
const auto ent_cnt_begin =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
const auto ent_cnt_end =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];
/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows + 1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
row_ptrs.begin());
// number of entries in this batch.
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry),
cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y),
1);
dh::LaunchKernel {grid3, block3} (
CompressBinEllpackKernel,
common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get(),
entries_d.data().get(),
matrix.info.gidx_fvalue_map.data(),
matrix.info.feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_stride,
null_gidx_value);
}
}
// Return the number of rows contained in this page.
size_t EllpackPageImpl::Size() const {
return n_rows;
}
// Clear the current page.
void EllpackPageImpl::Clear() {
ba_.Clear();
gidx_buffer = {};
idx_buffer.clear();
n_rows = 0;
}
// Push a CSR page to the current page.
//
// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and
// appended to the existing host vector.
void EllpackPageImpl::Push(int device, const SparsePage& batch) {
monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, batch.Size());
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(batch.Size());
hist_builder_row_state.BeginBatch(batch);
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch();
monitor_.StopCuda("BinningCompression");
monitor_.StartCuda("CopyDeviceToHost");
std::vector<common::CompressedByteT> buffer(gidx_buffer.size());
dh::CopyDeviceSpanToVector(&buffer, gidx_buffer);
int offset = 0;
if (!idx_buffer.empty()) {
offset = ::xgboost::common::detail::kPadding;
}
idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset);
idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end());
ba_.Clear();
gidx_buffer = {};
monitor_.StopCuda("CopyDeviceToHost");
n_rows += batch.Size();
}
// Return the memory cost for storing the compressed features.
size_t EllpackPageImpl::MemCostBytes() const {
return idx_buffer.size() * sizeof(common::CompressedByteT);
}
// Copy the compressed features to GPU.
void EllpackPageImpl::InitDevice(int device, EllpackInfo info) {
if (device_initialized_) return;
monitor_.StartCuda("CopyPageToDevice");
dh::safe_cuda(cudaSetDevice(device));
gidx_buffer = {};
ba_.Allocate(device, &gidx_buffer, idx_buffer.size());
dh::CopyVectorToDeviceSpan(gidx_buffer, idx_buffer);
matrix.info = info;
matrix.gidx_iter = common::CompressedIterator<uint32_t>(gidx_buffer.data(), info.n_bins + 1);
monitor_.StopCuda("CopyPageToDevice");
device_initialized_ = true;
}
} // namespace xgboost