58 lines
2.2 KiB
Plaintext
58 lines
2.2 KiB
Plaintext
/**
|
|
* Copyright 2021-2023 by XGBoost contributors
|
|
*/
|
|
#include <memory> // for unique_ptr
|
|
|
|
#include "../common/hist_util.cuh"
|
|
#include "../common/hist_util.h" // for HistogramCuts
|
|
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
|
#include "ellpack_page.cuh"
|
|
#include "sparse_page_dmatrix.h"
|
|
#include "xgboost/context.h" // for Context
|
|
#include "xgboost/data.h" // for BatchParam
|
|
|
|
namespace xgboost::data {
|
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|
const BatchParam& param) {
|
|
CHECK(ctx->IsCUDA());
|
|
if (param.Initialized()) {
|
|
CHECK_GE(param.max_bin, 2);
|
|
}
|
|
detail::CheckEmpty(batch_param_, param);
|
|
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
|
size_t row_stride = 0;
|
|
this->InitializeSparsePage(ctx);
|
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
|
// reinitialize the cache
|
|
cache_info_.erase(id);
|
|
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
|
std::unique_ptr<common::HistogramCuts> cuts;
|
|
if (!param.hess.empty()) {
|
|
cuts = std::make_unique<common::HistogramCuts>(
|
|
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
|
|
} else {
|
|
cuts =
|
|
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
|
|
}
|
|
this->InitializeSparsePage(ctx); // reset after use.
|
|
|
|
row_stride = GetRowStride(this);
|
|
this->InitializeSparsePage(ctx); // reset after use.
|
|
CHECK_NE(row_stride, 0);
|
|
batch_param_ = param;
|
|
|
|
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
|
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
|
|
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
|
|
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
|
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
|
|
} else {
|
|
CHECK(sparse_page_source_);
|
|
ellpack_page_source_->Reset();
|
|
}
|
|
|
|
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
|
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
|
}
|
|
} // namespace xgboost::data
|