Move ellpack page construction into DMatrix (#4833)
This commit is contained in:
@@ -99,15 +99,15 @@ struct SketchContainer {
|
||||
std::vector<std::mutex> col_locks_; // NOLINT
|
||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||
|
||||
SketchContainer(const tree::TrainParam ¶m, DMatrix *dmat) :
|
||||
SketchContainer(int max_bin, DMatrix *dmat) :
|
||||
col_locks_(dmat->Info().num_col_) {
|
||||
const MetaInfo &info = dmat->Info();
|
||||
// Initialize Sketches for this dmatrix
|
||||
sketches_.resize(info.num_col_);
|
||||
#pragma omp parallel for default(none) shared(info, param) schedule(static) \
|
||||
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
|
||||
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
|
||||
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
|
||||
sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ struct GPUSketcher {
|
||||
bool has_weights_{false};
|
||||
size_t row_stride_{0};
|
||||
|
||||
tree::TrainParam param_;
|
||||
const int max_bin_;
|
||||
SketchContainer *sketch_container_;
|
||||
dh::device_vector<size_t> row_ptrs_{};
|
||||
dh::device_vector<Entry> entries_{};
|
||||
@@ -148,11 +148,11 @@ struct GPUSketcher {
|
||||
public:
|
||||
DeviceShard(int device,
|
||||
bst_uint n_rows,
|
||||
tree::TrainParam param,
|
||||
int max_bin,
|
||||
SketchContainer* sketch_container) :
|
||||
device_(device),
|
||||
n_rows_(n_rows),
|
||||
param_(std::move(param)),
|
||||
max_bin_(max_bin),
|
||||
sketch_container_(sketch_container) {
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ struct GPUSketcher {
|
||||
}
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * param_.max_bin);
|
||||
double eps = 1.0 / (kFactor * max_bin_);
|
||||
size_t dummy_nlevel;
|
||||
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
|
||||
@@ -362,7 +362,7 @@ struct GPUSketcher {
|
||||
// add cuts into sketches
|
||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||
#pragma omp parallel for default(none) schedule(static) \
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
WXQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_cuts_);
|
||||
@@ -403,10 +403,8 @@ struct GPUSketcher {
|
||||
};
|
||||
|
||||
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
||||
auto device = generic_param_.gpu_id;
|
||||
|
||||
// create device shard
|
||||
shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get()));
|
||||
shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get()));
|
||||
|
||||
// compute sketches for the shard
|
||||
shard_->Init(batch, info, gpu_batch_nrows_);
|
||||
@@ -417,9 +415,8 @@ struct GPUSketcher {
|
||||
row_stride_ = shard_->GetRowStride();
|
||||
}
|
||||
|
||||
GPUSketcher(const tree::TrainParam ¶m, const GenericParameter &generic_param, int gpu_nrows)
|
||||
: param_(param), generic_param_(generic_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {
|
||||
}
|
||||
GPUSketcher(int device, int max_bin, int gpu_nrows)
|
||||
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
|
||||
|
||||
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
||||
* for the entire dataset */
|
||||
@@ -427,29 +424,31 @@ struct GPUSketcher {
|
||||
const MetaInfo &info = dmat->Info();
|
||||
|
||||
row_stride_ = 0;
|
||||
sketch_container_.reset(new SketchContainer(param_, dmat));
|
||||
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
|
||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->SketchBatch(batch, info);
|
||||
}
|
||||
|
||||
hmat->Init(&sketch_container_->sketches_, param_.max_bin);
|
||||
hmat->Init(&sketch_container_->sketches_, max_bin_);
|
||||
|
||||
return row_stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceShard> shard_;
|
||||
const tree::TrainParam ¶m_;
|
||||
const GenericParameter &generic_param_;
|
||||
const int device_;
|
||||
const int max_bin_;
|
||||
int gpu_batch_nrows_;
|
||||
size_t row_stride_;
|
||||
std::unique_ptr<SketchContainer> sketch_container_;
|
||||
};
|
||||
|
||||
size_t DeviceSketch
|
||||
(const tree::TrainParam ¶m, const GenericParameter &learner_param, int gpu_batch_nrows,
|
||||
DMatrix *dmat, HistogramCuts *hmat) {
|
||||
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
|
||||
size_t DeviceSketch(int device,
|
||||
int max_bin,
|
||||
int gpu_batch_nrows,
|
||||
DMatrix* dmat,
|
||||
HistogramCuts* hmat) {
|
||||
GPUSketcher sketcher(device, max_bin, gpu_batch_nrows);
|
||||
// We only need to return the result in HistogramCuts container, so it is safe to
|
||||
// use a pointer of local HistogramCutsDense
|
||||
DenseCuts dense_cuts(hmat);
|
||||
|
||||
Reference in New Issue
Block a user