[EM] Reuse the quantile container. (#10761)
Use the push method to merge the quantiles instead of creating multiple containers. This reduces the memory usage by consistent pruning.
This commit is contained in:
parent
4fe67f10b4
commit
7510a87466
@ -17,7 +17,8 @@
|
|||||||
|
|
||||||
#include "common.h" // safe_cuda
|
#include "common.h" // safe_cuda
|
||||||
#include "cuda_context.cuh" // CUDAContext
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector
|
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota
|
||||||
|
#include "device_vector.cuh" // for device_vector
|
||||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/logging.h" // CHECK
|
#include "xgboost/logging.h" // CHECK
|
||||||
|
|||||||
@ -182,7 +182,8 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
|
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
|
||||||
auto ind = get_ind(t); // == 0 if element is from x
|
auto ind = get_ind(t); // == 0 if element is from x
|
||||||
// x_counter, y_counter
|
// x_counter, y_counter
|
||||||
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
|
return thrust::make_tuple(static_cast<std::uint64_t>(!ind),
|
||||||
|
static_cast<std::uint64_t>(ind));
|
||||||
});
|
});
|
||||||
|
|
||||||
// Compute the index for both x and y (which of the element in a and b are used in each
|
// Compute the index for both x and y (which of the element in a and b are used in each
|
||||||
|
|||||||
@ -6,7 +6,9 @@
|
|||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/allreduce.h" // for Allreduce
|
#include "../collective/allreduce.h" // for Allreduce
|
||||||
|
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
||||||
|
#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
|
||||||
#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator
|
#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator
|
||||||
#include "../common/hist_util.cuh" // for AdapterDeviceSketch
|
#include "../common/hist_util.cuh" // for AdapterDeviceSketch
|
||||||
#include "../common/quantile.cuh" // for SketchContainer
|
#include "../common/quantile.cuh" // for SketchContainer
|
||||||
@ -26,8 +28,10 @@ void MakeSketches(Context const* ctx,
|
|||||||
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const& p,
|
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const& p,
|
||||||
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
|
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
|
||||||
ExternalDataInfo* p_ext_info) {
|
ExternalDataInfo* p_ext_info) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
xgboost_NVTX_FN_RANGE();
|
||||||
std::vector<common::SketchContainer> sketch_containers;
|
|
||||||
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
|
std::unique_ptr<common::SketchContainer> sketch;
|
||||||
auto& ext_info = *p_ext_info;
|
auto& ext_info = *p_ext_info;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
@ -44,12 +48,14 @@ void MakeSketches(Context const* ctx,
|
|||||||
<< "Inconsistent number of columns.";
|
<< "Inconsistent number of columns.";
|
||||||
}
|
}
|
||||||
if (!ref) {
|
if (!ref) {
|
||||||
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features,
|
if (!sketch) {
|
||||||
data::BatchSamples(proxy), dh::GetDevice(ctx));
|
sketch = std::make_unique<common::SketchContainer>(
|
||||||
auto* p_sketch = &sketch_containers.back();
|
proxy->Info().feature_types, p.max_bin, ext_info.n_features, data::BatchSamples(proxy),
|
||||||
|
dh::GetDevice(ctx));
|
||||||
|
}
|
||||||
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
||||||
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
auto batch_rows = data::BatchSamples(proxy);
|
auto batch_rows = data::BatchSamples(proxy);
|
||||||
@ -60,7 +66,7 @@ void MakeSketches(Context const* ctx,
|
|||||||
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||||
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
||||||
}));
|
}));
|
||||||
ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
|
ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end());
|
||||||
ext_info.n_batches++;
|
ext_info.n_batches++;
|
||||||
ext_info.base_rows.push_back(batch_rows);
|
ext_info.base_rows.push_back(batch_rows);
|
||||||
} while (iter->Next());
|
} while (iter->Next());
|
||||||
@ -73,18 +79,7 @@ void MakeSketches(Context const* ctx,
|
|||||||
// Get reference
|
// Get reference
|
||||||
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||||
if (!ref) {
|
if (!ref) {
|
||||||
HostDeviceVector<FeatureType> ft;
|
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
||||||
common::SketchContainer final_sketch(
|
|
||||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin,
|
|
||||||
ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx));
|
|
||||||
for (auto const& sketch : sketch_containers) {
|
|
||||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
|
||||||
final_sketch.FixError();
|
|
||||||
}
|
|
||||||
sketch_containers.clear();
|
|
||||||
sketch_containers.shrink_to_fit();
|
|
||||||
|
|
||||||
final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
|
||||||
} else {
|
} else {
|
||||||
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
|
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user