diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 66693253f..57417ffdc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -73,3 +73,4 @@ List of Contributors * [Gideon Whitehead](https://github.com/gaw89) * [Yi-Lin Juang](https://github.com/frankyjuang) * [Andrew Hannigan](https://github.com/andrewhannigan) +* [Andy Adinets](https://github.com/canonizer) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index cc7904583..418972a42 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -126,15 +126,15 @@ struct SparseBatch { /*! \brief feature value */ bst_float fvalue; /*! \brief default constructor */ - Entry() = default; + XGBOOST_DEVICE Entry() {} /*! * \brief constructor with index and value * \param index The feature or row index. * \param fvalue THe feature value. */ - Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} + XGBOOST_DEVICE Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} /*! \brief reversely compare feature values */ - inline static bool CmpValue(const Entry& a, const Entry& b) { + XGBOOST_DEVICE inline static bool CmpValue(const Entry& a, const Entry& b) { return a.fvalue < b.fvalue; } }; diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 512b75fbf..4b2ee45b6 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -8,6 +8,10 @@ #include #include +#ifdef __CUDACC__ +#include "device_helpers.cuh" +#endif + namespace xgboost { namespace common { @@ -96,6 +100,23 @@ class CompressedBufferWriter { } } } + +#ifdef __CUDACC__ + __device__ void AtomicWriteSymbol + (CompressedByteT* buffer, uint64_t symbol, size_t offset) { + size_t ibit_start = offset * symbol_bits_; + size_t ibit_end = (offset + 1) * symbol_bits_ - 1; + size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8; + + symbol <<= 7 - ibit_end % 8; + for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) { + dh::AtomicOrByte(reinterpret_cast(buffer + detail::kPadding), + ibyte, symbol & 0xff); + symbol >>= 8; + } + } +#endif + template void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) { uint64_t tmp = 0; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 2a7034706..42cd52a41 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -122,6 +122,14 @@ inline size_t AvailableMemory(int device_idx) { return device_free; } +inline size_t TotalMemory(int device_idx) { + size_t device_free = 0; + size_t device_total = 0; + safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); + return device_total; +} + /** * \fn inline int max_shared_memory(int device_idx) * @@ -155,6 +163,12 @@ inline void CheckComputeCapability() { } } + +DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) { + atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8)); +} + + /* * Range iterator */ diff --git a/src/tree/param.h b/src/tree/param.h index dc7949b00..551503a84 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -183,6 +183,7 @@ struct TrainParam : public dmlc::Parameter { .set_lower_bound(-1) .set_default(1) .describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs"); + // add alias of parameters DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_alpha, alpha); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2f1eb5ae4..96e60a2c3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -2,6 +2,9 @@ * Copyright 2017 XGBoost contributors */ #include +#include +#include +#include #include #include #include @@ -224,6 +227,53 @@ struct CalcWeightTrainParam { learning_rate(p.learning_rate) {} }; +// index of the first element in cuts greater than v, or n if none; +// cuts are ordered, and binary search is used +__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) { + if (n == 0) + return 0; + if (cuts[n - 1] <= v) + return n; + if (cuts[0] > v) + return 0; + int left = 0, right = n - 1; + while (right - left > 1) { + int middle = left + (right - left) / 2; + if (cuts[middle] > v) + right = middle; + else + left = middle; + } + return right; +} + +__global__ void compress_bin_ellpack_k +(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer, + const size_t* __restrict__ row_ptrs, + const RowBatch::Entry* __restrict__ entries, + const float* __restrict__ cuts, const size_t* __restrict__ cut_rows, + size_t base_row, size_t n_rows, size_t row_ptr_begin, size_t row_stride, + unsigned int null_gidx_value) { + size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x; + int ifeature = threadIdx.y + blockIdx.y * blockDim.y; + if (irow >= n_rows || ifeature >= row_stride) + return; + int row_size = static_cast(row_ptrs[irow + 1] - row_ptrs[irow]); + unsigned int bin = null_gidx_value; + if (ifeature < row_size) { + RowBatch::Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature]; + int feature = entry.index; + float fvalue = entry.fvalue; + const float *feature_cuts = &cuts[cut_rows[feature]]; + int ncuts = cut_rows[feature + 1] - cut_rows[feature]; + bin = upper_bound(feature_cuts, ncuts, fvalue); + if (bin >= ncuts) + bin = ncuts - 1; + bin += cut_rows[feature]; + } + wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); +} + // Manage memory for a single GPU struct DeviceShard { struct Segment { @@ -271,74 +321,117 @@ struct DeviceShard { dh::CubMemory temp_memory; DeviceShard(int device_idx, int normalised_device_idx, - const common::GHistIndexMatrix& gmat, bst_uint row_begin, - bst_uint row_end, int n_bins, TrainParam param) - : device_idx(device_idx), - normalised_device_idx(normalised_device_idx), - row_begin_idx(row_begin), - row_end_idx(row_end), - n_rows(row_end - row_begin), - n_bins(n_bins), - null_gidx_value(n_bins), - param(param), - prediction_cache_initialised(false) { - // Convert to ELLPACK matrix representation - int max_elements_row = 0; - for (auto i = row_begin; i < row_end; i++) { - max_elements_row = - (std::max)(max_elements_row, - static_cast(gmat.row_ptr[i + 1] - gmat.row_ptr[i])); - } - row_stride = max_elements_row; - std::vector ellpack_matrix(row_stride * n_rows, null_gidx_value); + bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param) + : device_idx(device_idx), + normalised_device_idx(normalised_device_idx), + row_begin_idx(row_begin), + row_end_idx(row_end), + n_rows(row_end - row_begin), + n_bins(n_bins), + null_gidx_value(n_bins), + param(param), + prediction_cache_initialised(false) {} - for (auto i = row_begin; i < row_end; i++) { - int row_count = 0; - for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { - ellpack_matrix[(i - row_begin) * row_stride + row_count] = - gmat.index[j]; - row_count++; - } - } + void Init(const common::HistCutMatrix& hmat, const RowBatch& row_batch) { + // copy cuts to the GPU + dh::safe_cuda(cudaSetDevice(device_idx)); + thrust::device_vector cuts_d(hmat.cut); + thrust::device_vector cut_row_ptrs_d(hmat.row_ptr); - // Allocate + // find the maximum row size + thrust::device_vector row_ptr_d( + row_batch.ind_ptr + row_begin_idx, row_batch.ind_ptr + row_end_idx + 1); + + auto row_iter = row_ptr_d.begin(); + auto get_size = [=] __device__(size_t row) { + return row_iter[row + 1] - row_iter[row]; + }; // NOLINT + auto counting = thrust::make_counting_iterator(size_t(0)); + using TransformT = thrust::transform_iterator; + TransformT row_size_iter = TransformT(counting, get_size); + row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0, + thrust::maximum()); + + // allocate compressed bin data int num_symbols = n_bins + 1; size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize( - ellpack_matrix.size(), num_symbols); + common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, + num_symbols); CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) << "Max leaves and max depth cannot both be unconstrained for " "gpu_hist."; + ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes); + + gidx_buffer.Fill(0); + + // bin and compress entries in batches of rows + // use no more than 1/16th of GPU memory per batch + size_t gpu_batch_nrows = dh::TotalMemory(device_idx) / + (16 * row_stride * sizeof(RowBatch::Entry)); + if (gpu_batch_nrows > n_rows) { + gpu_batch_nrows = n_rows; + } + thrust::device_vector entries_d(gpu_batch_nrows * row_stride); + size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows); + for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { + size_t batch_row_begin = gpu_batch * gpu_batch_nrows; + size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows; + if (batch_row_end > n_rows) { + batch_row_end = n_rows; + } + size_t batch_nrows = batch_row_end - batch_row_begin; + size_t n_entries = + row_batch.ind_ptr[row_begin_idx + batch_row_end] - + row_batch.ind_ptr[row_begin_idx + batch_row_begin]; + dh::safe_cuda + (cudaMemcpy + (entries_d.data().get(), + &row_batch.data_ptr[row_batch.ind_ptr[row_begin_idx + batch_row_begin]], + n_entries * sizeof(RowBatch::Entry), cudaMemcpyDefault)); + dim3 block3(32, 8, 1); + dim3 grid3(dh::DivRoundUp(n_rows, block3.x), + dh::DivRoundUp(row_stride, block3.y), 1); + compress_bin_ellpack_k<<>> + (common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(), + row_ptr_d.data().get() + batch_row_begin, + entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(), + batch_row_begin, batch_nrows, + row_batch.ind_ptr[row_begin_idx + batch_row_begin], + row_stride, null_gidx_value); + + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); + } + + // free the memory that is no longer needed + row_ptr_d.resize(0); + row_ptr_d.shrink_to_fit(); + entries_d.resize(0); + entries_d.shrink_to_fit(); + + gidx = common::CompressedIterator(gidx_buffer.Data(), num_symbols); + + // allocate the rest int max_nodes = param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); - ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, + ba.Allocate(device_idx, param.silent, &gpair, n_rows, &ridx, n_rows, &position, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, - &feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map, - gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size(), + &feature_segments, hmat.row_ptr.size(), &gidx_fvalue_map, + hmat.cut.size(), &min_fvalue, hmat.min_val.size(), &monotone_constraints, param.monotone_constraints.size()); - gidx_fvalue_map = gmat.cut->cut; - min_fvalue = gmat.cut->min_val; - feature_segments = gmat.cut->row_ptr; + gidx_fvalue_map = hmat.cut; + min_fvalue = hmat.min_val; + feature_segments = hmat.row_ptr; monotone_constraints = param.monotone_constraints; node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); - // Compress gidx - common::CompressedBufferWriter cbw(num_symbols); - std::vector host_buffer(gidx_buffer.Size()); - cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end()); - gidx_buffer = host_buffer; - gidx = - common::CompressedIterator(gidx_buffer.Data(), num_symbols); - - common::CompressedIterator ci_host(host_buffer.data(), - num_symbols); - // Init histogram - hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent); + hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent); dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); } @@ -579,8 +672,6 @@ class GPUHistMaker : public TreeUpdater { info_ = &dmat->Info(); monitor_.Start("Quantiles", device_list_); hmat_.Init(dmat, param_.max_bin); - gmat_.cut = &hmat_; - gmat_.Init(dmat); monitor_.Stop("Quantiles", device_list_); n_bins_ = hmat_.row_ptr.back(); @@ -609,12 +700,22 @@ class GPUHistMaker : public TreeUpdater { row_begin = row_end; } - // Create device shards - dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - shard = std::unique_ptr( - new DeviceShard(device_list_[i], i, gmat_, - row_segments[i], row_segments[i + 1], n_bins_, param_)); - }); + monitor_.Start("BinningCompression", device_list_); + { + dmlc::DataIter* iter = dmat->RowIterator(); + iter->BeforeFirst(); + CHECK(iter->Next()) << "Empty batches are not supported"; + const RowBatch& batch = iter->Value(); + // Create device shards + dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { + shard = std::unique_ptr + (new DeviceShard(device_list_[i], i, + row_segments[i], row_segments[i + 1], n_bins_, param_)); + shard->Init(hmat_, batch); + }); + CHECK(!iter->Next()) << "External memory not supported"; + } + monitor_.Stop("BinningCompression", device_list_); p_last_fmat_ = dmat; initialised_ = true; diff --git a/tests/cpp/common/test_gpu_compressed_iterator.cu b/tests/cpp/common/test_gpu_compressed_iterator.cu new file mode 100644 index 000000000..b462b78a5 --- /dev/null +++ b/tests/cpp/common/test_gpu_compressed_iterator.cu @@ -0,0 +1,73 @@ +#include "../../../src/common/compressed_iterator.h" +#include "../../../src/common/device_helpers.cuh" +#include "gtest/gtest.h" +#include +#include + +namespace xgboost { +namespace common { + +struct WriteSymbolFunction { + CompressedBufferWriter cbw; + unsigned char* buffer_data_d; + int* input_data_d; + WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d, + int* input_data_d) + : cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {} + + __device__ void operator()(size_t i) { + cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i); + } +}; + +struct ReadSymbolFunction { + CompressedIterator ci; + int* output_data_d; + ReadSymbolFunction(CompressedIterator ci, int* output_data_d) + : ci(ci), output_data_d(output_data_d) {} + + __device__ void operator()(size_t i) { + output_data_d[i] = ci[i]; + } +}; + +TEST(CompressedIterator, TestGPU) { + std::vector test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX}; + int num_elements = 1000; + int repetitions = 1000; + srand(9); + + for (auto alphabet_size : test_cases) { + for (int i = 0; i < repetitions; i++) { + std::vector input(num_elements); + std::generate(input.begin(), input.end(), + [=]() { return rand() % alphabet_size; }); + CompressedBufferWriter cbw(alphabet_size); + thrust::device_vector input_d(input); + + thrust::device_vector buffer_d( + CompressedBufferWriter::CalculateBufferSize(input.size(), + alphabet_size)); + + // write the data on device + auto input_data_d = input_d.data().get(); + auto buffer_data_d = buffer_d.data().get(); + dh::LaunchN(0, input_d.size(), + WriteSymbolFunction(cbw, buffer_data_d, input_data_d)); + + // read the data on device + CompressedIterator ci(buffer_d.data().get(), alphabet_size); + thrust::device_vector output_d(input.size()); + auto output_data_d = output_d.data().get(); + dh::LaunchN(0, output_d.size(), ReadSymbolFunction(ci, output_data_d)); + + std::vector output(output_d.size()); + thrust::copy(output_d.begin(), output_d.end(), output.begin()); + + ASSERT_TRUE(input == output); + } + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 4b172c12b..df336d326 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -18,11 +18,19 @@ long GetFileSize(const std::string filename) { } std::string CreateSimpleTestData() { + return CreateBigTestData(6); +} + +std::string CreateBigTestData(size_t n_entries) { std::string tmp_file = TempFileName(); std::ofstream fo; fo.open(tmp_file); - fo << "0 0:0 1:10 2:20\n"; - fo << "1 0:0 3:30 4:40\n"; + const size_t entries_per_row = 3; + size_t n_rows = (n_entries + entries_per_row - 1) / entries_per_row; + for (size_t i = 0; i < n_rows; ++i) { + const char* row = i % 2 == 0 ? " 0:0 1:10 2:20\n" : " 0:0 3:30 4:40\n"; + fo << i << row; + } fo.close(); return tmp_file; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 6846075c4..b3fcebfb3 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -23,6 +23,8 @@ long GetFileSize(const std::string filename); std::string CreateSimpleTestData(); +std::string CreateBigTestData(size_t n_entries); + void CheckObjFunction(xgboost::ObjFunction * obj, std::vector preds, std::vector labels, diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index cb5fcae56..e1e7c28ac 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -7,6 +7,7 @@ #include "../helpers.h" #include "gtest/gtest.h" +#include "../../../src/data/sparse_page_source.h" #include "../../../src/gbm/gbtree_model.h" #include "../../../src/tree/updater_gpu_hist.cu" @@ -24,8 +25,14 @@ TEST(gpu_hist_experimental, TestSparseShard) { gmat.Init(dmat.get()); TrainParam p; p.max_depth = 6; - DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), - p); + + dmlc::DataIter* iter = dmat->RowIterator(); + iter->BeforeFirst(); + CHECK(iter->Next()); + const RowBatch& batch = iter->Value(); + DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p); + shard.Init(hmat, batch); + CHECK(!iter->Next()); ASSERT_LT(shard.row_stride, columns); @@ -59,8 +66,15 @@ TEST(gpu_hist_experimental, TestDenseShard) { gmat.Init(dmat.get()); TrainParam p; p.max_depth = 6; - DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), - p); + + dmlc::DataIter* iter = dmat->RowIterator(); + iter->BeforeFirst(); + CHECK(iter->Next()); + const RowBatch& batch = iter->Value(); + + DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p); + shard.Init(hmat, batch); + CHECK(!iter->Next()); ASSERT_EQ(shard.row_stride, columns); @@ -75,4 +89,4 @@ TEST(gpu_hist_experimental, TestDenseShard) { } } // namespace tree -} // namespace xgboost \ No newline at end of file +} // namespace xgboost