From 3f312e30db5973a7b19ac60c6a326a19c22255dd Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 28 Mar 2019 13:59:58 +1300 Subject: [PATCH] Retire DVec class in favour of c++20 style span for device memory. (#4293) --- src/common/device_helpers.cuh | 319 ++++++++---------------- src/linear/updater_gpu_coordinate.cu | 36 +-- src/tree/updater_gpu.cu | 120 +++++---- src/tree/updater_gpu_common.cuh | 17 +- src/tree/updater_gpu_hist.cu | 85 ++++--- tests/cpp/common/test_device_helpers.cu | 19 ++ tests/cpp/tree/test_gpu_hist.cu | 61 +++-- 7 files changed, 288 insertions(+), 369 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 325130c28..fdcc7b41b 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -227,179 +227,79 @@ inline void LaunchN(int device_idx, size_t n, L lambda) { LaunchN(device_idx, n, nullptr, lambda); } -/* - * Memory + +/** + * \brief A double buffer, useful for algorithms like sort. */ - -enum MemoryType { kDevice, kDeviceManaged }; - -template -class BulkAllocator; template -class DVec2; - -template -class DVec { - friend class DVec2; - - private: - T *ptr_; - size_t size_; - int device_idx_; - +class DoubleBuffer { public: - void ExternalAllocate(int device_idx, void *ptr, size_t size) { - if (!Empty()) { - throw std::runtime_error("Tried to allocate DVec but already allocated"); - } - ptr_ = static_cast(ptr); - size_ = size; - device_idx_ = device_idx; - safe_cuda(cudaSetDevice(device_idx_)); + cub::DoubleBuffer buff; + xgboost::common::Span a, b; + DoubleBuffer() = default; + + size_t Size() const { + CHECK_EQ(a.size(), b.size()); + return a.size(); + } + cub::DoubleBuffer &CubBuffer() { return buff; } + + T *Current() { return buff.Current(); } + xgboost::common::Span CurrentSpan() { + return xgboost::common::Span{ + buff.Current(), + static_cast::index_type>(Size())}; } - DVec() : ptr_(NULL), size_(0), device_idx_(-1) {} - size_t Size() const { return size_; } - int DeviceIdx() const { return device_idx_; } - bool Empty() const { return ptr_ == NULL || size_ == 0; } - - T *Data() { return ptr_; } - - const T *Data() const { return ptr_; } - - xgboost::common::Span GetSpan() const { - return xgboost::common::Span(ptr_, this->Size()); - } - - xgboost::common::Span GetSpan() { - return xgboost::common::Span(ptr_, this->Size()); - } - - std::vector AsVector() const { - std::vector h_vector(Size()); - safe_cuda(cudaSetDevice(device_idx_)); - safe_cuda(cudaMemcpy(h_vector.data(), ptr_, Size() * sizeof(T), - cudaMemcpyDeviceToHost)); - return h_vector; - } - - void Fill(T value) { - auto d_ptr = ptr_; - LaunchN(device_idx_, Size(), - [=] __device__(size_t idx) { d_ptr[idx] = value; }); - } - - void Print() { - auto h_vector = this->AsVector(); - for (auto e : h_vector) { - std::cout << e << " "; - } - std::cout << "\n"; - } - - thrust::device_ptr tbegin() { return thrust::device_pointer_cast(ptr_); } - - thrust::device_ptr tend() { - return thrust::device_pointer_cast(ptr_ + Size()); - } - - template - DVec &operator=(const std::vector &other) { - this->copy(other.begin(), other.end()); - return *this; - } - - DVec &operator=(DVec &other) { - if (other.Size() != Size()) { - throw std::runtime_error( - "Cannot copy assign DVec to DVec, sizes are different"); - } - safe_cuda(cudaSetDevice(this->DeviceIdx())); - if (other.DeviceIdx() == this->DeviceIdx()) { - dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(), - other.Size() * sizeof(T), - cudaMemcpyDeviceToDevice)); - } else { - std::cout << "deviceother: " << other.DeviceIdx() - << " devicethis: " << this->DeviceIdx() << std::endl; - std::cout << "size deviceother: " << other.Size() - << " devicethis: " << this->DeviceIdx() << std::endl; - throw std::runtime_error("Cannot copy to/from different devices"); - } - - return *this; - } - - template - void copy(IterT begin, IterT end) { - safe_cuda(cudaSetDevice(this->DeviceIdx())); - if (end - begin != Size()) { - LOG(FATAL) << "Cannot copy assign vector to DVec, sizes are different" << - " vector::Size(): " << end - begin << " DVec::Size(): " << Size(); - } - thrust::copy(begin, end, this->tbegin()); - } - - void copy(thrust::device_ptr begin, thrust::device_ptr end) { - safe_cuda(cudaSetDevice(this->DeviceIdx())); - if (end - begin != Size()) { - throw std::runtime_error( - "Cannot copy assign vector to dvec, sizes are different"); - } - safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T), - cudaMemcpyDefault)); - } + T *other() { return buff.Alternate(); } }; /** - * @class DVec2 device_helpers.cuh - * @brief wrapper for storing 2 DVec's which are needed for cub::DoubleBuffer + * \brief Copies device span to std::vector. + * + * \tparam T Generic type parameter. + * \param [in,out] dst Copy destination. + * \param src Copy source. Must be device memory. */ template -class DVec2 { - private: - DVec d1_, d2_; - cub::DoubleBuffer buff_; - int device_idx_; +void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span src) { + CHECK_EQ(dst->size(), src.size()); + dh::safe_cuda(cudaMemcpyAsync(dst->data(), src.data(), dst->size() * sizeof(T), + cudaMemcpyDeviceToHost)); +} - public: - void ExternalAllocate(int device_idx, void *ptr1, void *ptr2, size_t size) { - if (!Empty()) { - throw std::runtime_error("Tried to allocate DVec2 but already allocated"); - } - device_idx_ = device_idx; - d1_.ExternalAllocate(device_idx_, ptr1, size); - d2_.ExternalAllocate(device_idx_, ptr2, size); - buff_.d_buffers[0] = static_cast(ptr1); - buff_.d_buffers[1] = static_cast(ptr2); - buff_.selector = 0; - } - DVec2() : d1_(), d2_(), buff_(), device_idx_(-1) {} +/** + * \brief Copies std::vector to device span. + * + * \tparam T Generic type parameter. + * \param dst Copy destination. Must be device memory. + * \param src Copy source. + */ +template +void CopyVectorToDeviceSpan(xgboost::common::Span dst ,const std::vector&src) +{ + CHECK_EQ(dst.size(), src.size()); + dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T), + cudaMemcpyHostToDevice)); +} - size_t Size() const { return d1_.Size(); } - int DeviceIdx() const { return device_idx_; } - bool Empty() const { return d1_.Empty() || d2_.Empty(); } - - cub::DoubleBuffer &buff() { return buff_; } - - DVec &D1() { return d1_; } - - DVec &D2() { return d2_; } - - T *Current() { return buff_.Current(); } - xgboost::common::Span CurrentSpan() { - return xgboost::common::Span{ - buff_.Current(), - static_cast::index_type>(Size())}; - } - - DVec &CurrentDVec() { return buff_.selector == 0 ? D1() : D2(); } - - T *other() { return buff_.Alternate(); } -}; +/** + * \brief Device to device memory copy from src to dst. Spans must be the same size. Use subspan to + * copy from a smaller array to a larger array. + * + * \tparam T Generic type parameter. + * \param dst Copy destination. Must be device memory. + * \param src Copy source. Must be device memory. + */ +template +void CopyDeviceSpan(xgboost::common::Span dst, + xgboost::common::Span src) { + CHECK_EQ(dst.size(), src.size()); + dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T), + cudaMemcpyDeviceToDevice)); +} /*! \brief Helper for allocating large block of memory. */ -template class BulkAllocator { std::vector d_ptr_; std::vector size_; @@ -413,70 +313,73 @@ class BulkAllocator { } template - size_t GetSizeBytes(DVec *first_vec, size_t first_size) { + size_t GetSizeBytes(xgboost::common::Span *first_vec, size_t first_size) { return AlignRoundUp(first_size * sizeof(T)); } template - size_t GetSizeBytes(DVec *first_vec, size_t first_size, Args... args) { + size_t GetSizeBytes(xgboost::common::Span *first_vec, size_t first_size, Args... args) { return GetSizeBytes(first_vec, first_size) + GetSizeBytes(args...); } template - void AllocateDVec(int device_idx, char *ptr, DVec *first_vec, - size_t first_size) { - first_vec->ExternalAllocate(device_idx, static_cast(ptr), - first_size); + void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span *first_vec, + size_t first_size) { + *first_vec = xgboost::common::Span(reinterpret_cast(ptr), first_size); } template - void AllocateDVec(int device_idx, char *ptr, DVec *first_vec, - size_t first_size, Args... args) { - AllocateDVec(device_idx, ptr, first_vec, first_size); + void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span *first_vec, + size_t first_size, Args... args) { + AllocateSpan(device_idx, ptr, first_vec, first_size); ptr += AlignRoundUp(first_size * sizeof(T)); - AllocateDVec(device_idx, ptr, args...); + AllocateSpan(device_idx, ptr, args...); } - char *AllocateDevice(int device_idx, size_t bytes, MemoryType t) { + char *AllocateDevice(int device_idx, size_t bytes) { char *ptr; safe_cuda(cudaSetDevice(device_idx)); safe_cuda(cudaMalloc(&ptr, bytes)); return ptr; } + template - size_t GetSizeBytes(DVec2 *first_vec, size_t first_size) { + size_t GetSizeBytes(DoubleBuffer *first_vec, size_t first_size) { return 2 * AlignRoundUp(first_size * sizeof(T)); } template - size_t GetSizeBytes(DVec2 *first_vec, size_t first_size, Args... args) { + size_t GetSizeBytes(DoubleBuffer *first_vec, size_t first_size, Args... args) { return GetSizeBytes(first_vec, first_size) + GetSizeBytes(args...); } template - void AllocateDVec(int device_idx, char *ptr, DVec2 *first_vec, - size_t first_size) { - first_vec->ExternalAllocate( - device_idx, static_cast(ptr), - static_cast(ptr + AlignRoundUp(first_size * sizeof(T))), - first_size); + void AllocateSpan(int device_idx, char *ptr, DoubleBuffer *first_vec, + size_t first_size) { + auto ptr1 = reinterpret_cast(ptr); + auto ptr2 = ptr1 + first_size; + first_vec->a = xgboost::common::Span(ptr1, first_size); + first_vec->b = xgboost::common::Span(ptr2, first_size); + first_vec->buff.d_buffers[0] = ptr1; + first_vec->buff.d_buffers[1] = ptr2; + first_vec->buff.selector = 0; } template - void AllocateDVec(int device_idx, char *ptr, DVec2 *first_vec, + void AllocateSpan(int device_idx, char *ptr, DoubleBuffer *first_vec, size_t first_size, Args... args) { - AllocateDVec(device_idx, ptr, first_vec, first_size); + AllocateSpan(device_idx, ptr, first_vec, first_size); ptr += (AlignRoundUp(first_size * sizeof(T)) * 2); - AllocateDVec(device_idx, ptr, args...); + AllocateSpan(device_idx, ptr, args...); } public: BulkAllocator() = default; // prevent accidental copying, moving or assignment of this object - BulkAllocator(const BulkAllocator&) = delete; - BulkAllocator(BulkAllocator&&) = delete; - void operator=(const BulkAllocator&) = delete; - void operator=(BulkAllocator&&) = delete; + BulkAllocator(const BulkAllocator&) = delete; + BulkAllocator(BulkAllocator&&) = delete; + void operator=(const BulkAllocator&) = delete; + void operator=(BulkAllocator&&) = delete; ~BulkAllocator() { for (size_t i = 0; i < d_ptr_.size(); i++) { @@ -497,9 +400,9 @@ class BulkAllocator { void Allocate(int device_idx, Args... args) { size_t size = GetSizeBytes(args...); - char *ptr = AllocateDevice(device_idx, size, MemoryT); + char *ptr = AllocateDevice(device_idx, size); - AllocateDVec(device_idx, ptr, args...); + AllocateSpan(device_idx, ptr, args...); d_ptr_.push_back(ptr); size_.push_back(size); @@ -582,28 +485,6 @@ struct CubMemory { * Utility functions */ -template -void Print(const DVec &v, size_t max_items = 10) { - std::vector h = v.as_vector(); - for (size_t i = 0; i < std::min(max_items, h.size()); i++) { - std::cout << " " << h[i]; - } - std::cout << "\n"; -} - -/** - * @brief Helper macro to measure timing on GPU - * @param call the GPU call - * @param name name used to track later - * @param stream cuda stream where to measure time - */ -#define TIMEIT(call, name) \ - do { \ - dh::Timer t1234; \ - call; \ - t1234.printElapsed(name); \ - } while (0) - // Load balancing search template @@ -762,18 +643,18 @@ void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count, * @param offsets the segments */ template -void SegmentedSort(dh::CubMemory *tmp_mem, dh::DVec2 *keys, - dh::DVec2 *vals, int nVals, int nSegs, - const dh::DVec &offsets, int start = 0, +void SegmentedSort(dh::CubMemory *tmp_mem, dh::DoubleBuffer *keys, + dh::DoubleBuffer *vals, int nVals, int nSegs, + xgboost::common::Span offsets, int start = 0, int end = sizeof(T1) * 8) { size_t tmpSize; dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs( - NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.Data(), - offsets.Data() + 1, start, end)); + NULL, tmpSize, keys->CubBuffer(), vals->CubBuffer(), nVals, nSegs, + offsets.data(), offsets.data() + 1, start, end)); tmp_mem->LazyAllocate(tmpSize); dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs( - tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals, - nSegs, offsets.Data(), offsets.Data() + 1, start, end)); + tmp_mem->d_temp_storage, tmpSize, keys->CubBuffer(), vals->CubBuffer(), + nVals, nSegs, offsets.data(), offsets.data() + 1, start, end)); } /** @@ -784,14 +665,14 @@ void SegmentedSort(dh::CubMemory *tmp_mem, dh::DVec2 *keys, * @param nVals number of elements in the input array */ template -void SumReduction(dh::CubMemory &tmp_mem, dh::DVec &in, dh::DVec &out, +void SumReduction(dh::CubMemory &tmp_mem, xgboost::common::Span in, xgboost::common::Span out, int nVals) { size_t tmpSize; dh::safe_cuda( - cub::DeviceReduce::Sum(NULL, tmpSize, in.Data(), out.Data(), nVals)); + cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals)); tmp_mem.LazyAllocate(tmpSize); dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize, - in.Data(), out.Data(), nVals)); + in.data(), out.data(), nVals)); } /** diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index fa9146e18..0874cef50 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -19,18 +19,18 @@ namespace linear { DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate); -void RescaleIndices(size_t ridx_begin, dh::DVec *data) { - auto d_data = data->GetSpan(); - dh::LaunchN(data->DeviceIdx(), data->Size(), - [=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; }); +void RescaleIndices(int device_idx, size_t ridx_begin, + common::Span data) { + dh::LaunchN(device_idx, data.size(), + [=] __device__(size_t idx) { data[idx].index -= ridx_begin; }); } class DeviceShard { int device_id_; - dh::BulkAllocator ba_; + dh::BulkAllocator ba_; std::vector row_ptr_; - dh::DVec data_; - dh::DVec gpair_; + common::Span data_; + common::Span gpair_; dh::CubMemory temp_; size_t ridx_begin_; size_t ridx_end_; @@ -73,12 +73,12 @@ class DeviceShard { auto col = batch[fidx]; auto seg = column_segments[fidx]; dh::safe_cuda(cudaMemcpy( - data_.GetSpan().subspan(row_ptr_[fidx]).data(), + data_.subspan(row_ptr_[fidx]).data(), col.data() + seg.first, sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice)); } // Rescale indices with respect to current shard - RescaleIndices(ridx_begin_, &data_); + RescaleIndices(device_id_, ridx_begin_, data_); } bool IsEmpty() { @@ -87,8 +87,10 @@ class DeviceShard { void UpdateGpair(const std::vector &host_gpair, const gbm::GBLinearModelParam &model_param) { - gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group, - host_gpair.begin() + ridx_end_ * model_param.num_output_group); + dh::safe_cuda(cudaMemcpyAsync( + gpair_.data(), + host_gpair.data() + ridx_begin_ * model_param.num_output_group, + gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice)); } GradientPair GetBiasGradient(int group_idx, int num_group) { @@ -99,14 +101,14 @@ class DeviceShard { }; // NOLINT thrust::transform_iterator skip( counting, f); - auto perm = thrust::make_permutation_iterator(gpair_.tbegin(), skip); + auto perm = thrust::make_permutation_iterator(gpair_.data(), skip); return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_); } void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { if (dbias == 0.0f) return; - auto d_gpair = gpair_.GetSpan(); + auto d_gpair = gpair_; dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) { auto &g = d_gpair[idx * num_groups + group_idx]; g += GradientPair(g.GetHess() * dbias, 0); @@ -115,9 +117,9 @@ class DeviceShard { GradientPair GetGradient(int group_idx, int num_group, int fidx) { dh::safe_cuda(cudaSetDevice(device_id_)); - common::Span d_col = data_.GetSpan().subspan(row_ptr_[fidx]); + common::Span d_col = data_.subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; - common::Span d_gpair = gpair_.GetSpan(); + common::Span d_gpair = gpair_; auto counting = thrust::make_counting_iterator(0ull); auto f = [=] __device__(size_t idx) { auto entry = d_col[idx]; @@ -131,8 +133,8 @@ class DeviceShard { } void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) { - common::Span d_gpair = gpair_.GetSpan(); - common::Span d_col = data_.GetSpan().subspan(row_ptr_[fidx]); + common::Span d_gpair = gpair_; + common::Span d_col = data_.subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) { auto entry = d_col[idx]; diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index e9745b78e..3348ebf7c 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -545,21 +545,21 @@ class GPUMaker : public TreeUpdater { /** whether we have initialized memory already (so as not to repeat!) */ bool allocated_; /** feature values stored in column-major compressed format */ - dh::DVec2 vals_; - dh::DVec vals_cached_; + dh::DoubleBuffer vals_; + common::Span vals_cached_; /** corresponding instance id's of these featutre values */ - dh::DVec2 instIds_; - dh::DVec inst_ids_cached_; + dh::DoubleBuffer instIds_; + common::Span inst_ids_cached_; /** column offsets for these feature values */ - dh::DVec colOffsets_; - dh::DVec gradsInst_; - dh::DVec2 nodeAssigns_; - dh::DVec2 nodeLocations_; - dh::DVec nodes_; - dh::DVec node_assigns_per_inst_; - dh::DVec gradsums_; - dh::DVec gradscans_; - dh::DVec nodeSplits_; + common::Span colOffsets_; + common::Span gradsInst_; + dh::DoubleBuffer nodeAssigns_; + dh::DoubleBuffer nodeLocations_; + common::Span nodes_; + common::Span node_assigns_per_inst_; + common::Span gradsums_; + common::Span gradscans_; + common::Span nodeSplits_; int n_vals_; int n_rows_; int n_cols_; @@ -571,10 +571,10 @@ class GPUMaker : public TreeUpdater { GPUSet devices_; dh::CubMemory tmp_mem_; - dh::DVec tmpScanGradBuff_; - dh::DVec tmp_scan_key_buff_; - dh::DVec colIds_; - dh::BulkAllocator ba_; + common::Span tmpScanGradBuff_; + common::Span tmp_scan_key_buff_; + common::Span colIds_; + dh::BulkAllocator ba_; public: GPUMaker() : allocated_{false} {} @@ -615,8 +615,8 @@ class GPUMaker : public TreeUpdater { for (int i = 0; i < param_.max_depth; ++i) { if (i == 0) { // make sure to start on a fresh tree with sorted values! - vals_.CurrentDVec() = vals_cached_; - instIds_.CurrentDVec() = inst_ids_cached_; + dh::CopyDeviceSpan(vals_.CurrentSpan(), vals_cached_); + dh::CopyDeviceSpan(instIds_.CurrentSpan(), inst_ids_cached_); TransferGrads(gpair); } int nNodes = 1 << i; @@ -630,13 +630,13 @@ class GPUMaker : public TreeUpdater { } void Split2Node(int nNodes, NodeIdT nodeStart) { - auto d_nodes = nodes_.GetSpan(); - auto d_gradScans = gradscans_.GetSpan(); - auto d_gradsums = gradsums_.GetSpan(); + auto d_nodes = nodes_; + auto d_gradScans = gradscans_; + auto d_gradsums = gradsums_; auto d_nodeAssigns = nodeAssigns_.CurrentSpan(); - auto d_colIds = colIds_.GetSpan(); + auto d_colIds = colIds_; auto d_vals = vals_.Current(); - auto d_nodeSplits = nodeSplits_.Data(); + auto d_nodeSplits = nodeSplits_.data(); int nUniqKeys = nNodes; float min_split_loss = param_.min_split_loss; auto gpu_param = GPUTrainingParam(param_); @@ -679,13 +679,13 @@ class GPUMaker : public TreeUpdater { } void FindSplit(int level, NodeIdT nodeStart, int nNodes) { - ReduceScanByKey(gradsums_.GetSpan(), gradscans_.GetSpan(), gradsInst_.GetSpan(), + ReduceScanByKey(gradsums_, gradscans_, gradsInst_, instIds_.CurrentSpan(), nodeAssigns_.CurrentSpan(), n_vals_, nNodes, - n_cols_, tmpScanGradBuff_.GetSpan(), tmp_scan_key_buff_.GetSpan(), - colIds_.GetSpan(), nodeStart); - ArgMaxByKey(nodeSplits_.GetSpan(), gradscans_.GetSpan(), gradsums_.GetSpan(), - vals_.CurrentSpan(), colIds_.GetSpan(), nodeAssigns_.CurrentSpan(), - nodes_.GetSpan(), nNodes, nodeStart, n_vals_, param_, + n_cols_, tmpScanGradBuff_, tmp_scan_key_buff_, + colIds_, nodeStart); + ArgMaxByKey(nodeSplits_, gradscans_, gradsums_, + vals_.CurrentSpan(), colIds_, nodeAssigns_.CurrentSpan(), + nodes_, nNodes, nodeStart, n_vals_, param_, level <= kMaxAbkLevels ? kAbkSmem : kAbkGmem); Split2Node(nNodes, nodeStart); } @@ -707,7 +707,7 @@ class GPUMaker : public TreeUpdater { } std::vector fval; std::vector fId; - std::vector offset; + std::vector offset; ConvertToCsc(dmat, &fval, &fId, &offset); AllocateAllData(static_cast(offset.size())); TransferAndSortData(fval, fId, offset); @@ -715,7 +715,7 @@ class GPUMaker : public TreeUpdater { } void ConvertToCsc(DMatrix* dmat, std::vector* fval, - std::vector* fId, std::vector* offset) { + std::vector* fId, std::vector* offset) { const MetaInfo& info = dmat->Info(); CHECK(info.num_col_ < std::numeric_limits::max()); CHECK(info.num_row_ < std::numeric_limits::max()); @@ -735,7 +735,7 @@ class GPUMaker : public TreeUpdater { fval->push_back(e.fvalue); fId->push_back(inst_id); } - offset->push_back(fval->size()); + offset->push_back(static_cast(fval->size())); } } CHECK(fval->size() < std::numeric_limits::max()); @@ -744,19 +744,21 @@ class GPUMaker : public TreeUpdater { void TransferAndSortData(const std::vector& fval, const std::vector& fId, - const std::vector& offset) { - vals_.CurrentDVec() = fval; - instIds_.CurrentDVec() = fId; - colOffsets_ = offset; + const std::vector& offset) { + dh::CopyVectorToDeviceSpan(vals_.CurrentSpan(), fval); + dh::CopyVectorToDeviceSpan(instIds_.CurrentSpan(), fId); + dh::CopyVectorToDeviceSpan(colOffsets_, offset); dh::SegmentedSort(&tmp_mem_, &vals_, &instIds_, n_vals_, n_cols_, colOffsets_); - vals_cached_ = vals_.CurrentDVec(); - inst_ids_cached_ = instIds_.CurrentDVec(); - AssignColIds<<>>(colIds_.Data(), colOffsets_.Data()); + dh::CopyDeviceSpan(vals_cached_, vals_.CurrentSpan()); + dh::CopyDeviceSpan(inst_ids_cached_, instIds_.CurrentSpan()); + AssignColIds<<>>(colIds_.data(), colOffsets_.data()); } void TransferGrads(HostDeviceVector* gpair) { - gpair->GatherTo(gradsInst_.tbegin(), gradsInst_.tend()); + gpair->GatherTo( + thrust::device_pointer_cast(gradsInst_.data()), + thrust::device_pointer_cast(gradsInst_.data() + gradsInst_.size())); // evaluate the full-grad reduction for the root node dh::SumReduction(tmp_mem_, gradsInst_, gradsums_, n_rows_); } @@ -764,14 +766,22 @@ class GPUMaker : public TreeUpdater { void InitNodeData(int level, NodeIdT nodeStart, int nNodes) { // all instances belong to root node at the beginning! if (level == 0) { - nodes_.Fill(DeviceNodeStats()); - nodeAssigns_.CurrentDVec().Fill(0); - node_assigns_per_inst_.Fill(0); + thrust::fill(thrust::device_pointer_cast(nodes_.data()), + thrust::device_pointer_cast(nodes_.data() + nodes_.size()), + DeviceNodeStats()); + thrust::fill(thrust::device_pointer_cast(nodeAssigns_.Current()), + thrust::device_pointer_cast(nodeAssigns_.Current() + + nodeAssigns_.Size()), + 0); + thrust::fill(thrust::device_pointer_cast(node_assigns_per_inst_.data()), + thrust::device_pointer_cast(node_assigns_per_inst_.data() + + node_assigns_per_inst_.size()), + 0); // for root node, just update the gradient/score/weight/id info // before splitting it! Currently all data is on GPU, hence this // stupid little kernel - auto d_nodes = nodes_.Data(); - auto d_sums = gradsums_.Data(); + auto d_nodes = nodes_; + auto d_sums = gradsums_; auto gpu_params = GPUTrainingParam(param_); dh::LaunchN(param_.gpu_id, 1, [=] __device__(int idx) { d_nodes[0] = DeviceNodeStats(d_sums[0], 0, gpu_params); @@ -781,17 +791,17 @@ class GPUMaker : public TreeUpdater { const int ItemsPerThread = 4; // assign default node ids first int nBlks = dh::DivRoundUp(n_rows_, BlkDim); - FillDefaultNodeIds<<>>(node_assigns_per_inst_.Data(), - nodes_.Data(), n_rows_); + FillDefaultNodeIds<<>>(node_assigns_per_inst_.data(), + nodes_.data(), n_rows_); // evaluate the correct child indices of non-missing values next nBlks = dh::DivRoundUp(n_vals_, BlkDim * ItemsPerThread); AssignNodeIds<<>>( - node_assigns_per_inst_.Data(), nodeLocations_.Current(), - nodeAssigns_.Current(), instIds_.Current(), nodes_.Data(), - colOffsets_.Data(), vals_.Current(), n_vals_, n_cols_); + node_assigns_per_inst_.data(), nodeLocations_.Current(), + nodeAssigns_.Current(), instIds_.Current(), nodes_.data(), + colOffsets_.data(), vals_.Current(), n_vals_, n_cols_); // gather the node assignments across all other columns too dh::Gather(param_.gpu_id, nodeAssigns_.Current(), - node_assigns_per_inst_.Data(), instIds_.Current(), n_vals_); + node_assigns_per_inst_.data(), instIds_.Current(), n_vals_); SortKeys(level); } } @@ -804,14 +814,14 @@ class GPUMaker : public TreeUpdater { dh::Gather(param_.gpu_id, vals_.other(), vals_.Current(), instIds_.other(), instIds_.Current(), nodeLocations_.Current(), n_vals_); - vals_.buff().selector ^= 1; - instIds_.buff().selector ^= 1; + vals_.buff.selector ^= 1; + instIds_.buff.selector ^= 1; } void MarkLeaves() { const int BlkDim = 128; int nBlks = dh::DivRoundUp(maxNodes_, BlkDim); - MarkLeavesKernel<<>>(nodes_.Data(), maxNodes_); + MarkLeavesKernel<<>>(nodes_.data(), maxNodes_); } }; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 4c928c26e..25a744544 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -254,10 +254,13 @@ XGBOOST_DEVICE inline bool IsLeftChild(int nidx) { // Copy gpu dense representation of tree to xgboost sparse representation inline void Dense2SparseTree(RegTree* p_tree, - const dh::DVec& nodes, + common::Span nodes, const TrainParam& param) { RegTree& tree = *p_tree; - std::vector h_nodes = nodes.AsVector(); + std::vector h_nodes(nodes.size()); + dh::safe_cuda(cudaMemcpy(h_nodes.data(), nodes.data(), + nodes.size() * sizeof(DeviceNodeStats), + cudaMemcpyDeviceToHost)); int nid = 0; for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { @@ -298,18 +301,16 @@ struct BernoulliRng { }; // Set gradient pair to 0 with p = 1 - subsample -inline void SubsampleGradientPair(dh::DVec* p_gpair, float subsample, - int offset = 0) { +inline void SubsampleGradientPair(int device_idx, + common::Span d_gpair, + float subsample, int offset = 0) { if (subsample == 1.0) { return; } - dh::DVec& gpair = *p_gpair; - - auto d_gpair = gpair.Data(); BernoulliRng rng(subsample, common::GlobalRandom()()); - dh::LaunchN(gpair.DeviceIdx(), gpair.Size(), [=] XGBOOST_DEVICE(int i) { + dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) { if (!rng(i + offset)) { d_gpair[i] = GradientPair(); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index f1ddef018..75402ffab 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -601,7 +601,7 @@ struct DeviceShard { int n_bins; int device_id; - dh::BulkAllocator ba; + dh::BulkAllocator ba; ELLPackMatrix ellpack_matrix; @@ -610,27 +610,26 @@ struct DeviceShard { DeviceHistogram hist; /*! \brief row_ptr form HistCutMatrix. */ - dh::DVec feature_segments; + common::Span feature_segments; /*! \brief minimum value for each feature. */ - dh::DVec min_fvalue; + common::Span min_fvalue; /*! \brief Cut. */ - dh::DVec gidx_fvalue_map; + common::Span gidx_fvalue_map; /*! \brief global index of histogram, which is stored in ELLPack format. */ - dh::DVec gidx_buffer; + common::Span gidx_buffer; /*! \brief Row indices relative to this shard, necessary for sorting rows. */ - dh::DVec2 ridx; + dh::DoubleBuffer ridx; + dh::DoubleBuffer position; /*! \brief Gradient pair for each row. */ - dh::DVec gpair; + common::Span gpair; - dh::DVec2 position; - - dh::DVec monotone_constraints; - dh::DVec prediction_cache; + common::Span monotone_constraints; + common::Span prediction_cache; /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; - dh::DVec node_sum_gradients_d; + common::Span node_sum_gradients_d; /*! \brief row offset in SparsePage (the input data). */ thrust::device_vector row_ptrs; /*! \brief On-device feature set, only actually used on one of the devices */ @@ -718,7 +717,9 @@ struct DeviceShard { // Reset values for each update iteration void Reset(HostDeviceVector* dh_gpair) { dh::safe_cuda(cudaSetDevice(device_id)); - position.CurrentDVec().Fill(0); + thrust::fill( + thrust::device_pointer_cast(position.Current()), + thrust::device_pointer_cast(position.Current() + position.Size()), 0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); if (left_counts.size() < 256) { @@ -727,13 +728,16 @@ struct DeviceShard { dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0, sizeof(int64_t) * left_counts.size())); } - thrust::sequence(ridx.CurrentDVec().tbegin(), ridx.CurrentDVec().tend()); + thrust::sequence( + thrust::device_pointer_cast(ridx.CurrentSpan().data()), + thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size())); std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); ridx_segments.front() = Segment(0, ridx.Size()); - this->gpair.copy(dh_gpair->tcbegin(device_id), - dh_gpair->tcend(device_id)); - SubsampleGradientPair(&gpair, param.subsample, row_begin_idx); + dh::safe_cuda(cudaMemcpyAsync( + gpair.data(), dh_gpair->ConstDevicePointer(device_id), + gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); + SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx); hist.Reset(); } @@ -788,7 +792,7 @@ struct DeviceShard { <<>>( hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix, gpu_param, d_split_candidates, value_constraints[nidx], - monotone_constraints.GetSpan()); + monotone_constraints); // Reduce over features to find best feature auto d_result = d_result_all.subspan(i, 1); @@ -943,8 +947,8 @@ struct DeviceShard { void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); if (!prediction_cache_initialised) { - dh::safe_cuda(cudaMemcpyAsync(prediction_cache.Data(), out_preds_d, - prediction_cache.Size() * sizeof(bst_float), + dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data(), out_preds_d, + prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); } prediction_cache_initialised = true; @@ -952,16 +956,16 @@ struct DeviceShard { CalcWeightTrainParam param_d(param); dh::safe_cuda( - cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(), + cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(), sizeof(GradientPair) * node_sum_gradients.size(), cudaMemcpyHostToDevice)); auto d_position = position.Current(); auto d_ridx = ridx.Current(); - auto d_node_sum_gradients = node_sum_gradients_d.Data(); - auto d_prediction_cache = prediction_cache.Data(); + auto d_node_sum_gradients = node_sum_gradients_d.data(); + auto d_prediction_cache = prediction_cache.data(); dh::LaunchN( - device_id, prediction_cache.Size(), [=] __device__(int local_idx) { + device_id, prediction_cache.size(), [=] __device__(int local_idx) { int pos = d_position[local_idx]; bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]); d_prediction_cache[d_ridx[local_idx]] += @@ -969,8 +973,8 @@ struct DeviceShard { }); dh::safe_cuda(cudaMemcpy( - out_preds_d, prediction_cache.Data(), - prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault)); + out_preds_d, prediction_cache.data(), + prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); } }; @@ -981,7 +985,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { auto segment_begin = segment.begin; auto d_node_hist = shard->hist.GetNodeHistogram(nidx); auto d_ridx = shard->ridx.Current(); - auto d_gpair = shard->gpair.Data(); + auto d_gpair = shard->gpair.data(); auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride; @@ -1006,7 +1010,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { Segment segment = shard->ridx_segments[nidx]; auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); bst_uint* d_ridx = shard->ridx.Current(); - GradientPair* d_gpair = shard->gpair.Data(); + GradientPair* d_gpair = shard->gpair.data(); size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride; auto d_matrix = shard->ellpack_matrix; @@ -1043,10 +1047,11 @@ inline void DeviceShard::InitCompressedData( &gidx_fvalue_map, hmat.cut.size(), &min_fvalue, hmat.min_val.size(), &monotone_constraints, param.monotone_constraints.size()); - gidx_fvalue_map = hmat.cut; - min_fvalue = hmat.min_val; - feature_segments = hmat.row_ptr; - monotone_constraints = param.monotone_constraints; + + dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut); + dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val); + dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr); + dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints); node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); @@ -1063,14 +1068,16 @@ inline void DeviceShard::InitCompressedData( << "Max leaves and max depth cannot both be unconstrained for " "gpu_hist."; ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes); - gidx_buffer.Fill(0); + thrust::fill( + thrust::device_pointer_cast(gidx_buffer.data()), + thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0); this->CreateHistIndices(row_batch, row_stride, null_gidx_value); ellpack_matrix.Init( - feature_segments.GetSpan(), min_fvalue.GetSpan(), - gidx_fvalue_map.GetSpan(), row_stride, - common::CompressedIterator(gidx_buffer.Data(), num_symbols), + feature_segments, min_fvalue, + gidx_fvalue_map, row_stride, + common::CompressedIterator(gidx_buffer.data(), num_symbols), is_dense, null_gidx_value); // check if we can use shared memory for building histograms @@ -1121,10 +1128,10 @@ inline void DeviceShard::CreateHistIndices( dh::DivRoundUp(row_stride, block3.y), 1); CompressBinEllpackKernel<<>> (common::CompressedBufferWriter(num_symbols), - gidx_buffer.Data(), + gidx_buffer.data(), row_ptrs.data().get() + batch_row_begin, entries_d.data().get(), - gidx_fvalue_map.Data(), feature_segments.Data(), + gidx_fvalue_map.data(), feature_segments.data(), batch_row_begin, batch_nrows, row_ptrs[batch_row_begin], row_stride, null_gidx_value); @@ -1355,7 +1362,7 @@ class GPUHistMakerSpecialised{ [&](int i, std::unique_ptr>& shard) { dh::safe_cuda(cudaSetDevice(shard->device_id)); tmp_sums[i] = dh::SumReduction( - shard->temp_memory, shard->gpair.Data(), shard->gpair.Size()); + shard->temp_memory, shard->gpair.data(), shard->gpair.size()); }); GradientPair sum_gradient = diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 8ae208cd4..426a213ab 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -7,6 +7,8 @@ #include "../../../src/common/device_helpers.cuh" #include "gtest/gtest.h" +using xgboost::common::Span; + struct Shard { int id; }; TEST(DeviceHelpers, Basic) { @@ -71,3 +73,20 @@ TEST(sumReduce, Test) { auto sum = dh::SumReduction(temp, dh::Raw(data), data.size()); ASSERT_NEAR(sum, 100.0f, 1e-5); } + +void TestAllocator() { + int n = 10; + Span a; + Span b; + Span c; + dh::BulkAllocator ba; + ba.Allocate(0, &a, n, &b, n, &c, n); + + // Should be no illegal memory accesses + dh::LaunchN(0, n, [=] __device__(size_t idx) { c[idx] = a[idx] + b[idx]; }); + + dh::safe_cuda(cudaDeviceSynchronize()); +} + +// Define the test in a function so we can use device lambda +TEST(bulkAllocator, Test) { TestAllocator(); } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 3616f85c8..d7e4b2654 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -56,8 +56,8 @@ TEST(GpuHist, BuildGidxDense) { DeviceShard shard(0, 0, kNRows, param); BuildGidx(&shard, kNRows, kNCols); - std::vector h_gidx_buffer; - h_gidx_buffer = shard.gidx_buffer.AsVector(); + std::vector h_gidx_buffer(shard.gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols); @@ -95,8 +95,8 @@ TEST(GpuHist, BuildGidxSparse) { DeviceShard shard(0, 0, kNRows, param); BuildGidx(&shard, kNRows, kNCols, 0.9f); - std::vector h_gidx_buffer; - h_gidx_buffer = shard.gidx_buffer.AsVector(); + std::vector h_gidx_buffer(shard.gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer); common::CompressedIterator gidx(h_gidx_buffer.data(), 25); ASSERT_LE(shard.ellpack_matrix.row_stride, 3); @@ -149,17 +149,14 @@ void TestBuildHist(GPUHistBuilderBase& builder) { gpair = GradientPair(grad, hess); } - thrust::device_vector gpair (kNRows); - gpair = h_gpair; - int num_symbols = shard.n_bins + 1; thrust::host_vector h_gidx_buffer ( - shard.gidx_buffer.Size()); + shard.gidx_buffer.size()); - common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.Data(); + common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.data(); dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr, - sizeof(common::CompressedByteT) * shard.gidx_buffer.Size(), + sizeof(common::CompressedByteT) * shard.gidx_buffer.size(), cudaMemcpyDeviceToHost)); auto gidx = common::CompressedIterator(h_gidx_buffer.data(), num_symbols); @@ -167,9 +164,10 @@ void TestBuildHist(GPUHistBuilderBase& builder) { shard.ridx_segments.resize(1); shard.ridx_segments[0] = Segment(0, kNRows); shard.hist.AllocateHistogram(0); - shard.gpair.copy(gpair.begin(), gpair.end()); - thrust::sequence(shard.ridx.CurrentDVec().tbegin(), - shard.ridx.CurrentDVec().tend()); + dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair); + thrust::sequence( + thrust::device_pointer_cast(shard.ridx.Current()), + thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size())); builder.Build(&shard, 0); DeviceHistogram d_hist = shard.hist; @@ -262,14 +260,14 @@ TEST(GpuHist, EvaluateSplits) { &(shard->min_fvalue), cmat.min_val.size(), &(shard->gidx_fvalue_map), 24, &(shard->monotone_constraints), kNCols); - shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end()); - shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end()); - shard->monotone_constraints.copy(param.monotone_constraints.begin(), - param.monotone_constraints.end()); - shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan(); - shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan(); - shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end()); - shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan(); + dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr); + dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut); + dh::CopyVectorToDeviceSpan(shard->monotone_constraints, + param.monotone_constraints); + shard->ellpack_matrix.feature_segments = shard->feature_segments; + shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map; + dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val); + shard->ellpack_matrix.min_fvalue = shard->min_fvalue; // Initialize DeviceShard::hist shard->hist.Init(0, (max_bins - 1) * kNCols); @@ -344,8 +342,9 @@ TEST(GpuHist, ApplySplit) { shard->ba.Allocate(0, &(shard->ridx), kNRows, &(shard->position), kNRows); shard->ellpack_matrix.row_stride = kNCols; - thrust::sequence(shard->ridx.CurrentDVec().tbegin(), - shard->ridx.CurrentDVec().tend()); + thrust::sequence( + thrust::device_pointer_cast(shard->ridx.Current()), + thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size())); // Initialize GPUHistMaker hist_maker.param_ = param; RegTree tree; @@ -378,12 +377,12 @@ TEST(GpuHist, ApplySplit) { &(shard->feature_segments), cmat.row_ptr.size(), &(shard->min_fvalue), cmat.min_val.size(), &(shard->gidx_fvalue_map), 24); - shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end()); - shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end()); - shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan(); - shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan(); - shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end()); - shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan(); + dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr); + dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut); + shard->ellpack_matrix.feature_segments = shard->feature_segments; + shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map; + dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val); + shard->ellpack_matrix.min_fvalue = shard->min_fvalue; shard->ellpack_matrix.is_dense = true; common::CompressedBufferWriter wr(num_symbols); @@ -394,10 +393,10 @@ TEST(GpuHist, ApplySplit) { std::vector h_gidx_compressed (compressed_size_bytes); wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end()); - shard->gidx_buffer.copy(h_gidx_compressed.begin(), h_gidx_compressed.end()); + dh::CopyVectorToDeviceSpan(shard->gidx_buffer, h_gidx_compressed); shard->ellpack_matrix.gidx_iter = common::CompressedIterator( - shard->gidx_buffer.Data(), num_symbols); + shard->gidx_buffer.data(), num_symbols); hist_maker.info_ = &info; hist_maker.ApplySplit(candidate_entry, &tree);