GPU binning and compression. (#3319)

* GPU binning and compression.

- binning and index compression are done inside the DeviceShard constructor
- in case of a DMatrix with multiple row batches, it is first converted into a single row batch
This commit is contained in:
Andy Adinets 2018-06-05 07:15:13 +02:00 committed by Rory Mitchell
parent 3f7696ff53
commit 286dccb8e8
10 changed files with 302 additions and 67 deletions

View File

@ -73,3 +73,4 @@ List of Contributors
* [Gideon Whitehead](https://github.com/gaw89)
* [Yi-Lin Juang](https://github.com/frankyjuang)
* [Andrew Hannigan](https://github.com/andrewhannigan)
* [Andy Adinets](https://github.com/canonizer)

View File

@ -126,15 +126,15 @@ struct SparseBatch {
/*! \brief feature value */
bst_float fvalue;
/*! \brief default constructor */
Entry() = default;
XGBOOST_DEVICE Entry() {}
/*!
* \brief constructor with index and value
* \param index The feature or row index.
* \param fvalue THe feature value.
*/
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
XGBOOST_DEVICE Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
inline static bool CmpValue(const Entry& a, const Entry& b) {
XGBOOST_DEVICE inline static bool CmpValue(const Entry& a, const Entry& b) {
return a.fvalue < b.fvalue;
}
};

View File

@ -8,6 +8,10 @@
#include <cstddef>
#include <algorithm>
#ifdef __CUDACC__
#include "device_helpers.cuh"
#endif
namespace xgboost {
namespace common {
@ -96,6 +100,23 @@ class CompressedBufferWriter {
}
}
}
#ifdef __CUDACC__
__device__ void AtomicWriteSymbol
(CompressedByteT* buffer, uint64_t symbol, size_t offset) {
size_t ibit_start = offset * symbol_bits_;
size_t ibit_end = (offset + 1) * symbol_bits_ - 1;
size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8;
symbol <<= 7 - ibit_end % 8;
for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) {
dh::AtomicOrByte(reinterpret_cast<unsigned int*>(buffer + detail::kPadding),
ibyte, symbol & 0xff);
symbol >>= 8;
}
}
#endif
template <typename IterT>
void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) {
uint64_t tmp = 0;

View File

@ -122,6 +122,14 @@ inline size_t AvailableMemory(int device_idx) {
return device_free;
}
inline size_t TotalMemory(int device_idx) {
size_t device_free = 0;
size_t device_total = 0;
safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
return device_total;
}
/**
* \fn inline int max_shared_memory(int device_idx)
*
@ -155,6 +163,12 @@ inline void CheckComputeCapability() {
}
}
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
}
/*
* Range iterator
*/

View File

@ -183,6 +183,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_lower_bound(-1)
.set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);

View File

@ -2,6 +2,9 @@
* Copyright 2017 XGBoost contributors
*/
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <xgboost/tree_updater.h>
@ -224,6 +227,53 @@ struct CalcWeightTrainParam {
learning_rate(p.learning_rate) {}
};
// index of the first element in cuts greater than v, or n if none;
// cuts are ordered, and binary search is used
__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
if (n == 0)
return 0;
if (cuts[n - 1] <= v)
return n;
if (cuts[0] > v)
return 0;
int left = 0, right = n - 1;
while (right - left > 1) {
int middle = left + (right - left) / 2;
if (cuts[middle] > v)
right = middle;
else
left = middle;
}
return right;
}
__global__ void compress_bin_ellpack_k
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
const size_t* __restrict__ row_ptrs,
const RowBatch::Entry* __restrict__ entries,
const float* __restrict__ cuts, const size_t* __restrict__ cut_rows,
size_t base_row, size_t n_rows, size_t row_ptr_begin, size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride)
return;
int row_size = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_size) {
RowBatch::Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
bin = upper_bound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts)
bin = ncuts - 1;
bin += cut_rows[feature];
}
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
// Manage memory for a single GPU
struct DeviceShard {
struct Segment {
@ -271,74 +321,117 @@ struct DeviceShard {
dh::CubMemory temp_memory;
DeviceShard(int device_idx, int normalised_device_idx,
const common::GHistIndexMatrix& gmat, bst_uint row_begin,
bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins),
param(param),
prediction_cache_initialised(false) {
// Convert to ELLPACK matrix representation
int max_elements_row = 0;
for (auto i = row_begin; i < row_end; i++) {
max_elements_row =
(std::max)(max_elements_row,
static_cast<int>(gmat.row_ptr[i + 1] - gmat.row_ptr[i]));
}
row_stride = max_elements_row;
std::vector<int> ellpack_matrix(row_stride * n_rows, null_gidx_value);
bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins),
param(param),
prediction_cache_initialised(false) {}
for (auto i = row_begin; i < row_end; i++) {
int row_count = 0;
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
ellpack_matrix[(i - row_begin) * row_stride + row_count] =
gmat.index[j];
row_count++;
}
}
void Init(const common::HistCutMatrix& hmat, const RowBatch& row_batch) {
// copy cuts to the GPU
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::device_vector<float> cuts_d(hmat.cut);
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
// Allocate
// find the maximum row size
thrust::device_vector<size_t> row_ptr_d(
row_batch.ind_ptr + row_begin_idx, row_batch.ind_ptr + row_end_idx + 1);
auto row_iter = row_ptr_d.begin();
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size),
decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
thrust::maximum<size_t>());
// allocate compressed bin data
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(
ellpack_matrix.size(), num_symbols);
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
num_symbols);
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);
gidx_buffer.Fill(0);
// bin and compress entries in batches of rows
// use no more than 1/16th of GPU memory per batch
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
(16 * row_stride * sizeof(RowBatch::Entry));
if (gpu_batch_nrows > n_rows) {
gpu_batch_nrows = n_rows;
}
thrust::device_vector<RowBatch::Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > n_rows) {
batch_row_end = n_rows;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
size_t n_entries =
row_batch.ind_ptr[row_begin_idx + batch_row_end] -
row_batch.ind_ptr[row_begin_idx + batch_row_begin];
dh::safe_cuda
(cudaMemcpy
(entries_d.data().get(),
&row_batch.data_ptr[row_batch.ind_ptr[row_begin_idx + batch_row_begin]],
n_entries * sizeof(RowBatch::Entry), cudaMemcpyDefault));
dim3 block3(32, 8, 1);
dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
compress_bin_ellpack_k<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(),
row_ptr_d.data().get() + batch_row_begin,
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
batch_row_begin, batch_nrows,
row_batch.ind_ptr[row_begin_idx + batch_row_begin],
row_stride, null_gidx_value);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
// free the memory that is no longer needed
row_ptr_d.resize(0);
row_ptr_d.shrink_to_fit();
entries_d.resize(0);
entries_d.shrink_to_fit();
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);
// allocate the rest
int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
ba.Allocate(device_idx, param.silent,
&gpair, n_rows, &ridx, n_rows, &position, n_rows,
&prediction_cache, n_rows, &node_sum_gradients_d, max_nodes,
&feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map,
gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size(),
&feature_segments, hmat.row_ptr.size(), &gidx_fvalue_map,
hmat.cut.size(), &min_fvalue, hmat.min_val.size(),
&monotone_constraints, param.monotone_constraints.size());
gidx_fvalue_map = gmat.cut->cut;
min_fvalue = gmat.cut->min_val;
feature_segments = gmat.cut->row_ptr;
gidx_fvalue_map = hmat.cut;
min_fvalue = hmat.min_val;
feature_segments = hmat.row_ptr;
monotone_constraints = param.monotone_constraints;
node_sum_gradients.resize(max_nodes);
ridx_segments.resize(max_nodes);
// Compress gidx
common::CompressedBufferWriter cbw(num_symbols);
std::vector<common::CompressedByteT> host_buffer(gidx_buffer.Size());
cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end());
gidx_buffer = host_buffer;
gidx =
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);
common::CompressedIterator<uint32_t> ci_host(host_buffer.data(),
num_symbols);
// Init histogram
hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent);
hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent);
dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
}
@ -579,8 +672,6 @@ class GPUHistMaker : public TreeUpdater {
info_ = &dmat->Info();
monitor_.Start("Quantiles", device_list_);
hmat_.Init(dmat, param_.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(dmat);
monitor_.Stop("Quantiles", device_list_);
n_bins_ = hmat_.row_ptr.back();
@ -609,12 +700,22 @@ class GPUHistMaker : public TreeUpdater {
row_begin = row_end;
}
// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(device_list_[i], i, gmat_,
row_segments[i], row_segments[i + 1], n_bins_, param_));
});
monitor_.Start("BinningCompression", device_list_);
{
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next()) << "Empty batches are not supported";
const RowBatch& batch = iter->Value();
// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(device_list_[i], i,
row_segments[i], row_segments[i + 1], n_bins_, param_));
shard->Init(hmat_, batch);
});
CHECK(!iter->Next()) << "External memory not supported";
}
monitor_.Stop("BinningCompression", device_list_);
p_last_fmat_ = dmat;
initialised_ = true;

View File

@ -0,0 +1,73 @@
#include "../../../src/common/compressed_iterator.h"
#include "../../../src/common/device_helpers.cuh"
#include "gtest/gtest.h"
#include <algorithm>
#include <thrust/device_vector.h>
namespace xgboost {
namespace common {
struct WriteSymbolFunction {
CompressedBufferWriter cbw;
unsigned char* buffer_data_d;
int* input_data_d;
WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d,
int* input_data_d)
: cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {}
__device__ void operator()(size_t i) {
cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i);
}
};
struct ReadSymbolFunction {
CompressedIterator<int> ci;
int* output_data_d;
ReadSymbolFunction(CompressedIterator<int> ci, int* output_data_d)
: ci(ci), output_data_d(output_data_d) {}
__device__ void operator()(size_t i) {
output_data_d[i] = ci[i];
}
};
TEST(CompressedIterator, TestGPU) {
std::vector<int> test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX};
int num_elements = 1000;
int repetitions = 1000;
srand(9);
for (auto alphabet_size : test_cases) {
for (int i = 0; i < repetitions; i++) {
std::vector<int> input(num_elements);
std::generate(input.begin(), input.end(),
[=]() { return rand() % alphabet_size; });
CompressedBufferWriter cbw(alphabet_size);
thrust::device_vector<int> input_d(input);
thrust::device_vector<unsigned char> buffer_d(
CompressedBufferWriter::CalculateBufferSize(input.size(),
alphabet_size));
// write the data on device
auto input_data_d = input_d.data().get();
auto buffer_data_d = buffer_d.data().get();
dh::LaunchN(0, input_d.size(),
WriteSymbolFunction(cbw, buffer_data_d, input_data_d));
// read the data on device
CompressedIterator<int> ci(buffer_d.data().get(), alphabet_size);
thrust::device_vector<int> output_d(input.size());
auto output_data_d = output_d.data().get();
dh::LaunchN(0, output_d.size(), ReadSymbolFunction(ci, output_data_d));
std::vector<int> output(output_d.size());
thrust::copy(output_d.begin(), output_d.end(), output.begin());
ASSERT_TRUE(input == output);
}
}
}
} // namespace common
} // namespace xgboost

View File

@ -18,11 +18,19 @@ long GetFileSize(const std::string filename) {
}
std::string CreateSimpleTestData() {
return CreateBigTestData(6);
}
std::string CreateBigTestData(size_t n_entries) {
std::string tmp_file = TempFileName();
std::ofstream fo;
fo.open(tmp_file);
fo << "0 0:0 1:10 2:20\n";
fo << "1 0:0 3:30 4:40\n";
const size_t entries_per_row = 3;
size_t n_rows = (n_entries + entries_per_row - 1) / entries_per_row;
for (size_t i = 0; i < n_rows; ++i) {
const char* row = i % 2 == 0 ? " 0:0 1:10 2:20\n" : " 0:0 3:30 4:40\n";
fo << i << row;
}
fo.close();
return tmp_file;
}

View File

@ -23,6 +23,8 @@ long GetFileSize(const std::string filename);
std::string CreateSimpleTestData();
std::string CreateBigTestData(size_t n_entries);
void CheckObjFunction(xgboost::ObjFunction * obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,

View File

@ -7,6 +7,7 @@
#include "../helpers.h"
#include "gtest/gtest.h"
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu"
@ -24,8 +25,14 @@ TEST(gpu_hist_experimental, TestSparseShard) {
gmat.Init(dmat.get());
TrainParam p;
p.max_depth = 6;
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
p);
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);
CHECK(!iter->Next());
ASSERT_LT(shard.row_stride, columns);
@ -59,8 +66,15 @@ TEST(gpu_hist_experimental, TestDenseShard) {
gmat.Init(dmat.get());
TrainParam p;
p.max_depth = 6;
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
p);
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);
CHECK(!iter->Next());
ASSERT_EQ(shard.row_stride, columns);
@ -75,4 +89,4 @@ TEST(gpu_hist_experimental, TestDenseShard) {
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost