diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index c4884442d..42587b624 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -630,6 +630,37 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val, #endif } +// Instances of this type are created while creating the histogram bins for the +// entire dataset across multiple sparse page batches. This keeps track of the number +// of rows to process from a batch and the position from which to process on each device. +struct RowStateOnDevice { + // Number of rows assigned to this device + const size_t total_rows_assigned_to_device; + // Number of rows processed thus far + size_t total_rows_processed; + // Number of rows to process from the current sparse page batch + size_t rows_to_process_from_batch; + // Offset from the current sparse page batch to begin processing + size_t row_offset_in_current_batch; + + explicit RowStateOnDevice(size_t total_rows) + : total_rows_assigned_to_device(total_rows), total_rows_processed(0), + rows_to_process_from_batch(0), row_offset_in_current_batch(0) { + } + + explicit RowStateOnDevice(size_t total_rows, size_t batch_rows) + : total_rows_assigned_to_device(total_rows), total_rows_processed(0), + rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) { + } + + // Advance the row state by the number of rows processed + void Advance() { + total_rows_processed += rows_to_process_from_batch; + CHECK_LE(total_rows_processed, total_rows_assigned_to_device); + rows_to_process_from_batch = row_offset_in_current_batch = 0; + } +}; + // Manage memory for a single GPU template struct DeviceShard { @@ -666,8 +697,6 @@ struct DeviceShard { /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; common::Span node_sum_gradients_d; - /*! \brief row offset in SparsePage (the input data). */ - dh::device_vector row_ptrs; /*! \brief On-device feature set, only actually used on one of the devices */ dh::device_vector feature_set_d; dh::device_vector @@ -695,7 +724,6 @@ struct DeviceShard { std::function>; std::unique_ptr qexpand; - // TODO(canonizer): do add support multi-batch DMatrix here DeviceShard(int _device_id, int shard_idx, bst_uint row_begin, bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed) : device_id(_device_id), @@ -710,32 +738,12 @@ struct DeviceShard { monitor.Init(std::string("DeviceShard") + std::to_string(device_id)); } - /* Init row_ptrs and row_stride */ - size_t InitRowPtrs(const SparsePage& row_batch) { - const auto& offset_vec = row_batch.offset.HostVector(); - row_ptrs.resize(n_rows + 1); - thrust::copy(offset_vec.data() + row_begin_idx, - offset_vec.data() + row_end_idx + 1, - row_ptrs.begin()); - auto row_iter = row_ptrs.begin(); - // find the maximum row size for converting to ELLPack - auto get_size = [=] __device__(size_t row) { - return row_iter[row + 1] - row_iter[row]; - }; // NOLINT - - auto counting = thrust::make_counting_iterator(size_t(0)); - using TransformT = thrust::transform_iterator; - TransformT row_size_iter = TransformT(counting, get_size); - size_t row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0, - thrust::maximum()); - return row_stride; - } - void InitCompressedData( - const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense); + const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense); - void CreateHistIndices(const SparsePage& row_batch, size_t row_stride, int null_gidx_value); + void CreateHistIndices( + const SparsePage &row_batch, const common::HistCutMatrix &hmat, + const RowStateOnDevice &device_row_state, int rows_per_batch); ~DeviceShard() { dh::safe_cuda(cudaSetDevice(device_id)); @@ -1229,11 +1237,14 @@ struct DeviceShard { template inline void DeviceShard::InitCompressedData( - const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) { - size_t row_stride = this->InitRowPtrs(row_batch); + const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) { n_bins = hmat.row_ptr.back(); int null_gidx_value = hmat.row_ptr.back(); + CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) + << "Max leaves and max depth cannot both be unconstrained for " + "gpu_hist."; + int max_nodes = param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); @@ -1256,7 +1267,6 @@ inline void DeviceShard::InitCompressedData( node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); - // allocate compressed bin data int num_symbols = n_bins + 1; // Required buffer size for storing data matrix in ELLPack format. @@ -1264,16 +1274,11 @@ inline void DeviceShard::InitCompressedData( common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols); - CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) - << "Max leaves and max depth cannot both be unconstrained for " - "gpu_hist."; ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes); thrust::fill( thrust::device_pointer_cast(gidx_buffer.data()), thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0); - this->CreateHistIndices(row_batch, row_stride, null_gidx_value); - ellpack_matrix.Init( feature_segments, min_fvalue, gidx_fvalue_map, row_stride, @@ -1295,25 +1300,45 @@ inline void DeviceShard::InitCompressedData( template inline void DeviceShard::CreateHistIndices( - const SparsePage& row_batch, size_t row_stride, int null_gidx_value) { + const SparsePage &row_batch, + const common::HistCutMatrix &hmat, + const RowStateOnDevice &device_row_state, + int rows_per_batch) { + // Has any been allocated for me in this batch? + if (!device_row_state.rows_to_process_from_batch) return; + + unsigned int null_gidx_value = hmat.row_ptr.back(); + size_t row_stride = this->ellpack_matrix.row_stride; + + const auto &offset_vec = row_batch.offset.ConstHostVector(); + /*! \brief row offset in SparsePage (the input data). */ + CHECK_LE(device_row_state.rows_to_process_from_batch, offset_vec.size()); + dh::device_vector row_ptrs(device_row_state.rows_to_process_from_batch+1); + thrust::copy( + offset_vec.data() + device_row_state.row_offset_in_current_batch, + offset_vec.data() + device_row_state.row_offset_in_current_batch + + device_row_state.rows_to_process_from_batch + 1, + row_ptrs.begin()); + int num_symbols = n_bins + 1; // bin and compress entries in batches of rows - size_t gpu_batch_nrows = - std::min - (dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)), - static_cast(n_rows)); - const std::vector& data_vec = row_batch.data.HostVector(); + size_t gpu_batch_nrows = std::min( + dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)), + static_cast(device_row_state.rows_to_process_from_batch)); + const std::vector& data_vec = row_batch.data.ConstHostVector(); dh::device_vector entries_d(gpu_batch_nrows * row_stride); - size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows); + size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch, + gpu_batch_nrows); for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { size_t batch_row_begin = gpu_batch * gpu_batch_nrows; size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows; - if (batch_row_end > n_rows) { - batch_row_end = n_rows; + if (batch_row_end > device_row_state.rows_to_process_from_batch) { + batch_row_end = device_row_state.rows_to_process_from_batch; } size_t batch_nrows = batch_row_end - batch_row_begin; + // number of entries in this batch. size_t n_entries = row_ptrs[batch_row_end] - row_ptrs[batch_row_begin]; // copy data entries to device. @@ -1322,17 +1347,20 @@ inline void DeviceShard::CreateHistIndices( (entries_d.data().get(), data_vec.data() + row_ptrs[batch_row_begin], n_entries * sizeof(Entry), cudaMemcpyDefault)); const dim3 block3(32, 8, 1); // 256 threads - const dim3 grid3(dh::DivRoundUp(n_rows, block3.x), + const dim3 grid3(dh::DivRoundUp(device_row_state.rows_to_process_from_batch, block3.x), dh::DivRoundUp(row_stride, block3.y), 1); CompressBinEllpackKernel<<>> (common::CompressedBufferWriter(num_symbols), gidx_buffer.data(), row_ptrs.data().get() + batch_row_begin, entries_d.data().get(), - gidx_fvalue_map.data(), feature_segments.data(), - batch_row_begin, batch_nrows, + gidx_fvalue_map.data(), + feature_segments.data(), + device_row_state.total_rows_processed + batch_row_begin, + batch_nrows, row_ptrs[batch_row_begin], - row_stride, null_gidx_value); + row_stride, + null_gidx_value); } // free the memory that is no longer needed @@ -1342,6 +1370,60 @@ inline void DeviceShard::CreateHistIndices( entries_d.shrink_to_fit(); } +// An instance of this type is created which keeps track of total number of rows to process, +// rows processed thus far, rows to process and the offset from the current sparse page batch +// to begin processing on each device +class DeviceHistogramBuilderState { + public: + template + explicit DeviceHistogramBuilderState( + const std::vector>> &shards) { + device_row_states_.reserve(shards.size()); + for (const auto &shard : shards) { + device_row_states_.push_back(RowStateOnDevice(shard->n_rows)); + } + } + + const RowStateOnDevice &GetRowStateOnDevice(int idx) const { + return device_row_states_[idx]; + } + + // This method is invoked at the beginning of each sparse page batch. This distributes + // the rows in the sparse page to the different devices. + // TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins. + void BeginBatch(const SparsePage &batch) { + size_t rem_rows = batch.Size(); + size_t row_offset_in_current_batch = 0; + for (auto &device_row_state : device_row_states_) { + // Do we have anymore left to process from this batch on this device? + if (device_row_state.total_rows_assigned_to_device > device_row_state.total_rows_processed) { + // There are still some rows that needs to be assigned to this device + device_row_state.rows_to_process_from_batch = + std::min( + device_row_state.total_rows_assigned_to_device - device_row_state.total_rows_processed, + rem_rows); + } else { + // All rows have been assigned to this device + device_row_state.rows_to_process_from_batch = 0; + } + + device_row_state.row_offset_in_current_batch = row_offset_in_current_batch; + row_offset_in_current_batch += device_row_state.rows_to_process_from_batch; + rem_rows -= device_row_state.rows_to_process_from_batch; + } + } + + // This method is invoked after completion of each sparse page batch + void EndBatch() { + for (auto &rs : device_row_states_) { + rs.Advance(); + } + } + + private: + std::vector device_row_states_; +}; + template class GPUHistMakerSpecialised { public: @@ -1397,9 +1479,6 @@ class GPUHistMakerSpecialised { reducer_.Init(device_list_); - auto batch_iter = dmat->GetRowBatches().begin(); - const SparsePage& batch = *batch_iter; - // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); @@ -1418,26 +1497,43 @@ class GPUHistMakerSpecialised { column_sampling_seed)); }); - // Find the cuts. monitor_.StartCuda("Quantiles"); - // TODO(sriramch): The return value will be used when we add support for histogram - // index creation for multiple batches - common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_); - n_bins_ = hmat_.row_ptr.back(); + // Create the quantile sketches for the dmatrix and initialize HistCutMatrix + size_t row_stride = common::DeviceSketch(param_, *learner_param_, + hist_maker_param_.gpu_batch_nrows, + dmat, &hmat_); monitor_.StopCuda("Quantiles"); + n_bins_ = hmat_.row_ptr.back(); + auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; - monitor_.StartCuda("BinningCompression"); + // Init global data for each shard + monitor_.StartCuda("InitCompressedData"); dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { dh::safe_cuda(cudaSetDevice(shard->device_id)); - shard->InitCompressedData(hmat_, batch, is_dense); + shard->InitCompressedData(hmat_, row_stride, is_dense); }); + monitor_.StopCuda("InitCompressedData"); + + monitor_.StartCuda("BinningCompression"); + DeviceHistogramBuilderState hist_builder_row_state(shards_); + for (const auto &batch : dmat->GetRowBatches()) { + hist_builder_row_state.BeginBatch(batch); + + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id)); + shard->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(idx), + hist_maker_param_.gpu_batch_nrows); + }); + + hist_builder_row_state.EndBatch(); + } monitor_.StopCuda("BinningCompression"); - ++batch_iter; - CHECK(batch_iter.AtEnd()) << "External memory not supported"; p_last_fmat_ = dmat; initialised_ = true; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 56d31fce3..514f39010 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -6,6 +6,7 @@ #include #include #include +#include "../../src/data/simple_csr_source.h" bool FileExists(const std::string& filename) { struct stat st; @@ -165,6 +166,71 @@ std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, size_t page_s return dmat; } +std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, + size_t page_size, bool deterministic) { + if (!n_rows || !n_cols) { + return nullptr; + } + + // Create the svm file in a temp dir + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/big.libsvm"; + + std::ofstream fo(tmp_file.c_str()); + size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1; + int64_t rem_cols = n_cols; + size_t col_idx = 0; + + // Random feature id generator + std::random_device rdev; + std::unique_ptr gen; + if (deterministic) { + // Seed it with a constant value for this configuration - without getting too fancy + // like ordered pairing functions and its likes to make it truely unique + gen.reset(new std::mt19937(n_rows * n_cols)); + } else { + gen.reset(new std::mt19937(rdev())); + } + std::uniform_int_distribution dis(1, n_cols); + + for (size_t i = 0; i < n_rows; ++i) { + // Make sure that all cols are slotted in the first few rows; randomly distribute the + // rest + std::stringstream row_data; + fo << i; + size_t j = 0; + if (rem_cols > 0) { + for (; j < std::min(static_cast(rem_cols), cols_per_row); ++j) { + row_data << " " << (col_idx+j) << ":" << (col_idx+j+1)*10; + } + rem_cols -= cols_per_row; + } else { + // Take some random number of colums in [1, n_cols] and slot them here + size_t ncols = dis(*gen); + for (; j < ncols; ++j) { + size_t fid = (col_idx+j) % n_cols; + row_data << " " << fid << ":" << (fid+1)*10; + } + } + col_idx += j; + + fo << row_data.str() << "\n"; + } + fo.close(); + + std::unique_ptr dmat(DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size)); + EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); + + if (!page_size) { + std::unique_ptr source(new data::SimpleCSRSource); + source->CopyFrom(dmat.get()); + return std::unique_ptr(DMatrix::Create(std::move(source))); + } else { + return dmat; + } +} + gbm::GBTreeModel CreateTestModel() { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree)); diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 9c829af51..0c3c8e535 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -165,6 +165,27 @@ std::shared_ptr *CreateDMatrix(int rows, int columns, std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, size_t page_size); +/** + * \fn std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, + * size_t page_size); + * + * \brief Creates dmatrix with some records, each record containing random number of + * features in [1, n_cols] + * + * \param n_rows Number of records to create. + * \param n_cols Max number of features within that record. + * \param page_size Sparse page size for the pages within the dmatrix. If page size is 0 + * then the entire dmatrix is resident in memory; else, multiple sparse pages + * of page size are created and backed to disk, which would have to be + * streamed in at point of use. + * \param deterministic The content inside the dmatrix is constant for this configuration, if true; + * else, the content changes every time this method is invoked + * + * \return The new dmatrix. + */ +std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, + size_t page_size, bool deterministic); + gbm::GBTreeModel CreateTestModel(); inline LearnerTrainParam CreateEmptyGenericParam(int gpu_id, int n_gpus) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 6a59859e3..e7c48eaf2 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -77,7 +77,14 @@ void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, auto is_dense = (*dmat)->Info().num_nonzero_ == (*dmat)->Info().num_row_ * (*dmat)->Info().num_col_; - shard->InitCompressedData(cmat, batch, is_dense); + size_t row_stride = 0; + const auto &offset_vec = batch.offset.ConstHostVector(); + for (size_t i = 1; i < offset_vec.size(); ++i) { + row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]); + } + shard->InitCompressedData(cmat, row_stride, is_dense); + shard->CreateHistIndices( + batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1); delete dmat; } @@ -469,5 +476,46 @@ TEST(GpuHist, SortPosition) { TestSortPosition({2, 2, 2, 2}, 1, 2); TestSortPosition({1, 2, 1, 2, 3}, 1, 2); } + +TEST(GpuHist, TestHistogramIndex) { + // Test if the compressed histogram index matches when using a sparse + // dmatrix with and without using external memory + + int constexpr kNRows = 1000, kNCols = 10; + + // Build 2 matrices and build a histogram maker with that + tree::GPUHistMakerSpecialised hist_maker, hist_maker_ext; + std::unique_ptr hist_maker_dmat( + CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); + std::unique_ptr hist_maker_ext_dmat( + CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true)); + + std::vector> training_params = { + {"max_depth", "1"}, + {"max_leaves", "0"} + }; + + LearnerTrainParam learner_param(CreateEmptyGenericParam(0, 1)); + hist_maker.Init(training_params, &learner_param); + hist_maker.InitDataOnce(hist_maker_dmat.get()); + hist_maker_ext.Init(training_params, &learner_param); + hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); + + // Extract the device shards from the histogram makers and from that its compressed + // histogram index + const auto &dev_shard = hist_maker.shards_[0]; + std::vector h_gidx_buffer(dev_shard->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer); + + const auto &dev_shard_ext = hist_maker_ext.shards_[0]; + std::vector h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer); + + ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins); + ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size()); + + ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext); +} + } // namespace tree } // namespace xgboost