[EM] Reuse the quantile container. (#10761)
Use the push method to merge the quantiles instead of creating multiple containers. This reduces the memory usage by consistent pruning.
This commit is contained in:
parent
4fe67f10b4
commit
7510a87466
@ -17,7 +17,8 @@
|
||||
|
||||
#include "common.h" // safe_cuda
|
||||
#include "cuda_context.cuh" // CUDAContext
|
||||
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector
|
||||
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota
|
||||
#include "device_vector.cuh" // for device_vector
|
||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/logging.h" // CHECK
|
||||
|
||||
@ -182,7 +182,8 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
||||
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
|
||||
auto ind = get_ind(t); // == 0 if element is from x
|
||||
// x_counter, y_counter
|
||||
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
|
||||
return thrust::make_tuple(static_cast<std::uint64_t>(!ind),
|
||||
static_cast<std::uint64_t>(ind));
|
||||
});
|
||||
|
||||
// Compute the index for both x and y (which of the element in a and b are used in each
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
||||
#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
|
||||
#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator
|
||||
#include "../common/hist_util.cuh" // for AdapterDeviceSketch
|
||||
#include "../common/quantile.cuh" // for SketchContainer
|
||||
@ -26,8 +28,10 @@ void MakeSketches(Context const* ctx,
|
||||
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const& p,
|
||||
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
|
||||
ExternalDataInfo* p_ext_info) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
std::vector<common::SketchContainer> sketch_containers;
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
|
||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||
std::unique_ptr<common::SketchContainer> sketch;
|
||||
auto& ext_info = *p_ext_info;
|
||||
|
||||
do {
|
||||
@ -44,12 +48,14 @@ void MakeSketches(Context const* ctx,
|
||||
<< "Inconsistent number of columns.";
|
||||
}
|
||||
if (!ref) {
|
||||
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features,
|
||||
data::BatchSamples(proxy), dh::GetDevice(ctx));
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
if (!sketch) {
|
||||
sketch = std::make_unique<common::SketchContainer>(
|
||||
proxy->Info().feature_types, p.max_bin, ext_info.n_features, data::BatchSamples(proxy),
|
||||
dh::GetDevice(ctx));
|
||||
}
|
||||
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
||||
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get());
|
||||
});
|
||||
}
|
||||
auto batch_rows = data::BatchSamples(proxy);
|
||||
@ -60,7 +66,7 @@ void MakeSketches(Context const* ctx,
|
||||
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
||||
}));
|
||||
ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
|
||||
ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end());
|
||||
ext_info.n_batches++;
|
||||
ext_info.base_rows.push_back(batch_rows);
|
||||
} while (iter->Next());
|
||||
@ -73,18 +79,7 @@ void MakeSketches(Context const* ctx,
|
||||
// Get reference
|
||||
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||
if (!ref) {
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
common::SketchContainer final_sketch(
|
||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin,
|
||||
ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx));
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
final_sketch.FixError();
|
||||
}
|
||||
sketch_containers.clear();
|
||||
sketch_containers.shrink_to_fit();
|
||||
|
||||
final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
||||
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
||||
} else {
|
||||
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user