- training with external memory - part 2 of 2 (#4526)

* - training with external memory - part 2 of 2
   - when external memory support is enabled, building of histogram indices are
     done incrementally for every sparse page
   - the entire set of input data is divided across multiple gpu's and the relative
     row positions within each device is tracked when building the compressed histogram buffer
   - this was tested using a mortgage dataset containing ~ 670m rows before 4xt4's could be
     saturated
This commit is contained in:
sriramch 2019-06-11 14:52:56 -07:00 committed by Rory Mitchell
parent 4591039eba
commit a2042b685a
4 changed files with 292 additions and 61 deletions

View File

@ -630,6 +630,37 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
#endif
}
// Instances of this type are created while creating the histogram bins for the
// entire dataset across multiple sparse page batches. This keeps track of the number
// of rows to process from a batch and the position from which to process on each device.
struct RowStateOnDevice {
// Number of rows assigned to this device
const size_t total_rows_assigned_to_device;
// Number of rows processed thus far
size_t total_rows_processed;
// Number of rows to process from the current sparse page batch
size_t rows_to_process_from_batch;
// Offset from the current sparse page batch to begin processing
size_t row_offset_in_current_batch;
explicit RowStateOnDevice(size_t total_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(0), row_offset_in_current_batch(0) {
}
explicit RowStateOnDevice(size_t total_rows, size_t batch_rows)
: total_rows_assigned_to_device(total_rows), total_rows_processed(0),
rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) {
}
// Advance the row state by the number of rows processed
void Advance() {
total_rows_processed += rows_to_process_from_batch;
CHECK_LE(total_rows_processed, total_rows_assigned_to_device);
rows_to_process_from_batch = row_offset_in_current_batch = 0;
}
};
// Manage memory for a single GPU
template <typename GradientSumT>
struct DeviceShard {
@ -666,8 +697,6 @@ struct DeviceShard {
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs;
/*! \brief On-device feature set, only actually used on one of the devices */
dh::device_vector<int> feature_set_d;
dh::device_vector<int64_t>
@ -695,7 +724,6 @@ struct DeviceShard {
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int _device_id, int shard_idx, bst_uint row_begin,
bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed)
: device_id(_device_id),
@ -710,32 +738,12 @@ struct DeviceShard {
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
}
/* Init row_ptrs and row_stride */
size_t InitRowPtrs(const SparsePage& row_batch) {
const auto& offset_vec = row_batch.offset.HostVector();
row_ptrs.resize(n_rows + 1);
thrust::copy(offset_vec.data() + row_begin_idx,
offset_vec.data() + row_end_idx + 1,
row_ptrs.begin());
auto row_iter = row_ptrs.begin();
// find the maximum row size for converting to ELLPack
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size),
decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
size_t row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
thrust::maximum<size_t>());
return row_stride;
}
void InitCompressedData(
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense);
const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense);
void CreateHistIndices(const SparsePage& row_batch, size_t row_stride, int null_gidx_value);
void CreateHistIndices(
const SparsePage &row_batch, const common::HistCutMatrix &hmat,
const RowStateOnDevice &device_row_state, int rows_per_batch);
~DeviceShard() {
dh::safe_cuda(cudaSetDevice(device_id));
@ -1229,11 +1237,14 @@ struct DeviceShard {
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
size_t row_stride = this->InitRowPtrs(row_batch);
const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) {
n_bins = hmat.row_ptr.back();
int null_gidx_value = hmat.row_ptr.back();
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
@ -1256,7 +1267,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
node_sum_gradients.resize(max_nodes);
ridx_segments.resize(max_nodes);
// allocate compressed bin data
int num_symbols = n_bins + 1;
// Required buffer size for storing data matrix in ELLPack format.
@ -1264,16 +1274,11 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
num_symbols);
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
thrust::fill(
thrust::device_pointer_cast(gidx_buffer.data()),
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
this->CreateHistIndices(row_batch, row_stride, null_gidx_value);
ellpack_matrix.Init(
feature_segments, min_fvalue,
gidx_fvalue_map, row_stride,
@ -1295,25 +1300,45 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::CreateHistIndices(
const SparsePage& row_batch, size_t row_stride, int null_gidx_value) {
const SparsePage &row_batch,
const common::HistCutMatrix &hmat,
const RowStateOnDevice &device_row_state,
int rows_per_batch) {
// Has any been allocated for me in this batch?
if (!device_row_state.rows_to_process_from_batch) return;
unsigned int null_gidx_value = hmat.row_ptr.back();
size_t row_stride = this->ellpack_matrix.row_stride;
const auto &offset_vec = row_batch.offset.ConstHostVector();
/*! \brief row offset in SparsePage (the input data). */
CHECK_LE(device_row_state.rows_to_process_from_batch, offset_vec.size());
dh::device_vector<size_t> row_ptrs(device_row_state.rows_to_process_from_batch+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch,
offset_vec.data() + device_row_state.row_offset_in_current_batch +
device_row_state.rows_to_process_from_batch + 1,
row_ptrs.begin());
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows =
std::min
(dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(n_rows));
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
size_t gpu_batch_nrows = std::min(
dh::TotalMemory(device_id) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > n_rows) {
batch_row_end = n_rows;
if (batch_row_end > device_row_state.rows_to_process_from_batch) {
batch_row_end = device_row_state.rows_to_process_from_batch;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
// number of entries in this batch.
size_t n_entries = row_ptrs[batch_row_end] - row_ptrs[batch_row_begin];
// copy data entries to device.
@ -1322,17 +1347,20 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
(entries_d.data().get(), data_vec.data() + row_ptrs[batch_row_begin],
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
const dim3 grid3(dh::DivRoundUp(device_row_state.rows_to_process_from_batch, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
CompressBinEllpackKernel<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get() + batch_row_begin,
entries_d.data().get(),
gidx_fvalue_map.data(), feature_segments.data(),
batch_row_begin, batch_nrows,
gidx_fvalue_map.data(),
feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_ptrs[batch_row_begin],
row_stride, null_gidx_value);
row_stride,
null_gidx_value);
}
// free the memory that is no longer needed
@ -1342,6 +1370,60 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
entries_d.shrink_to_fit();
}
// An instance of this type is created which keeps track of total number of rows to process,
// rows processed thus far, rows to process and the offset from the current sparse page batch
// to begin processing on each device
class DeviceHistogramBuilderState {
public:
template <typename GradientSumT>
explicit DeviceHistogramBuilderState(
const std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> &shards) {
device_row_states_.reserve(shards.size());
for (const auto &shard : shards) {
device_row_states_.push_back(RowStateOnDevice(shard->n_rows));
}
}
const RowStateOnDevice &GetRowStateOnDevice(int idx) const {
return device_row_states_[idx];
}
// This method is invoked at the beginning of each sparse page batch. This distributes
// the rows in the sparse page to the different devices.
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
void BeginBatch(const SparsePage &batch) {
size_t rem_rows = batch.Size();
size_t row_offset_in_current_batch = 0;
for (auto &device_row_state : device_row_states_) {
// Do we have anymore left to process from this batch on this device?
if (device_row_state.total_rows_assigned_to_device > device_row_state.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state.rows_to_process_from_batch =
std::min(
device_row_state.total_rows_assigned_to_device - device_row_state.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state.rows_to_process_from_batch = 0;
}
device_row_state.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state.rows_to_process_from_batch;
rem_rows -= device_row_state.rows_to_process_from_batch;
}
}
// This method is invoked after completion of each sparse page batch
void EndBatch() {
for (auto &rs : device_row_states_) {
rs.Advance();
}
}
private:
std::vector<RowStateOnDevice> device_row_states_;
};
template <typename GradientSumT>
class GPUHistMakerSpecialised {
public:
@ -1397,9 +1479,6 @@ class GPUHistMakerSpecialised {
reducer_.Init(device_list_);
auto batch_iter = dmat->GetRowBatches().begin();
const SparsePage& batch = *batch_iter;
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
@ -1418,26 +1497,43 @@ class GPUHistMakerSpecialised {
column_sampling_seed));
});
// Find the cuts.
monitor_.StartCuda("Quantiles");
// TODO(sriramch): The return value will be used when we add support for histogram
// index creation for multiple batches
common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_);
n_bins_ = hmat_.row_ptr.back();
// Create the quantile sketches for the dmatrix and initialize HistCutMatrix
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
hist_maker_param_.gpu_batch_nrows,
dmat, &hmat_);
monitor_.StopCuda("Quantiles");
n_bins_ = hmat_.row_ptr.back();
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
monitor_.StartCuda("BinningCompression");
// Init global data for each shard
monitor_.StartCuda("InitCompressedData");
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->InitCompressedData(hmat_, batch, is_dense);
shard->InitCompressedData(hmat_, row_stride, is_dense);
});
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(shards_);
for (const auto &batch : dmat->GetRowBatches()) {
hist_builder_row_state.BeginBatch(batch);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(idx),
hist_maker_param_.gpu_batch_nrows);
});
hist_builder_row_state.EndBatch();
}
monitor_.StopCuda("BinningCompression");
++batch_iter;
CHECK(batch_iter.AtEnd()) << "External memory not supported";
p_last_fmat_ = dmat;
initialised_ = true;

View File

@ -6,6 +6,7 @@
#include <random>
#include <cinttypes>
#include <dmlc/filesystem.h>
#include "../../src/data/simple_csr_source.h"
bool FileExists(const std::string& filename) {
struct stat st;
@ -165,6 +166,71 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_s
return dmat;
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
size_t page_size, bool deterministic) {
if (!n_rows || !n_cols) {
return nullptr;
}
// Create the svm file in a temp dir
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm";
std::ofstream fo(tmp_file.c_str());
size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1;
int64_t rem_cols = n_cols;
size_t col_idx = 0;
// Random feature id generator
std::random_device rdev;
std::unique_ptr<std::mt19937> gen;
if (deterministic) {
// Seed it with a constant value for this configuration - without getting too fancy
// like ordered pairing functions and its likes to make it truely unique
gen.reset(new std::mt19937(n_rows * n_cols));
} else {
gen.reset(new std::mt19937(rdev()));
}
std::uniform_int_distribution<size_t> dis(1, n_cols);
for (size_t i = 0; i < n_rows; ++i) {
// Make sure that all cols are slotted in the first few rows; randomly distribute the
// rest
std::stringstream row_data;
fo << i;
size_t j = 0;
if (rem_cols > 0) {
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
row_data << " " << (col_idx+j) << ":" << (col_idx+j+1)*10;
}
rem_cols -= cols_per_row;
} else {
// Take some random number of colums in [1, n_cols] and slot them here
size_t ncols = dis(*gen);
for (; j < ncols; ++j) {
size_t fid = (col_idx+j) % n_cols;
row_data << " " << fid << ":" << (fid+1)*10;
}
}
col_idx += j;
fo << row_data.str() << "\n";
}
fo.close();
std::unique_ptr<DMatrix> dmat(DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size));
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
if (!page_size) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource);
source->CopyFrom(dmat.get());
return std::unique_ptr<DMatrix>(DMatrix::Create(std::move(source)));
} else {
return dmat;
}
}
gbm::GBTreeModel CreateTestModel() {
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));

View File

@ -165,6 +165,27 @@ std::shared_ptr<xgboost::DMatrix> *CreateDMatrix(int rows, int columns,
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_size);
/**
* \fn std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
* size_t page_size);
*
* \brief Creates dmatrix with some records, each record containing random number of
* features in [1, n_cols]
*
* \param n_rows Number of records to create.
* \param n_cols Max number of features within that record.
* \param page_size Sparse page size for the pages within the dmatrix. If page size is 0
* then the entire dmatrix is resident in memory; else, multiple sparse pages
* of page size are created and backed to disk, which would have to be
* streamed in at point of use.
* \param deterministic The content inside the dmatrix is constant for this configuration, if true;
* else, the content changes every time this method is invoked
*
* \return The new dmatrix.
*/
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
size_t page_size, bool deterministic);
gbm::GBTreeModel CreateTestModel();
inline LearnerTrainParam CreateEmptyGenericParam(int gpu_id, int n_gpus) {

View File

@ -77,7 +77,14 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
auto is_dense = (*dmat)->Info().num_nonzero_ ==
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
shard->InitCompressedData(cmat, batch, is_dense);
size_t row_stride = 0;
const auto &offset_vec = batch.offset.ConstHostVector();
for (size_t i = 1; i < offset_vec.size(); ++i) {
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
}
shard->InitCompressedData(cmat, row_stride, is_dense);
shard->CreateHistIndices(
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
delete dmat;
}
@ -469,5 +476,46 @@ TEST(GpuHist, SortPosition) {
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
TEST(GpuHist, TestHistogramIndex) {
// Test if the compressed histogram index matches when using a sparse
// dmatrix with and without using external memory
int constexpr kNRows = 1000, kNCols = 10;
// Build 2 matrices and build a histogram maker with that
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true));
std::vector<std::pair<std::string, std::string>> training_params = {
{"max_depth", "1"},
{"max_leaves", "0"}
};
LearnerTrainParam learner_param(CreateEmptyGenericParam(0, 1));
hist_maker.Init(training_params, &learner_param);
hist_maker.InitDataOnce(hist_maker_dmat.get());
hist_maker_ext.Init(training_params, &learner_param);
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
// Extract the device shards from the histogram makers and from that its compressed
// histogram index
const auto &dev_shard = hist_maker.shards_[0];
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
const auto &dev_shard_ext = hist_maker_ext.shards_[0];
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}
} // namespace tree
} // namespace xgboost