diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index b207d8b31..eb2f0ca4a 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -27,9 +27,11 @@ namespace common { using WXQSketch = DenseCuts::WXQSketch; -__global__ void FindCutsK -(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data, - const float* __restrict__ cum_weights, int nsamples, int ncuts) { +__global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts, + const bst_float* __restrict__ data, + const float* __restrict__ cum_weights, + int nsamples, + int ncuts) { // ncuts < nsamples int icut = threadIdx.x + blockIdx.x * blockDim.x; if (icut >= ncuts) { @@ -42,7 +44,7 @@ __global__ void FindCutsK isample = nsamples - 1; } else { bst_float rank = cum_weights[nsamples - 1] / static_cast(ncuts - 1) - * static_cast(icut); + * static_cast(icut); // -1 is used because cum_weights is an inclusive sum isample = dh::UpperBound(cum_weights, nsamples, rank); isample = max(0, min(isample, nsamples - 1)); @@ -99,9 +101,8 @@ struct SketchContainer { std::vector col_locks_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; - SketchContainer(int max_bin, DMatrix *dmat) : - col_locks_(dmat->Info().num_col_) { - const MetaInfo &info = dmat->Info(); + SketchContainer(int max_bin, DMatrix* dmat) : col_locks_(dmat->Info().num_col_) { + const MetaInfo& info = dmat->Info(); // Initialize Sketches for this dmatrix sketches_.resize(info.num_col_); #pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \ @@ -119,328 +120,339 @@ if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT }; // finds quantiles on the GPU -struct GPUSketcher { - // manage memory for a single GPU - class DeviceShard { - int device_; - bst_uint n_rows_; - int num_cols_{0}; - size_t n_cuts_{0}; - size_t gpu_batch_nrows_{0}; - bool has_weights_{false}; - size_t row_stride_{0}; - - const int max_bin_; - SketchContainer *sketch_container_; - dh::device_vector row_ptrs_{}; - dh::device_vector entries_{}; - dh::device_vector fvalues_{}; - dh::device_vector feature_weights_{}; - dh::device_vector fvalues_cur_{}; - dh::device_vector cuts_d_{}; - thrust::host_vector cuts_h_{}; - dh::device_vector weights_{}; - dh::device_vector weights2_{}; - std::vector n_cuts_cur_{}; - dh::device_vector num_elements_{}; - dh::device_vector tmp_storage_{}; - - public: - DeviceShard(int device, - bst_uint n_rows, - int max_bin, - SketchContainer* sketch_container) : - device_(device), - n_rows_(n_rows), - max_bin_(max_bin), - sketch_container_(sketch_container) { - } - - ~DeviceShard() { // NOLINT - dh::safe_cuda(cudaSetDevice(device_)); - } - - inline size_t GetRowStride() const { - return row_stride_; - } - - void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { - num_cols_ = info.num_col_; - has_weights_ = info.weights_.Size() > 0; - - // find the batch size - if (gpu_batch_nrows == 0) { - // By default, use no more than 1/16th of GPU memory - gpu_batch_nrows_ = dh::TotalMemory(device_) / - (16 * num_cols_ * sizeof(Entry)); - } else if (gpu_batch_nrows == -1) { - gpu_batch_nrows_ = n_rows_; - } else { - gpu_batch_nrows_ = gpu_batch_nrows; - } - if (gpu_batch_nrows_ > n_rows_) { - gpu_batch_nrows_ = n_rows_; - } - - constexpr int kFactor = 8; - double eps = 1.0 / (kFactor * max_bin_); - size_t dummy_nlevel; - WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); - - // allocate necessary GPU buffers - dh::safe_cuda(cudaSetDevice(device_)); - - entries_.resize(gpu_batch_nrows_ * num_cols_); - fvalues_.resize(gpu_batch_nrows_ * num_cols_); - fvalues_cur_.resize(gpu_batch_nrows_); - cuts_d_.resize(n_cuts_ * num_cols_); - cuts_h_.resize(n_cuts_ * num_cols_); - weights_.resize(gpu_batch_nrows_); - weights2_.resize(gpu_batch_nrows_); - num_elements_.resize(1); - - if (has_weights_) { - feature_weights_.resize(gpu_batch_nrows_ * num_cols_); - } - n_cuts_cur_.resize(num_cols_); - - // allocate storage for CUB algorithms; the size is the maximum of the sizes - // required for various algorithm - size_t tmp_size = 0, cur_tmp_size = 0; - // size for sorting - if (has_weights_) { - cub::DeviceRadixSort::SortPairs - (nullptr, cur_tmp_size, fvalues_cur_.data().get(), - fvalues_.data().get(), weights_.data().get(), weights2_.data().get(), - gpu_batch_nrows_); - } else { - cub::DeviceRadixSort::SortKeys - (nullptr, cur_tmp_size, fvalues_cur_.data().get(), fvalues_.data().get(), - gpu_batch_nrows_); - } - tmp_size = std::max(tmp_size, cur_tmp_size); - // size for inclusive scan - if (has_weights_) { - cub::DeviceScan::InclusiveSum - (nullptr, cur_tmp_size, weights2_.begin(), weights_.begin(), gpu_batch_nrows_); - tmp_size = std::max(tmp_size, cur_tmp_size); - } - // size for reduction by key - cub::DeviceReduce::ReduceByKey - (nullptr, cur_tmp_size, fvalues_.begin(), - fvalues_cur_.begin(), weights_.begin(), weights2_.begin(), - num_elements_.begin(), thrust::maximum(), gpu_batch_nrows_); - tmp_size = std::max(tmp_size, cur_tmp_size); - // size for filtering - cub::DeviceSelect::If - (nullptr, cur_tmp_size, fvalues_.begin(), fvalues_cur_.begin(), - num_elements_.begin(), gpu_batch_nrows_, IsNotNaN()); - tmp_size = std::max(tmp_size, cur_tmp_size); - - tmp_storage_.resize(tmp_size); - } - - void FindColumnCuts(size_t batch_nrows, size_t icol) { - size_t tmp_size = tmp_storage_.size(); - // filter out NaNs in feature values - auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_; - cub::DeviceSelect::If - (tmp_storage_.data().get(), tmp_size, fvalues_begin, - fvalues_cur_.data(), num_elements_.begin(), batch_nrows, IsNotNaN()); - size_t nfvalues_cur = 0; - thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur); - - // compute cumulative weights using a prefix scan - if (has_weights_) { - // filter out NaNs in weights; - // since cub::DeviceSelect::If performs stable filtering, - // the weights are stored in the correct positions - auto feature_weights_begin = feature_weights_.data() + - icol * gpu_batch_nrows_; - cub::DeviceSelect::If - (tmp_storage_.data().get(), tmp_size, feature_weights_begin, - weights_.data().get(), num_elements_.begin(), batch_nrows, IsNotNaN()); - - // sort the values and weights - cub::DeviceRadixSort::SortPairs - (tmp_storage_.data().get(), tmp_size, fvalues_cur_.data().get(), - fvalues_begin.get(), weights_.data().get(), weights2_.data().get(), - nfvalues_cur); - - // sum the weights to get cumulative weight values - cub::DeviceScan::InclusiveSum - (tmp_storage_.data().get(), tmp_size, weights2_.begin(), - weights_.begin(), nfvalues_cur); - } else { - // sort the batch values - cub::DeviceRadixSort::SortKeys - (tmp_storage_.data().get(), tmp_size, - fvalues_cur_.data().get(), fvalues_begin.get(), nfvalues_cur); - - // fill in cumulative weights with counting iterator - thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, - weights_.begin()); - } - - // remove repeated items and sum the weights across them; - // non-negative weights are assumed - cub::DeviceReduce::ReduceByKey - (tmp_storage_.data().get(), tmp_size, fvalues_begin, - fvalues_cur_.begin(), weights_.begin(), weights2_.begin(), - num_elements_.begin(), thrust::maximum(), nfvalues_cur); - size_t n_unique = 0; - thrust::copy_n(num_elements_.begin(), 1, &n_unique); - - // extract cuts - n_cuts_cur_[icol] = std::min(n_cuts_, n_unique); - // if less elements than cuts: copy all elements with their weights - if (n_cuts_ > n_unique) { - float* weights2_ptr = weights2_.data().get(); - float* fvalues_ptr = fvalues_cur_.data().get(); - WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_; - dh::LaunchN(device_, n_unique, [=]__device__(size_t i) { - bst_float rmax = weights2_ptr[i]; - bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0; - cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]); - }); - } else if (n_cuts_cur_[icol] > 0) { - // if more elements than cuts: use binary search on cumulative weights - int block = 256; - FindCutsK<<>> - (cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(), - weights2_.data().get(), n_unique, n_cuts_cur_[icol]); - dh::safe_cuda(cudaGetLastError()); // NOLINT - } - } - - void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, - size_t gpu_batch) { - // compute start and end indices - size_t batch_row_begin = gpu_batch * gpu_batch_nrows_; - size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_, - static_cast(n_rows_)); - size_t batch_nrows = batch_row_end - batch_row_begin; - - const auto& offset_vec = row_batch.offset.HostVector(); - const auto& data_vec = row_batch.data.HostVector(); - - size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin]; - // copy the batch to the GPU - dh::safe_cuda - (cudaMemcpyAsync(entries_.data().get(), - data_vec.data() + offset_vec[batch_row_begin], - n_entries * sizeof(Entry), cudaMemcpyDefault)); - // copy the weights if necessary - if (has_weights_) { - const auto& weights_vec = info.weights_.HostVector(); - dh::safe_cuda - (cudaMemcpyAsync(weights_.data().get(), - weights_vec.data() + batch_row_begin, - batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); - } - - // unpack the features; also unpack weights if present - thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); - if (has_weights_) { - thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); - } - - dim3 block3(16, 64, 1); - // NOTE: This will typically support ~ 4M features - 64K*64 - dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), - common::DivRoundUp(num_cols_, block3.y), 1); - UnpackFeaturesK<<>> - (fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr, - row_ptrs_.data().get() + batch_row_begin, - has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), - gpu_batch_nrows_, offset_vec[batch_row_begin], batch_nrows); - - for (int icol = 0; icol < num_cols_; ++icol) { - FindColumnCuts(batch_nrows, icol); - } - - // add cuts into sketches - thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); -#pragma omp parallel for default(none) schedule(static) \ -if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT - for (int icol = 0; icol < num_cols_; ++icol) { - WXQSketch::SummaryContainer summary; - summary.Reserve(n_cuts_); - summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); - - std::lock_guard lock(sketch_container_->col_locks_[icol]); - sketch_container_->sketches_[icol].PushSummary(summary); - } - } - - void ComputeRowStride() { - // Find the row stride for this batch - auto row_iter = row_ptrs_.begin(); - // Functor for finding the maximum row size for this batch - auto get_size = [=] __device__(size_t row) { - return row_iter[row + 1] - row_iter[row]; - }; // NOLINT - - auto counting = thrust::make_counting_iterator(size_t(0)); - using TransformT = thrust::transform_iterator; - TransformT row_size_iter = TransformT(counting, get_size); - row_stride_ = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, - thrust::maximum()); - } - - void Sketch(const SparsePage& row_batch, const MetaInfo& info) { - // copy rows to the device - dh::safe_cuda(cudaSetDevice(device_)); - const auto& offset_vec = row_batch.offset.HostVector(); - row_ptrs_.resize(n_rows_ + 1); - thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin()); - size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_); - for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { - SketchBatch(row_batch, info, gpu_batch); - } - } - }; - - void SketchBatch(const SparsePage &batch, const MetaInfo &info) { - // create device shard - shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get())); - - // compute sketches for the shard - shard_->Init(batch, info, gpu_batch_nrows_); - shard_->Sketch(batch, info); - shard_->ComputeRowStride(); - - // compute row stride - row_stride_ = shard_->GetRowStride(); - } - +class GPUSketcher { + public: GPUSketcher(int device, int max_bin, int gpu_nrows) : device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {} + ~GPUSketcher() { // NOLINT + dh::safe_cuda(cudaSetDevice(device_)); + } + + void SketchBatch(const SparsePage &batch, const MetaInfo &info) { + n_rows_ = batch.Size(); + + Init(batch, info, gpu_batch_nrows_); + Sketch(batch, info); + ComputeRowStride(); + } + /* Builds the sketches on the GPU for the dmatrix and returns the row stride * for the entire dataset */ size_t Sketch(DMatrix *dmat, DenseCuts *hmat) { - const MetaInfo &info = dmat->Info(); + const MetaInfo& info = dmat->Info(); row_stride_ = 0; sketch_container_.reset(new SketchContainer(max_bin_, dmat)); - for (const auto &batch : dmat->GetBatches()) { + for (const auto& batch : dmat->GetBatches()) { this->SketchBatch(batch, info); } hmat->Init(&sketch_container_->sketches_, max_bin_); - return row_stride_; } + // This needs to be public because of the __device__ lambda. + void ComputeRowStride() { + // Find the row stride for this batch + auto row_iter = row_ptrs_.begin(); + // Functor for finding the maximum row size for this batch + auto get_size = [=] __device__(size_t row) { + return row_iter[row + 1] - row_iter[row]; + }; // NOLINT + + auto counting = thrust::make_counting_iterator(size_t(0)); + using TransformT = thrust::transform_iterator; + TransformT row_size_iter = TransformT(counting, get_size); + row_stride_ = + thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum()); + } + + // This needs to be public because of the __device__ lambda. + void FindColumnCuts(size_t batch_nrows, size_t icol) { + size_t tmp_size = tmp_storage_.size(); + // filter out NaNs in feature values + auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_; + cub::DeviceSelect::If(tmp_storage_.data().get(), + tmp_size, + fvalues_begin, + fvalues_cur_.data(), + num_elements_.begin(), + batch_nrows, + IsNotNaN()); + size_t nfvalues_cur = 0; + thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur); + + // compute cumulative weights using a prefix scan + if (has_weights_) { + // filter out NaNs in weights; + // since cub::DeviceSelect::If performs stable filtering, + // the weights are stored in the correct positions + auto feature_weights_begin = feature_weights_.data() + icol * gpu_batch_nrows_; + cub::DeviceSelect::If(tmp_storage_.data().get(), + tmp_size, + feature_weights_begin, + weights_.data().get(), + num_elements_.begin(), + batch_nrows, + IsNotNaN()); + + // sort the values and weights + cub::DeviceRadixSort::SortPairs(tmp_storage_.data().get(), + tmp_size, + fvalues_cur_.data().get(), + fvalues_begin.get(), + weights_.data().get(), + weights2_.data().get(), + nfvalues_cur); + + // sum the weights to get cumulative weight values + cub::DeviceScan::InclusiveSum(tmp_storage_.data().get(), + tmp_size, + weights2_.begin(), + weights_.begin(), + nfvalues_cur); + } else { + // sort the batch values + cub::DeviceRadixSort::SortKeys(tmp_storage_.data().get(), + tmp_size, + fvalues_cur_.data().get(), + fvalues_begin.get(), + nfvalues_cur); + + // fill in cumulative weights with counting iterator + thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, weights_.begin()); + } + + // remove repeated items and sum the weights across them; + // non-negative weights are assumed + cub::DeviceReduce::ReduceByKey(tmp_storage_.data().get(), + tmp_size, + fvalues_begin, + fvalues_cur_.begin(), + weights_.begin(), + weights2_.begin(), + num_elements_.begin(), + thrust::maximum(), + nfvalues_cur); + size_t n_unique = 0; + thrust::copy_n(num_elements_.begin(), 1, &n_unique); + + // extract cuts + n_cuts_cur_[icol] = std::min(n_cuts_, n_unique); + // if less elements than cuts: copy all elements with their weights + if (n_cuts_ > n_unique) { + float* weights2_ptr = weights2_.data().get(); + float* fvalues_ptr = fvalues_cur_.data().get(); + WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_; + dh::LaunchN(device_, n_unique, [=]__device__(size_t i) { + bst_float rmax = weights2_ptr[i]; + bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0; + cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]); + }); + } else if (n_cuts_cur_[icol] > 0) { + // if more elements than cuts: use binary search on cumulative weights + int block = 256; + FindCutsK<<>>( + cuts_d_.data().get() + icol * n_cuts_, + fvalues_cur_.data().get(), + weights2_.data().get(), + n_unique, + n_cuts_cur_[icol]); + dh::safe_cuda(cudaGetLastError()); // NOLINT + } + } + private: - std::unique_ptr shard_; + void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { + num_cols_ = info.num_col_; + has_weights_ = info.weights_.Size() > 0; + + // find the batch size + if (gpu_batch_nrows == 0) { + // By default, use no more than 1/16th of GPU memory + gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry)); + } else if (gpu_batch_nrows == -1) { + gpu_batch_nrows_ = n_rows_; + } else { + gpu_batch_nrows_ = gpu_batch_nrows; + } + if (gpu_batch_nrows_ > n_rows_) { + gpu_batch_nrows_ = n_rows_; + } + + constexpr int kFactor = 8; + double eps = 1.0 / (kFactor * max_bin_); + size_t dummy_nlevel; + WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); + + // allocate necessary GPU buffers + dh::safe_cuda(cudaSetDevice(device_)); + + entries_.resize(gpu_batch_nrows_ * num_cols_); + fvalues_.resize(gpu_batch_nrows_ * num_cols_); + fvalues_cur_.resize(gpu_batch_nrows_); + cuts_d_.resize(n_cuts_ * num_cols_); + cuts_h_.resize(n_cuts_ * num_cols_); + weights_.resize(gpu_batch_nrows_); + weights2_.resize(gpu_batch_nrows_); + num_elements_.resize(1); + + if (has_weights_) { + feature_weights_.resize(gpu_batch_nrows_ * num_cols_); + } + n_cuts_cur_.resize(num_cols_); + + // allocate storage for CUB algorithms; the size is the maximum of the sizes + // required for various algorithm + size_t tmp_size = 0, cur_tmp_size = 0; + // size for sorting + if (has_weights_) { + cub::DeviceRadixSort::SortPairs(nullptr, + cur_tmp_size, + fvalues_cur_.data().get(), + fvalues_.data().get(), + weights_.data().get(), + weights2_.data().get(), + gpu_batch_nrows_); + } else { + cub::DeviceRadixSort::SortKeys(nullptr, + cur_tmp_size, + fvalues_cur_.data().get(), + fvalues_.data().get(), + gpu_batch_nrows_); + } + tmp_size = std::max(tmp_size, cur_tmp_size); + // size for inclusive scan + if (has_weights_) { + cub::DeviceScan::InclusiveSum(nullptr, + cur_tmp_size, + weights2_.begin(), + weights_.begin(), + gpu_batch_nrows_); + tmp_size = std::max(tmp_size, cur_tmp_size); + } + // size for reduction by key + cub::DeviceReduce::ReduceByKey(nullptr, + cur_tmp_size, + fvalues_.begin(), + fvalues_cur_.begin(), + weights_.begin(), + weights2_.begin(), + num_elements_.begin(), + thrust::maximum(), + gpu_batch_nrows_); + tmp_size = std::max(tmp_size, cur_tmp_size); + // size for filtering + cub::DeviceSelect::If(nullptr, + cur_tmp_size, + fvalues_.begin(), + fvalues_cur_.begin(), + num_elements_.begin(), + gpu_batch_nrows_, + IsNotNaN()); + tmp_size = std::max(tmp_size, cur_tmp_size); + + tmp_storage_.resize(tmp_size); + } + + void Sketch(const SparsePage& row_batch, const MetaInfo& info) { + // copy rows to the device + dh::safe_cuda(cudaSetDevice(device_)); + const auto& offset_vec = row_batch.offset.HostVector(); + row_ptrs_.resize(n_rows_ + 1); + thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin()); + size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_); + for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { + SketchBatch(row_batch, info, gpu_batch); + } + } + + void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, size_t gpu_batch) { + // compute start and end indices + size_t batch_row_begin = gpu_batch * gpu_batch_nrows_; + size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_, + static_cast(n_rows_)); + size_t batch_nrows = batch_row_end - batch_row_begin; + + const auto& offset_vec = row_batch.offset.HostVector(); + const auto& data_vec = row_batch.data.HostVector(); + + size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin]; + // copy the batch to the GPU + dh::safe_cuda(cudaMemcpyAsync(entries_.data().get(), + data_vec.data() + offset_vec[batch_row_begin], + n_entries * sizeof(Entry), + cudaMemcpyDefault)); + // copy the weights if necessary + if (has_weights_) { + const auto& weights_vec = info.weights_.HostVector(); + dh::safe_cuda(cudaMemcpyAsync(weights_.data().get(), + weights_vec.data() + batch_row_begin, + batch_nrows * sizeof(bst_float), + cudaMemcpyDefault)); + } + + // unpack the features; also unpack weights if present + thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); + if (has_weights_) { + thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + } + + dim3 block3(16, 64, 1); + // NOTE: This will typically support ~ 4M features - 64K*64 + dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), + common::DivRoundUp(num_cols_, block3.y), 1); + UnpackFeaturesK<<>>( + fvalues_.data().get(), + has_weights_ ? feature_weights_.data().get() : nullptr, + row_ptrs_.data().get() + batch_row_begin, + has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), + gpu_batch_nrows_, + offset_vec[batch_row_begin], + batch_nrows); + + for (int icol = 0; icol < num_cols_; ++icol) { + FindColumnCuts(batch_nrows, icol); + } + + // add cuts into sketches + thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); +#pragma omp parallel for default(none) schedule(static) \ +if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT + for (int icol = 0; icol < num_cols_; ++icol) { + WXQSketch::SummaryContainer summary; + summary.Reserve(n_cuts_); + summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); + + std::lock_guard lock(sketch_container_->col_locks_[icol]); + sketch_container_->sketches_[icol].PushSummary(summary); + } + } + const int device_; const int max_bin_; int gpu_batch_nrows_; size_t row_stride_; std::unique_ptr sketch_container_; + + bst_uint n_rows_{}; + int num_cols_{0}; + size_t n_cuts_{0}; + bool has_weights_{false}; + + dh::device_vector row_ptrs_{}; + dh::device_vector entries_{}; + dh::device_vector fvalues_{}; + dh::device_vector feature_weights_{}; + dh::device_vector fvalues_cur_{}; + dh::device_vector cuts_d_{}; + thrust::host_vector cuts_h_{}; + dh::device_vector weights_{}; + dh::device_vector weights2_{}; + std::vector n_cuts_cur_{}; + dh::device_vector num_elements_{}; + dh::device_vector tmp_storage_{}; }; size_t DeviceSketch(int device, diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index e54c50a24..e2d4a04f7 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -14,8 +14,8 @@ * Initialization/Allocation:
* One can choose to initialize the vector on CPU or GPU during constructor. * (use the 'devices' argument) Or, can choose to use the 'Resize' method to - * allocate/resize memory explicitly, and use the 'Shard' method - * to specify the devices. + * allocate/resize memory explicitly, and use the 'SetDevice' method + * to specify the device. * * Accessing underlying data:
* Use 'HostVector' method to explicitly query for the underlying std::vector. diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index affd9f22a..cfacec0d6 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -73,7 +73,7 @@ void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) { const auto& info = dmat_->Info(); auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_; - // Init global data for each shard + // Init global data monitor_.StartCuda("InitCompressedData"); InitCompressedData(device, hmat, row_stride, is_dense); monitor_.StopCuda("InitCompressedData"); diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 74bbd8e25..a249a8538 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -19,27 +19,39 @@ namespace linear { DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate); -class DeviceShard { - int device_id_; - dh::BulkAllocator ba_; - std::vector row_ptr_; - common::Span data_; - common::Span gpair_; - dh::CubMemory temp_; - size_t shard_size_; +/** + * \class GPUCoordinateUpdater + * + * \brief Coordinate descent algorithm that updates one feature per iteration + */ +class GPUCoordinateUpdater : public LinearUpdater { // NOLINT public: - DeviceShard(int device_id, - const SparsePage &batch, // column batch - bst_uint shard_size, - const LinearTrainParam ¶m, - const gbm::GBLinearModelParam &model_param) - : device_id_(device_id), - shard_size_(shard_size) { + ~GPUCoordinateUpdater() { // NOLINT + if (learner_param_->gpu_id >= 0) { + dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + } + } + + // set training parameter + void Configure(Args const& args) override { + tparam_.InitAllowUnknown(args); + selector_.reset(FeatureSelector::Create(tparam_.feature_selector)); + monitor_.Init("GPUCoordinateUpdater"); + } + + void LazyInitDevice(DMatrix *p_fmat, const gbm::GBLinearModelParam &model_param) { + if (learner_param_->gpu_id < 0) return; + + num_row_ = static_cast(p_fmat->Info().num_row_); + + CHECK(p_fmat->SingleColBlock()); + SparsePage const& batch = *(p_fmat->GetBatches().begin()); + if ( IsEmpty() ) { return; } - dh::safe_cuda(cudaSetDevice(device_id_)); + dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); // The begin and end indices for the section of each column associated with - // this shard + // this device std::vector> column_segments; row_ptr_ = {0}; // iterate through columns @@ -53,13 +65,13 @@ class DeviceShard { xgboost::Entry(0, 0.0f), cmp); auto column_end = std::lower_bound(col.cbegin(), col.cend(), - xgboost::Entry(shard_size_, 0.0f), cmp); + xgboost::Entry(num_row_, 0.0f), cmp); column_segments.emplace_back( std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin())); row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin)); } - ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_, - shard_size_ * model_param.num_output_group); + ba_.Allocate(learner_param_->gpu_id, &data_, row_ptr_.back(), &gpair_, + num_row_ * model_param.num_output_group); for (size_t fidx = 0; fidx < batch.Size(); fidx++) { auto col = batch[fidx]; @@ -71,121 +83,18 @@ class DeviceShard { } } - ~DeviceShard() { // NOLINT - dh::safe_cuda(cudaSetDevice(device_id_)); - } - - bool IsEmpty() { - return shard_size_ == 0; - } - - void UpdateGpair(const std::vector &host_gpair, - const gbm::GBLinearModelParam &model_param) { - dh::safe_cuda(cudaMemcpyAsync( - gpair_.data(), - host_gpair.data(), - gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice)); - } - - GradientPair GetBiasGradient(int group_idx, int num_group) { - dh::safe_cuda(cudaSetDevice(device_id_)); - auto counting = thrust::make_counting_iterator(0ull); - auto f = [=] __device__(size_t idx) { - return idx * num_group + group_idx; - }; // NOLINT - thrust::transform_iterator skip( - counting, f); - auto perm = thrust::make_permutation_iterator(gpair_.data(), skip); - - return dh::SumReduction(temp_, perm, shard_size_); - } - - void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { - if (dbias == 0.0f) return; - auto d_gpair = gpair_; - dh::LaunchN(device_id_, shard_size_, [=] __device__(size_t idx) { - auto &g = d_gpair[idx * num_groups + group_idx]; - g += GradientPair(g.GetHess() * dbias, 0); - }); - } - - GradientPair GetGradient(int group_idx, int num_group, int fidx) { - dh::safe_cuda(cudaSetDevice(device_id_)); - common::Span d_col = data_.subspan(row_ptr_[fidx]); - size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; - common::Span d_gpair = gpair_; - auto counting = thrust::make_counting_iterator(0ull); - auto f = [=] __device__(size_t idx) { - auto entry = d_col[idx]; - auto g = d_gpair[entry.index * num_group + group_idx]; - return GradientPair(g.GetGrad() * entry.fvalue, - g.GetHess() * entry.fvalue * entry.fvalue); - }; // NOLINT - thrust::transform_iterator - multiply_iterator(counting, f); - return dh::SumReduction(temp_, multiply_iterator, col_size); - } - - void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) { - common::Span d_gpair = gpair_; - common::Span d_col = data_.subspan(row_ptr_[fidx]); - size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; - dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) { - auto entry = d_col[idx]; - auto &g = d_gpair[entry.index * num_groups + group_idx]; - g += GradientPair(g.GetHess() * dw * entry.fvalue, 0); - }); - } -}; - -/** - * \class GPUCoordinateUpdater - * - * \brief Coordinate descent algorithm that updates one feature per iteration - */ - -class GPUCoordinateUpdater : public LinearUpdater { // NOLINT - public: - // set training parameter - void Configure(Args const& args) override { - tparam_.InitAllowUnknown(args); - selector_.reset(FeatureSelector::Create(tparam_.feature_selector)); - monitor_.Init("GPUCoordinateUpdater"); - } - - void LazyInitShards(DMatrix *p_fmat, - const gbm::GBLinearModelParam &model_param) { - if (shard_) return; - - device_ = learner_param_->gpu_id; - - auto num_row = static_cast(p_fmat->Info().num_row_); - - // Partition input matrix into row segments - std::vector row_segments; - row_segments.push_back(0); - size_t shard_size = num_row; - row_segments.push_back(shard_size); - - CHECK(p_fmat->SingleColBlock()); - SparsePage const& batch = *(p_fmat->GetBatches().begin()); - - // Create device shard - shard_.reset(new DeviceShard(device_, batch, shard_size, tparam_, model_param)); - } - void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { tparam_.DenormalizePenalties(sum_instance_weight); - monitor_.Start("LazyInitShards"); - this->LazyInitShards(p_fmat, model->param); - monitor_.Stop("LazyInitShards"); + monitor_.Start("LazyInitDevice"); + this->LazyInitDevice(p_fmat, model->param); + monitor_.Stop("LazyInitDevice"); monitor_.Start("UpdateGpair"); auto &in_gpair_host = in_gpair->ConstHostVector(); // Update gpair - if (shard_) { - shard_->UpdateGpair(in_gpair_host, model->param); + if (learner_param_->gpu_id >= 0) { + this->UpdateGpair(in_gpair_host, model->param); } monitor_.Stop("UpdateGpair"); @@ -197,8 +106,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, coord_param_.top_k); monitor_.Start("UpdateFeature"); - for (auto group_idx = 0; group_idx < model->param.num_output_group; - ++group_idx) { + for (auto group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) { for (auto i = 0U; i < model->param.num_feature; i++) { auto fidx = selector_->NextFeature( i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, @@ -214,8 +122,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT for (int group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) { // Get gradient auto grad = GradientPair(0, 0); - if (shard_) { - grad = shard_->GetBiasGradient(group_idx, model->param.num_output_group); + if (learner_param_->gpu_id >= 0) { + grad = GetBiasGradient(group_idx, model->param.num_output_group); } auto dbias = static_cast( tparam_.learning_rate * @@ -223,8 +131,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT model->bias()[group_idx] += dbias; // Update residual - if (shard_) { - shard_->UpdateBiasResidual(dbias, group_idx, model->param.num_output_group); + if (learner_param_->gpu_id >= 0) { + UpdateBiasResidual(dbias, group_idx, model->param.num_output_group); } } } @@ -235,8 +143,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT bst_float &w = (*model)[fidx][group_idx]; // Get gradient auto grad = GradientPair(0, 0); - if (shard_) { - grad = shard_->GetGradient(group_idx, model->param.num_output_group, fidx); + if (learner_param_->gpu_id >= 0) { + grad = GetGradient(group_idx, model->param.num_output_group, fidx); } auto dw = static_cast(tparam_.learning_rate * CoordinateDelta(grad.GetGrad(), grad.GetHess(), @@ -244,20 +152,90 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT tparam_.reg_lambda_denorm)); w += dw; - if (shard_) { - shard_->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); + if (learner_param_->gpu_id >= 0) { + UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); } } + // This needs to be public because of the __device__ lambda. + GradientPair GetBiasGradient(int group_idx, int num_group) { + dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + auto counting = thrust::make_counting_iterator(0ull); + auto f = [=] __device__(size_t idx) { + return idx * num_group + group_idx; + }; // NOLINT + thrust::transform_iterator skip( + counting, f); + auto perm = thrust::make_permutation_iterator(gpair_.data(), skip); + + return dh::SumReduction(temp_, perm, num_row_); + } + + // This needs to be public because of the __device__ lambda. + void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { + if (dbias == 0.0f) return; + auto d_gpair = gpair_; + dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) { + auto &g = d_gpair[idx * num_groups + group_idx]; + g += GradientPair(g.GetHess() * dbias, 0); + }); + } + + // This needs to be public because of the __device__ lambda. + GradientPair GetGradient(int group_idx, int num_group, int fidx) { + dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + common::Span d_col = data_.subspan(row_ptr_[fidx]); + size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; + common::Span d_gpair = gpair_; + auto counting = thrust::make_counting_iterator(0ull); + auto f = [=] __device__(size_t idx) { + auto entry = d_col[idx]; + auto g = d_gpair[entry.index * num_group + group_idx]; + return GradientPair(g.GetGrad() * entry.fvalue, + g.GetHess() * entry.fvalue * entry.fvalue); + }; // NOLINT + thrust::transform_iterator + multiply_iterator(counting, f); + return dh::SumReduction(temp_, multiply_iterator, col_size); + } + + // This needs to be public because of the __device__ lambda. + void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) { + common::Span d_gpair = gpair_; + common::Span d_col = data_.subspan(row_ptr_[fidx]); + size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; + dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) { + auto entry = d_col[idx]; + auto &g = d_gpair[entry.index * num_groups + group_idx]; + g += GradientPair(g.GetHess() * dw * entry.fvalue, 0); + }); + } + private: + bool IsEmpty() { + return num_row_ == 0; + } + + void UpdateGpair(const std::vector &host_gpair, + const gbm::GBLinearModelParam &model_param) { + dh::safe_cuda(cudaMemcpyAsync( + gpair_.data(), + host_gpair.data(), + gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice)); + } + // training parameter LinearTrainParam tparam_; CoordinateParam coord_param_; - int device_{}; std::unique_ptr selector_; common::Monitor monitor_; - std::unique_ptr shard_{nullptr}; + dh::BulkAllocator ba_; + std::vector row_ptr_; + common::Span data_; + common::Span gpair_; + dh::CubMemory temp_; + size_t num_row_; }; XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent") diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 80cf69410..c90d927e8 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -33,9 +33,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter { .describe("Number of output class in the multi-class classification."); } }; -// TODO(trivialfis): Currently the sharding in softmax is less than ideal -// due to repeated copying data between CPU and GPUs. Maybe we just use single -// GPU? + class SoftmaxMultiClassObj : public ObjFunction { public: explicit SoftmaxMultiClassObj(bool output_prob) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index d50cbcd79..6a2fce52c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -195,77 +195,52 @@ __global__ void PredictKernel(common::Span d_nodes, class GPUPredictor : public xgboost::Predictor { private: - struct DeviceShard { - DeviceShard() : device_{-1} {} + void InitModel(const gbm::GBTreeModel& model, + const thrust::host_vector& h_tree_segments, + const thrust::host_vector& h_nodes, + size_t tree_begin, size_t tree_end) { + dh::safe_cuda(cudaSetDevice(device_)); + nodes_.resize(h_nodes.size()); + dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(), + sizeof(DevicePredictionNode) * h_nodes.size(), + cudaMemcpyHostToDevice)); + tree_segments_.resize(h_tree_segments.size()); + dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(), + sizeof(size_t) * h_tree_segments.size(), + cudaMemcpyHostToDevice)); + tree_group_.resize(model.tree_info.size()); + dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(), + sizeof(int) * model.tree_info.size(), + cudaMemcpyHostToDevice)); + this->tree_begin_ = tree_begin; + this->tree_end_ = tree_end; + this->num_group_ = model.param.num_output_group; + } - ~DeviceShard() { - if (device_ >= 0) { - dh::safe_cuda(cudaSetDevice(device_)); - } + void PredictInternal(const SparsePage& batch, + size_t num_features, + HostDeviceVector* predictions, + size_t batch_offset) { + dh::safe_cuda(cudaSetDevice(device_)); + const int BLOCK_THREADS = 128; + size_t num_rows = batch.Size(); + const int GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); + + int shared_memory_bytes = static_cast + (sizeof(float) * num_features * BLOCK_THREADS); + bool use_shared = true; + if (shared_memory_bytes > max_shared_memory_bytes_) { + shared_memory_bytes = 0; + use_shared = false; } + size_t entry_start = 0; - void Init(int device) { - this->device_ = device; - max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_); - } - - void InitModel(const gbm::GBTreeModel& model, - const thrust::host_vector& h_tree_segments, - const thrust::host_vector& h_nodes, - size_t tree_begin, size_t tree_end) { - dh::safe_cuda(cudaSetDevice(device_)); - nodes_.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(), - sizeof(DevicePredictionNode) * h_nodes.size(), - cudaMemcpyHostToDevice)); - tree_segments_.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(), - sizeof(size_t) * h_tree_segments.size(), - cudaMemcpyHostToDevice)); - tree_group_.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(), - sizeof(int) * model.tree_info.size(), - cudaMemcpyHostToDevice)); - this->tree_begin_ = tree_begin; - this->tree_end_ = tree_end; - this->num_group_ = model.param.num_output_group; - } - - void PredictInternal(const SparsePage& batch, - size_t num_features, - HostDeviceVector* predictions, - size_t batch_offset) { - dh::safe_cuda(cudaSetDevice(device_)); - const int BLOCK_THREADS = 128; - size_t num_rows = batch.Size(); - const int GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - - int shared_memory_bytes = static_cast - (sizeof(float) * num_features * BLOCK_THREADS); - bool use_shared = true; - if (shared_memory_bytes > max_shared_memory_bytes_) { - shared_memory_bytes = 0; - use_shared = false; - } - size_t entry_start = 0; - - PredictKernel<<>> - (dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), - batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, - entry_start, use_shared, this->num_group_); - } - - private: - int device_; - dh::device_vector nodes_; - dh::device_vector tree_segments_; - dh::device_vector tree_group_; - size_t max_shared_memory_bytes_; - size_t tree_begin_; - size_t tree_end_; - int num_group_; - }; + PredictKernel<<>> + (dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), + dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), + batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, + entry_start, use_shared, this->num_group_); + } void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { CHECK_EQ(model.param.size_leaf_vector, 0); @@ -285,7 +260,7 @@ class GPUPredictor : public xgboost::Predictor { std::copy(src_nodes.begin(), src_nodes.end(), h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } - shard_.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); + InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); } void DevicePredictInternal(DMatrix* dmat, @@ -301,7 +276,7 @@ class GPUPredictor : public xgboost::Predictor { for (auto &batch : dmat->GetBatches()) { batch.offset.SetDevice(device_); batch.data.SetDevice(device_); - shard_.PredictInternal(batch, model.param.num_feature, out_preds, batch_offset); + PredictInternal(batch, model.param.num_feature, out_preds, batch_offset); batch_offset += batch.Size() * model.param.num_output_group; } @@ -309,14 +284,20 @@ class GPUPredictor : public xgboost::Predictor { } public: - GPUPredictor() : device_{-1} {}; + GPUPredictor() : device_{-1} {} + + ~GPUPredictor() override { + if (device_ >= 0) { + dh::safe_cuda(cudaSetDevice(device_)); + } + } void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { int device = learner_param_->gpu_id; CHECK_GE(device, 0); - ConfigureShard(device); + ConfigureDevice(device); if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { return; @@ -433,22 +414,29 @@ class GPUPredictor : public xgboost::Predictor { int device = learner_param_->gpu_id; if (device >= 0) { - ConfigureShard(device); + ConfigureDevice(device); } } private: - /*! \brief Reconfigure the shard when GPU is changed. */ - void ConfigureShard(int device) { + /*! \brief Reconfigure the device when GPU is changed. */ + void ConfigureDevice(int device) { if (device_ == device) return; - device_ = device; - shard_.Init(device_); + if (device_ >= 0) { + max_shared_memory_bytes_ = dh::MaxSharedMemory(device_); + } } - DeviceShard shard_; int device_; common::Monitor monitor_; + dh::device_vector nodes_; + dh::device_vector tree_segments_; + dh::device_vector tree_group_; + size_t max_shared_memory_bytes_; + size_t tree_begin_; + size_t tree_end_; + int num_group_; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9f79b5b19..e0c3e6209 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -435,7 +435,7 @@ __global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix, // Manage memory for a single GPU template -struct DeviceShard { +struct GPUHistMakerDevice { int device_id; EllpackPageImpl* page; @@ -474,12 +474,12 @@ struct DeviceShard { std::function>; std::unique_ptr qexpand; - DeviceShard(int _device_id, - EllpackPageImpl* _page, - bst_uint _n_rows, - TrainParam _param, - uint32_t column_sampler_seed, - uint32_t n_features) + GPUHistMakerDevice(int _device_id, + EllpackPageImpl* _page, + bst_uint _n_rows, + TrainParam _param, + uint32_t column_sampler_seed, + uint32_t n_features) : device_id(_device_id), page(_page), n_rows(_n_rows), @@ -487,12 +487,12 @@ struct DeviceShard { prediction_cache_initialised(false), column_sampler(column_sampler_seed), interaction_constraints(param, n_features) { - monitor.Init(std::string("DeviceShard") + std::to_string(device_id)); + monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } void InitHistogram(); - ~DeviceShard() { // NOLINT + ~GPUHistMakerDevice() { // NOLINT dh::safe_cuda(cudaSetDevice(device_id)); for (auto& stream : streams) { dh::safe_cuda(cudaStreamDestroy(stream)); @@ -781,7 +781,7 @@ struct DeviceShard { auto left_node_rows = row_partitioner->GetRows(nidx_left).size(); auto right_node_rows = row_partitioner->GetRows(nidx_right).size(); // Decide whether to build the left histogram or right histogram - // Find the largest number of training instances on any given Shard + // Find the largest number of training instances on any given device // Assume this will be the bottleneck and avoid building this node if // possible std::vector max_reduce; @@ -939,7 +939,7 @@ struct DeviceShard { }; template -inline void DeviceShard::InitHistogram() { +inline void GPUHistMakerDevice::InitHistogram() { CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) << "Max leaves and max depth cannot both be unconstrained for " "gpu_hist."; @@ -1026,19 +1026,17 @@ class GPUHistMakerSpecialised { page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows); } - // Create device shard dh::safe_cuda(cudaSetDevice(device_)); - shard_.reset(new DeviceShard(device_, - page, - info_->num_row_, - param_, - column_sampling_seed, - info_->num_col_)); + maker_.reset(new GPUHistMakerDevice(device_, + page, + info_->num_row_, + param_, + column_sampling_seed, + info_->num_col_)); - // Init global data for each shard monitor_.StartCuda("InitHistogram"); dh::safe_cuda(cudaSetDevice(device_)); - shard_->InitHistogram(); + maker_->InitHistogram(); monitor_.StopCuda("InitHistogram"); p_last_fmat_ = dmat; @@ -1077,18 +1075,17 @@ class GPUHistMakerSpecialised { monitor_.StopCuda("InitData"); gpair->SetDevice(device_); - shard_->UpdateTree(gpair, p_fmat, p_tree, &reducer_); + maker_->UpdateTree(gpair, p_fmat, p_tree, &reducer_); } bool UpdatePredictionCache( const DMatrix* data, HostDeviceVector* p_out_preds) { - if (shard_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { + if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } monitor_.StartCuda("UpdatePredictionCache"); p_out_preds->SetDevice(device_); - dh::safe_cuda(cudaSetDevice(shard_->device_id)); - shard_->UpdatePredictionCache(p_out_preds->DevicePointer()); + maker_->UpdatePredictionCache(p_out_preds->DevicePointer()); monitor_.StopCuda("UpdatePredictionCache"); return true; } @@ -1096,7 +1093,7 @@ class GPUHistMakerSpecialised { TrainParam param_; // NOLINT MetaInfo* info_{}; // NOLINT - std::unique_ptr> shard_; // NOLINT + std::unique_ptr> maker_; // NOLINT private: bool initialised_; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 5c92f4b9e..261a9b898 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -71,40 +71,6 @@ class HistogramCutsWrapper : public common::HistogramCuts { }; } // anonymous namespace - -template -void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, - bst_float sparsity=0) { - auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3); - const SparsePage& batch = *(*dmat)->GetBatches().begin(); - - HistogramCutsWrapper cmat; - cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); - // 24 cut fields, 3 cut fields for each feature (column). - cmat.SetValues({0.30f, 0.67f, 1.64f, - 0.32f, 0.77f, 1.95f, - 0.29f, 0.70f, 1.80f, - 0.32f, 0.75f, 1.85f, - 0.18f, 0.59f, 1.69f, - 0.25f, 0.74f, 2.00f, - 0.26f, 0.74f, 1.98f, - 0.26f, 0.71f, 1.83f}); - cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); - - auto is_dense = (*dmat)->Info().num_nonzero_ == - (*dmat)->Info().num_row_ * (*dmat)->Info().num_col_; - size_t row_stride = 0; - const auto &offset_vec = batch.offset.ConstHostVector(); - for (size_t i = 1; i < offset_vec.size(); ++i) { - row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]); - } - shard->InitHistogram(cmat, row_stride, is_dense); - shard->CreateHistIndices( - batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1); - - delete dmat; -} - std::vector GetHostHistGpair() { // 24 bins, 3 bins for each feature (column). std::vector hist_gpair = { @@ -131,9 +97,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { }; param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); - DeviceShard shard(0, page.get(), kNRows, param, kNCols, kNCols); - shard.InitHistogram(); - + GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols); + maker.InitHistogram(); + xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); std::vector h_gpair(kNRows); @@ -150,13 +116,13 @@ void TestBuildHist(bool use_shared_memory_histograms) { sizeof(common::CompressedByteT) * page->gidx_buffer.size(), cudaMemcpyDeviceToHost)); - shard.row_partitioner.reset(new RowPartitioner(0, kNRows)); - shard.hist.AllocateHistogram(0); - dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair); + maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); + maker.hist.AllocateHistogram(0); + dh::CopyVectorToDeviceSpan(maker.gpair, h_gpair); - shard.use_shared_memory_histograms = use_shared_memory_histograms; - shard.BuildHist(0); - DeviceHistogram d_hist = shard.hist; + maker.use_shared_memory_histograms = use_shared_memory_histograms; + maker.BuildHist(0); + DeviceHistogram d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -230,30 +196,29 @@ TEST(GpuHist, EvaluateSplits) { int max_bins = 4; - // Initialize DeviceShard + // Initialize GPUHistMakerDevice auto page = BuildEllpackPage(kNRows, kNCols); - std::unique_ptr> shard{ - new DeviceShard(0, page.get(), kNRows, param, kNCols, kNCols)}; - // Initialize DeviceShard::node_sum_gradients - shard->node_sum_gradients = {{6.4f, 12.8f}}; + GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols); + // Initialize GPUHistMakerDevice::node_sum_gradients + maker.node_sum_gradients = {{6.4f, 12.8f}}; - // Initialize DeviceShard::cut + // Initialize GPUHistMakerDevice::cut auto cmat = GetHostCutMatrix(); // Copy cut matrix to device. - shard->ba.Allocate(0, - &(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(), - &(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(), - &(page->ellpack_matrix.gidx_fvalue_map), 24, - &(shard->monotone_constraints), kNCols); + maker.ba.Allocate(0, + &(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(), + &(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(), + &(page->ellpack_matrix.gidx_fvalue_map), 24, + &(maker.monotone_constraints), kNCols); dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs()); dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values()); - dh::CopyVectorToDeviceSpan(shard->monotone_constraints, param.monotone_constraints); + dh::CopyVectorToDeviceSpan(maker.monotone_constraints, param.monotone_constraints); dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues()); - // Initialize DeviceShard::hist - shard->hist.Init(0, (max_bins - 1) * kNCols); - shard->hist.AllocateHistogram(0); + // Initialize GPUHistMakerDevice::hist + maker.hist.Init(0, (max_bins - 1) * kNCols); + maker.hist.AllocateHistogram(0); // Each row of hist_gpair represents gpairs for one feature. // Each entry represents a bin. std::vector hist_gpair = GetHostHistGpair(); @@ -263,27 +228,26 @@ TEST(GpuHist, EvaluateSplits) { hist.push_back(pair.GetHess()); } - ASSERT_EQ(shard->hist.Data().size(), hist.size()); + ASSERT_EQ(maker.hist.Data().size(), hist.size()); thrust::copy(hist.begin(), hist.end(), - shard->hist.Data().begin()); + maker.hist.Data().begin()); - shard->column_sampler.Init(kNCols, - param.colsample_bynode, - param.colsample_bylevel, - param.colsample_bytree, - false); + maker.column_sampler.Init(kNCols, + param.colsample_bynode, + param.colsample_bylevel, + param.colsample_bytree, + false); RegTree tree; MetaInfo info; info.num_row_ = kNRows; info.num_col_ = kNCols; - shard->node_value_constraints.resize(1); - shard->node_value_constraints[0].lower_bound = -1.0; - shard->node_value_constraints[0].upper_bound = 1.0; + maker.node_value_constraints.resize(1); + maker.node_value_constraints[0].lower_bound = -1.0; + maker.node_value_constraints[0].upper_bound = 1.0; - std::vector res = - shard->EvaluateSplits({ 0,0 }, tree, kNCols); + std::vector res = maker.EvaluateSplits({0, 0 }, tree, kNCols); ASSERT_EQ(res[0].findex, 7); ASSERT_EQ(res[1].findex, 7); @@ -316,18 +280,18 @@ void TestHistogramIndexImpl() { hist_maker_ext.Configure(training_params, &generic_param); hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); - // Extract the device shard from the histogram makers and from that its compressed + // Extract the device maker from the histogram makers and from that its compressed // histogram index - const auto &dev_shard = hist_maker.shard_; - std::vector h_gidx_buffer(dev_shard->page->gidx_buffer.size()); - dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->page->gidx_buffer); + const auto &maker = hist_maker.maker_; + std::vector h_gidx_buffer(maker->page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer, maker->page->gidx_buffer); - const auto &dev_shard_ext = hist_maker_ext.shard_; - std::vector h_gidx_buffer_ext(dev_shard_ext->page->gidx_buffer.size()); - dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->page->gidx_buffer); + const auto &maker_ext = hist_maker_ext.maker_; + std::vector h_gidx_buffer_ext(maker_ext->page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, maker_ext->page->gidx_buffer); - ASSERT_EQ(dev_shard->page->n_bins, dev_shard_ext->page->n_bins); - ASSERT_EQ(dev_shard->page->gidx_buffer.size(), dev_shard_ext->page->gidx_buffer.size()); + ASSERT_EQ(maker->page->n_bins, maker_ext->page->n_bins); + ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size()); ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext); }