- training with external memory - part 2 of 2 (#4526)
* - training with external memory - part 2 of 2
- when external memory support is enabled, building of histogram indices are
done incrementally for every sparse page
- the entire set of input data is divided across multiple gpu's and the relative
row positions within each device is tracked when building the compressed histogram buffer
- this was tested using a mortgage dataset containing ~ 670m rows before 4xt4's could be
saturated
This commit is contained in:
parent
4591039eba
commit
a2042b685a
@ -630,6 +630,37 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
|
||||
#endif
|
||||
}
|
||||
|
||||
// Instances of this type are created while creating the histogram bins for the
|
||||
// entire dataset across multiple sparse page batches. This keeps track of the number
|
||||
// of rows to process from a batch and the position from which to process on each device.
|
||||
struct RowStateOnDevice {
|
||||
// Number of rows assigned to this device
|
||||
const size_t total_rows_assigned_to_device;
|
||||
// Number of rows processed thus far
|
||||
size_t total_rows_processed;
|
||||
// Number of rows to process from the current sparse page batch
|
||||
size_t rows_to_process_from_batch;
|
||||
// Offset from the current sparse page batch to begin processing
|
||||
size_t row_offset_in_current_batch;
|
||||
|
||||
explicit RowStateOnDevice(size_t total_rows)
|
||||
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
||||
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
|
||||
}
|
||||
|
||||
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
|
||||
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
|
||||
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
|
||||
}
|
||||
|
||||
// Advance the row state by the number of rows processed
|
||||
void Advance() {
|
||||
total_rows_processed += rows_to_process_from_batch;
|
||||
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
|
||||
rows_to_process_from_batch = row_offset_in_current_batch = 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Manage memory for a single GPU
|
||||
template <typename GradientSumT>
|
||||
struct DeviceShard {
|
||||
@ -666,8 +697,6 @@ struct DeviceShard {
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
dh::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
dh::device_vector<int> feature_set_d;
|
||||
dh::device_vector<int64_t>
|
||||
@ -695,7 +724,6 @@ struct DeviceShard {
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand;
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int _device_id, int shard_idx, bst_uint row_begin,
|
||||
bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed)
|
||||
: device_id(_device_id),
|
||||
@ -710,32 +738,12 @@ struct DeviceShard {
|
||||
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
|
||||
}
|
||||
|
||||
/* Init row_ptrs and row_stride */
|
||||
size_t InitRowPtrs(const SparsePage& row_batch) {
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs.resize(n_rows + 1);
|
||||
thrust::copy(offset_vec.data() + row_begin_idx,
|
||||
offset_vec.data() + row_end_idx + 1,
|
||||
row_ptrs.begin());
|
||||
auto row_iter = row_ptrs.begin();
|
||||
// find the maximum row size for converting to ELLPack
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size),
|
||||
decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
size_t row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
void InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense);
|
||||
const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense);
|
||||
|
||||
void CreateHistIndices(const SparsePage& row_batch, size_t row_stride, int null_gidx_value);
|
||||
void CreateHistIndices(
|
||||
const SparsePage &row_batch, const common::HistCutMatrix &hmat,
|
||||
const RowStateOnDevice &device_row_state, int rows_per_batch);
|
||||
|
||||
~DeviceShard() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
@ -1229,11 +1237,14 @@ struct DeviceShard {
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
|
||||
size_t row_stride = this->InitRowPtrs(row_batch);
|
||||
const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) {
|
||||
n_bins = hmat.row_ptr.back();
|
||||
int null_gidx_value = hmat.row_ptr.back();
|
||||
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
|
||||
@ -1256,7 +1267,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
@ -1264,16 +1274,11 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(gidx_buffer.data()),
|
||||
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
||||
|
||||
this->CreateHistIndices(row_batch, row_stride, null_gidx_value);
|
||||
|
||||
ellpack_matrix.Init(
|
||||
feature_segments, min_fvalue,
|
||||
gidx_fvalue_map, row_stride,
|
||||
@ -1295,25 +1300,45 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
const SparsePage& row_batch, size_t row_stride, int null_gidx_value) {
|
||||
const SparsePage &row_batch,
|
||||
const common::HistCutMatrix &hmat,
|
||||
const RowStateOnDevice &device_row_state,
|
||||
int rows_per_batch) {
|
||||
// Has any been allocated for me in this batch?
|
||||
if (!device_row_state.rows_to_process_from_batch) return;
|
||||
|
||||
unsigned int null_gidx_value = hmat.row_ptr.back();
|
||||
size_t row_stride = this->ellpack_matrix.row_stride;
|
||||
|
||||
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
CHECK_LE(device_row_state.rows_to_process_from_batch, offset_vec.size());
|
||||
dh::device_vector<size_t> row_ptrs(device_row_state.rows_to_process_from_batch+1);
|
||||
thrust::copy(
|
||||
offset_vec.data() + device_row_state.row_offset_in_current_batch,
|
||||
offset_vec.data() + device_row_state.row_offset_in_current_batch +
|
||||
device_row_state.rows_to_process_from_batch + 1,
|
||||
row_ptrs.begin());
|
||||
|
||||
int num_symbols = n_bins + 1;
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows =
|
||||
std::min
|
||||
(dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(n_rows));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
|
||||
size_t gpu_batch_nrows = std::min(
|
||||
dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
||||
|
||||
dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
|
||||
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
|
||||
gpu_batch_nrows);
|
||||
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
||||
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
|
||||
if (batch_row_end > n_rows) {
|
||||
batch_row_end = n_rows;
|
||||
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
|
||||
batch_row_end = device_row_state.rows_to_process_from_batch;
|
||||
}
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
|
||||
// number of entries in this batch.
|
||||
size_t n_entries = row_ptrs[batch_row_end] - row_ptrs[batch_row_begin];
|
||||
// copy data entries to device.
|
||||
@ -1322,17 +1347,20 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
(entries_d.data().get(), data_vec.data() + row_ptrs[batch_row_begin],
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
const dim3 block3(32, 8, 1); // 256 threads
|
||||
const dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
|
||||
const dim3 grid3(dh::DivRoundUp(device_row_state.rows_to_process_from_batch, block3.x),
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols),
|
||||
gidx_buffer.data(),
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(),
|
||||
gidx_fvalue_map.data(), feature_segments.data(),
|
||||
batch_row_begin, batch_nrows,
|
||||
gidx_fvalue_map.data(),
|
||||
feature_segments.data(),
|
||||
device_row_state.total_rows_processed + batch_row_begin,
|
||||
batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
row_stride,
|
||||
null_gidx_value);
|
||||
}
|
||||
|
||||
// free the memory that is no longer needed
|
||||
@ -1342,6 +1370,60 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
entries_d.shrink_to_fit();
|
||||
}
|
||||
|
||||
// An instance of this type is created which keeps track of total number of rows to process,
|
||||
// rows processed thus far, rows to process and the offset from the current sparse page batch
|
||||
// to begin processing on each device
|
||||
class DeviceHistogramBuilderState {
|
||||
public:
|
||||
template <typename GradientSumT>
|
||||
explicit DeviceHistogramBuilderState(
|
||||
const std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> &shards) {
|
||||
device_row_states_.reserve(shards.size());
|
||||
for (const auto &shard : shards) {
|
||||
device_row_states_.push_back(RowStateOnDevice(shard->n_rows));
|
||||
}
|
||||
}
|
||||
|
||||
const RowStateOnDevice &GetRowStateOnDevice(int idx) const {
|
||||
return device_row_states_[idx];
|
||||
}
|
||||
|
||||
// This method is invoked at the beginning of each sparse page batch. This distributes
|
||||
// the rows in the sparse page to the different devices.
|
||||
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
|
||||
void BeginBatch(const SparsePage &batch) {
|
||||
size_t rem_rows = batch.Size();
|
||||
size_t row_offset_in_current_batch = 0;
|
||||
for (auto &device_row_state : device_row_states_) {
|
||||
// Do we have anymore left to process from this batch on this device?
|
||||
if (device_row_state.total_rows_assigned_to_device > device_row_state.total_rows_processed) {
|
||||
// There are still some rows that needs to be assigned to this device
|
||||
device_row_state.rows_to_process_from_batch =
|
||||
std::min(
|
||||
device_row_state.total_rows_assigned_to_device - device_row_state.total_rows_processed,
|
||||
rem_rows);
|
||||
} else {
|
||||
// All rows have been assigned to this device
|
||||
device_row_state.rows_to_process_from_batch = 0;
|
||||
}
|
||||
|
||||
device_row_state.row_offset_in_current_batch = row_offset_in_current_batch;
|
||||
row_offset_in_current_batch += device_row_state.rows_to_process_from_batch;
|
||||
rem_rows -= device_row_state.rows_to_process_from_batch;
|
||||
}
|
||||
}
|
||||
|
||||
// This method is invoked after completion of each sparse page batch
|
||||
void EndBatch() {
|
||||
for (auto &rs : device_row_states_) {
|
||||
rs.Advance();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<RowStateOnDevice> device_row_states_;
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GPUHistMakerSpecialised {
|
||||
public:
|
||||
@ -1397,9 +1479,6 @@ class GPUHistMakerSpecialised {
|
||||
|
||||
reducer_.Init(device_list_);
|
||||
|
||||
auto batch_iter = dmat->GetRowBatches().begin();
|
||||
const SparsePage& batch = *batch_iter;
|
||||
|
||||
// Synchronise the column sampling seed
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
@ -1418,26 +1497,43 @@ class GPUHistMakerSpecialised {
|
||||
column_sampling_seed));
|
||||
});
|
||||
|
||||
// Find the cuts.
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// TODO(sriramch): The return value will be used when we add support for histogram
|
||||
// index creation for multiple batches
|
||||
common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
// Create the quantile sketches for the dmatrix and initialize HistCutMatrix
|
||||
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
|
||||
hist_maker_param_.gpu_batch_nrows,
|
||||
dmat, &hmat_);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
|
||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
// Init global data for each shard
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->InitCompressedData(hmat_, batch, is_dense);
|
||||
shard->InitCompressedData(hmat_, row_stride, is_dense);
|
||||
});
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
DeviceHistogramBuilderState hist_builder_row_state(shards_);
|
||||
for (const auto &batch : dmat->GetRowBatches()) {
|
||||
hist_builder_row_state.BeginBatch(batch);
|
||||
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(idx),
|
||||
hist_maker_param_.gpu_batch_nrows);
|
||||
});
|
||||
|
||||
hist_builder_row_state.EndBatch();
|
||||
}
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
++batch_iter;
|
||||
CHECK(batch_iter.AtEnd()) << "External memory not supported";
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <random>
|
||||
#include <cinttypes>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include "../../src/data/simple_csr_source.h"
|
||||
|
||||
bool FileExists(const std::string& filename) {
|
||||
struct stat st;
|
||||
@ -165,6 +166,71 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_s
|
||||
return dmat;
|
||||
}
|
||||
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||
size_t page_size, bool deterministic) {
|
||||
if (!n_rows || !n_cols) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create the svm file in a temp dir
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
||||
|
||||
std::ofstream fo(tmp_file.c_str());
|
||||
size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1;
|
||||
int64_t rem_cols = n_cols;
|
||||
size_t col_idx = 0;
|
||||
|
||||
// Random feature id generator
|
||||
std::random_device rdev;
|
||||
std::unique_ptr<std::mt19937> gen;
|
||||
if (deterministic) {
|
||||
// Seed it with a constant value for this configuration - without getting too fancy
|
||||
// like ordered pairing functions and its likes to make it truely unique
|
||||
gen.reset(new std::mt19937(n_rows * n_cols));
|
||||
} else {
|
||||
gen.reset(new std::mt19937(rdev()));
|
||||
}
|
||||
std::uniform_int_distribution<size_t> dis(1, n_cols);
|
||||
|
||||
for (size_t i = 0; i < n_rows; ++i) {
|
||||
// Make sure that all cols are slotted in the first few rows; randomly distribute the
|
||||
// rest
|
||||
std::stringstream row_data;
|
||||
fo << i;
|
||||
size_t j = 0;
|
||||
if (rem_cols > 0) {
|
||||
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
|
||||
row_data << " " << (col_idx+j) << ":" << (col_idx+j+1)*10;
|
||||
}
|
||||
rem_cols -= cols_per_row;
|
||||
} else {
|
||||
// Take some random number of colums in [1, n_cols] and slot them here
|
||||
size_t ncols = dis(*gen);
|
||||
for (; j < ncols; ++j) {
|
||||
size_t fid = (col_idx+j) % n_cols;
|
||||
row_data << " " << fid << ":" << (fid+1)*10;
|
||||
}
|
||||
}
|
||||
col_idx += j;
|
||||
|
||||
fo << row_data.str() << "\n";
|
||||
}
|
||||
fo.close();
|
||||
|
||||
std::unique_ptr<DMatrix> dmat(DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size));
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
|
||||
|
||||
if (!page_size) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource);
|
||||
source->CopyFrom(dmat.get());
|
||||
return std::unique_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
} else {
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
|
||||
gbm::GBTreeModel CreateTestModel() {
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
|
||||
@ -165,6 +165,27 @@ std::shared_ptr<xgboost::DMatrix> *CreateDMatrix(int rows, int columns,
|
||||
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_size);
|
||||
|
||||
/**
|
||||
* \fn std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||
* size_t page_size);
|
||||
*
|
||||
* \brief Creates dmatrix with some records, each record containing random number of
|
||||
* features in [1, n_cols]
|
||||
*
|
||||
* \param n_rows Number of records to create.
|
||||
* \param n_cols Max number of features within that record.
|
||||
* \param page_size Sparse page size for the pages within the dmatrix. If page size is 0
|
||||
* then the entire dmatrix is resident in memory; else, multiple sparse pages
|
||||
* of page size are created and backed to disk, which would have to be
|
||||
* streamed in at point of use.
|
||||
* \param deterministic The content inside the dmatrix is constant for this configuration, if true;
|
||||
* else, the content changes every time this method is invoked
|
||||
*
|
||||
* \return The new dmatrix.
|
||||
*/
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||
size_t page_size, bool deterministic);
|
||||
|
||||
gbm::GBTreeModel CreateTestModel();
|
||||
|
||||
inline LearnerTrainParam CreateEmptyGenericParam(int gpu_id, int n_gpus) {
|
||||
|
||||
@ -77,7 +77,14 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
||||
|
||||
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||
shard->InitCompressedData(cmat, batch, is_dense);
|
||||
size_t row_stride = 0;
|
||||
const auto &offset_vec = batch.offset.ConstHostVector();
|
||||
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
||||
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||
}
|
||||
shard->InitCompressedData(cmat, row_stride, is_dense);
|
||||
shard->CreateHistIndices(
|
||||
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
@ -469,5 +476,46 @@ TEST(GpuHist, SortPosition) {
|
||||
TestSortPosition({2, 2, 2, 2}, 1, 2);
|
||||
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
|
||||
}
|
||||
|
||||
TEST(GpuHist, TestHistogramIndex) {
|
||||
// Test if the compressed histogram index matches when using a sparse
|
||||
// dmatrix with and without using external memory
|
||||
|
||||
int constexpr kNRows = 1000, kNCols = 10;
|
||||
|
||||
// Build 2 matrices and build a histogram maker with that
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true));
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> training_params = {
|
||||
{"max_depth", "1"},
|
||||
{"max_leaves", "0"}
|
||||
};
|
||||
|
||||
LearnerTrainParam learner_param(CreateEmptyGenericParam(0, 1));
|
||||
hist_maker.Init(training_params, &learner_param);
|
||||
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
||||
hist_maker_ext.Init(training_params, &learner_param);
|
||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||
|
||||
// Extract the device shards from the histogram makers and from that its compressed
|
||||
// histogram index
|
||||
const auto &dev_shard = hist_maker.shards_[0];
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
|
||||
|
||||
const auto &dev_shard_ext = hist_maker_ext.shards_[0];
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
|
||||
|
||||
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
|
||||
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
|
||||
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user