diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh index 5f0986d5b..137832def 100644 --- a/src/common/algorithm.cuh +++ b/src/common/algorithm.cuh @@ -17,7 +17,8 @@ #include "common.h" // safe_cuda #include "cuda_context.cuh" // CUDAContext -#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector +#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota +#include "device_vector.cuh" // for device_vector #include "xgboost/base.h" // XGBOOST_DEVICE #include "xgboost/context.h" // Context #include "xgboost/logging.h" // CHECK diff --git a/src/common/quantile.cu b/src/common/quantile.cu index c6c665258..d807bd7af 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -182,7 +182,8 @@ common::Span> MergePath( merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple { auto ind = get_ind(t); // == 0 if element is from x // x_counter, y_counter - return thrust::tuple{!ind, ind}; + return thrust::make_tuple(static_cast(!ind), + static_cast(ind)); }); // Compute the index for both x and y (which of the element in a and b are used in each diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu index 47ccadd4e..f90ca882f 100644 --- a/src/data/quantile_dmatrix.cu +++ b/src/data/quantile_dmatrix.cu @@ -6,7 +6,9 @@ #include // for vector #include "../collective/allreduce.h" // for Allreduce +#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_rt_utils.h" // for AllVisibleGPUs +#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE #include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator #include "../common/hist_util.cuh" // for AdapterDeviceSketch #include "../common/quantile.cuh" // for SketchContainer @@ -26,8 +28,10 @@ void MakeSketches(Context const* ctx, DMatrixProxy* proxy, std::shared_ptr ref, BatchParam const& p, float missing, std::shared_ptr cuts, MetaInfo const& info, ExternalDataInfo* p_ext_info) { - dh::XGBCachingDeviceAllocator alloc; - std::vector sketch_containers; + xgboost_NVTX_FN_RANGE(); + + CUDAContext const* cuctx = ctx->CUDACtx(); + std::unique_ptr sketch; auto& ext_info = *p_ext_info; do { @@ -44,12 +48,14 @@ void MakeSketches(Context const* ctx, << "Inconsistent number of columns."; } if (!ref) { - sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features, - data::BatchSamples(proxy), dh::GetDevice(ctx)); - auto* p_sketch = &sketch_containers.back(); + if (!sketch) { + sketch = std::make_unique( + proxy->Info().feature_types, p.max_bin, ext_info.n_features, data::BatchSamples(proxy), + dh::GetDevice(ctx)); + } proxy->Info().weights_.SetDevice(dh::GetDevice(ctx)); cuda_impl::Dispatch(proxy, [&](auto const& value) { - common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch); + common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get()); }); } auto batch_rows = data::BatchSamples(proxy); @@ -60,7 +66,7 @@ void MakeSketches(Context const* ctx, std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); })); - ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); + ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end()); ext_info.n_batches++; ext_info.base_rows.push_back(batch_rows); } while (iter->Next()); @@ -73,18 +79,7 @@ void MakeSketches(Context const* ctx, // Get reference dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); if (!ref) { - HostDeviceVector ft; - common::SketchContainer final_sketch( - sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, - ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx)); - for (auto const& sketch : sketch_containers) { - final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); - final_sketch.FixError(); - } - sketch_containers.clear(); - sketch_containers.shrink_to_fit(); - - final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit()); + sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit()); } else { GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get()); }