GPU binning and compression. (#3319)
* GPU binning and compression. - binning and index compression are done inside the DeviceShard constructor - in case of a DMatrix with multiple row batches, it is first converted into a single row batch
This commit is contained in:
parent
3f7696ff53
commit
286dccb8e8
@ -73,3 +73,4 @@ List of Contributors
|
||||
* [Gideon Whitehead](https://github.com/gaw89)
|
||||
* [Yi-Lin Juang](https://github.com/frankyjuang)
|
||||
* [Andrew Hannigan](https://github.com/andrewhannigan)
|
||||
* [Andy Adinets](https://github.com/canonizer)
|
||||
|
||||
@ -126,15 +126,15 @@ struct SparseBatch {
|
||||
/*! \brief feature value */
|
||||
bst_float fvalue;
|
||||
/*! \brief default constructor */
|
||||
Entry() = default;
|
||||
XGBOOST_DEVICE Entry() {}
|
||||
/*!
|
||||
* \brief constructor with index and value
|
||||
* \param index The feature or row index.
|
||||
* \param fvalue THe feature value.
|
||||
*/
|
||||
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
|
||||
XGBOOST_DEVICE Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
|
||||
/*! \brief reversely compare feature values */
|
||||
inline static bool CmpValue(const Entry& a, const Entry& b) {
|
||||
XGBOOST_DEVICE inline static bool CmpValue(const Entry& a, const Entry& b) {
|
||||
return a.fvalue < b.fvalue;
|
||||
}
|
||||
};
|
||||
|
||||
@ -8,6 +8,10 @@
|
||||
#include <cstddef>
|
||||
#include <algorithm>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include "device_helpers.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@ -96,6 +100,23 @@ class CompressedBufferWriter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ void AtomicWriteSymbol
|
||||
(CompressedByteT* buffer, uint64_t symbol, size_t offset) {
|
||||
size_t ibit_start = offset * symbol_bits_;
|
||||
size_t ibit_end = (offset + 1) * symbol_bits_ - 1;
|
||||
size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8;
|
||||
|
||||
symbol <<= 7 - ibit_end % 8;
|
||||
for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) {
|
||||
dh::AtomicOrByte(reinterpret_cast<unsigned int*>(buffer + detail::kPadding),
|
||||
ibyte, symbol & 0xff);
|
||||
symbol >>= 8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IterT>
|
||||
void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) {
|
||||
uint64_t tmp = 0;
|
||||
|
||||
@ -122,6 +122,14 @@ inline size_t AvailableMemory(int device_idx) {
|
||||
return device_free;
|
||||
}
|
||||
|
||||
inline size_t TotalMemory(int device_idx) {
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
|
||||
return device_total;
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn inline int max_shared_memory(int device_idx)
|
||||
*
|
||||
@ -155,6 +163,12 @@ inline void CheckComputeCapability() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
|
||||
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Range iterator
|
||||
*/
|
||||
|
||||
@ -183,6 +183,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.set_lower_bound(-1)
|
||||
.set_default(1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
|
||||
|
||||
// add alias of parameters
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
@ -224,6 +227,53 @@ struct CalcWeightTrainParam {
|
||||
learning_rate(p.learning_rate) {}
|
||||
};
|
||||
|
||||
// index of the first element in cuts greater than v, or n if none;
|
||||
// cuts are ordered, and binary search is used
|
||||
__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
|
||||
if (n == 0)
|
||||
return 0;
|
||||
if (cuts[n - 1] <= v)
|
||||
return n;
|
||||
if (cuts[0] > v)
|
||||
return 0;
|
||||
int left = 0, right = n - 1;
|
||||
while (right - left > 1) {
|
||||
int middle = left + (right - left) / 2;
|
||||
if (cuts[middle] > v)
|
||||
right = middle;
|
||||
else
|
||||
left = middle;
|
||||
}
|
||||
return right;
|
||||
}
|
||||
|
||||
__global__ void compress_bin_ellpack_k
|
||||
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
|
||||
const size_t* __restrict__ row_ptrs,
|
||||
const RowBatch::Entry* __restrict__ entries,
|
||||
const float* __restrict__ cuts, const size_t* __restrict__ cut_rows,
|
||||
size_t base_row, size_t n_rows, size_t row_ptr_begin, size_t row_stride,
|
||||
unsigned int null_gidx_value) {
|
||||
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
|
||||
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if (irow >= n_rows || ifeature >= row_stride)
|
||||
return;
|
||||
int row_size = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
|
||||
unsigned int bin = null_gidx_value;
|
||||
if (ifeature < row_size) {
|
||||
RowBatch::Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
|
||||
int feature = entry.index;
|
||||
float fvalue = entry.fvalue;
|
||||
const float *feature_cuts = &cuts[cut_rows[feature]];
|
||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||
bin = upper_bound(feature_cuts, ncuts, fvalue);
|
||||
if (bin >= ncuts)
|
||||
bin = ncuts - 1;
|
||||
bin += cut_rows[feature];
|
||||
}
|
||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||
}
|
||||
|
||||
// Manage memory for a single GPU
|
||||
struct DeviceShard {
|
||||
struct Segment {
|
||||
@ -271,74 +321,117 @@ struct DeviceShard {
|
||||
dh::CubMemory temp_memory;
|
||||
|
||||
DeviceShard(int device_idx, int normalised_device_idx,
|
||||
const common::GHistIndexMatrix& gmat, bst_uint row_begin,
|
||||
bst_uint row_end, int n_bins, TrainParam param)
|
||||
: device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(n_bins),
|
||||
null_gidx_value(n_bins),
|
||||
param(param),
|
||||
prediction_cache_initialised(false) {
|
||||
// Convert to ELLPACK matrix representation
|
||||
int max_elements_row = 0;
|
||||
for (auto i = row_begin; i < row_end; i++) {
|
||||
max_elements_row =
|
||||
(std::max)(max_elements_row,
|
||||
static_cast<int>(gmat.row_ptr[i + 1] - gmat.row_ptr[i]));
|
||||
}
|
||||
row_stride = max_elements_row;
|
||||
std::vector<int> ellpack_matrix(row_stride * n_rows, null_gidx_value);
|
||||
bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param)
|
||||
: device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(n_bins),
|
||||
null_gidx_value(n_bins),
|
||||
param(param),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
for (auto i = row_begin; i < row_end; i++) {
|
||||
int row_count = 0;
|
||||
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
|
||||
ellpack_matrix[(i - row_begin) * row_stride + row_count] =
|
||||
gmat.index[j];
|
||||
row_count++;
|
||||
}
|
||||
}
|
||||
void Init(const common::HistCutMatrix& hmat, const RowBatch& row_batch) {
|
||||
// copy cuts to the GPU
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
thrust::device_vector<float> cuts_d(hmat.cut);
|
||||
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
|
||||
|
||||
// Allocate
|
||||
// find the maximum row size
|
||||
thrust::device_vector<size_t> row_ptr_d(
|
||||
row_batch.ind_ptr + row_begin_idx, row_batch.ind_ptr + row_end_idx + 1);
|
||||
|
||||
auto row_iter = row_ptr_d.begin();
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size),
|
||||
decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
thrust::maximum<size_t>());
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(
|
||||
ellpack_matrix.size(), num_symbols);
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);
|
||||
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
// bin and compress entries in batches of rows
|
||||
// use no more than 1/16th of GPU memory per batch
|
||||
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
|
||||
(16 * row_stride * sizeof(RowBatch::Entry));
|
||||
if (gpu_batch_nrows > n_rows) {
|
||||
gpu_batch_nrows = n_rows;
|
||||
}
|
||||
thrust::device_vector<RowBatch::Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
||||
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
|
||||
if (batch_row_end > n_rows) {
|
||||
batch_row_end = n_rows;
|
||||
}
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
size_t n_entries =
|
||||
row_batch.ind_ptr[row_begin_idx + batch_row_end] -
|
||||
row_batch.ind_ptr[row_begin_idx + batch_row_begin];
|
||||
dh::safe_cuda
|
||||
(cudaMemcpy
|
||||
(entries_d.data().get(),
|
||||
&row_batch.data_ptr[row_batch.ind_ptr[row_begin_idx + batch_row_begin]],
|
||||
n_entries * sizeof(RowBatch::Entry), cudaMemcpyDefault));
|
||||
dim3 block3(32, 8, 1);
|
||||
dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
compress_bin_ellpack_k<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(),
|
||||
row_ptr_d.data().get() + batch_row_begin,
|
||||
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_batch.ind_ptr[row_begin_idx + batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// free the memory that is no longer needed
|
||||
row_ptr_d.resize(0);
|
||||
row_ptr_d.shrink_to_fit();
|
||||
entries_d.resize(0);
|
||||
entries_d.shrink_to_fit();
|
||||
|
||||
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);
|
||||
|
||||
// allocate the rest
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
|
||||
ba.Allocate(device_idx, param.silent,
|
||||
&gpair, n_rows, &ridx, n_rows, &position, n_rows,
|
||||
&prediction_cache, n_rows, &node_sum_gradients_d, max_nodes,
|
||||
&feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map,
|
||||
gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size(),
|
||||
&feature_segments, hmat.row_ptr.size(), &gidx_fvalue_map,
|
||||
hmat.cut.size(), &min_fvalue, hmat.min_val.size(),
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
gidx_fvalue_map = gmat.cut->cut;
|
||||
min_fvalue = gmat.cut->min_val;
|
||||
feature_segments = gmat.cut->row_ptr;
|
||||
gidx_fvalue_map = hmat.cut;
|
||||
min_fvalue = hmat.min_val;
|
||||
feature_segments = hmat.row_ptr;
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
// Compress gidx
|
||||
common::CompressedBufferWriter cbw(num_symbols);
|
||||
std::vector<common::CompressedByteT> host_buffer(gidx_buffer.Size());
|
||||
cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end());
|
||||
gidx_buffer = host_buffer;
|
||||
gidx =
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);
|
||||
|
||||
common::CompressedIterator<uint32_t> ci_host(host_buffer.data(),
|
||||
num_symbols);
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent);
|
||||
hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent);
|
||||
|
||||
dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
|
||||
}
|
||||
@ -579,8 +672,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
info_ = &dmat->Info();
|
||||
monitor_.Start("Quantiles", device_list_);
|
||||
hmat_.Init(dmat, param_.max_bin);
|
||||
gmat_.cut = &hmat_;
|
||||
gmat_.Init(dmat);
|
||||
monitor_.Stop("Quantiles", device_list_);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
|
||||
@ -609,12 +700,22 @@ class GPUHistMaker : public TreeUpdater {
|
||||
row_begin = row_end;
|
||||
}
|
||||
|
||||
// Create device shards
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>(
|
||||
new DeviceShard(device_list_[i], i, gmat_,
|
||||
row_segments[i], row_segments[i + 1], n_bins_, param_));
|
||||
});
|
||||
monitor_.Start("BinningCompression", device_list_);
|
||||
{
|
||||
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next()) << "Empty batches are not supported";
|
||||
const RowBatch& batch = iter->Value();
|
||||
// Create device shards
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(device_list_[i], i,
|
||||
row_segments[i], row_segments[i + 1], n_bins_, param_));
|
||||
shard->Init(hmat_, batch);
|
||||
});
|
||||
CHECK(!iter->Next()) << "External memory not supported";
|
||||
}
|
||||
monitor_.Stop("BinningCompression", device_list_);
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
|
||||
73
tests/cpp/common/test_gpu_compressed_iterator.cu
Normal file
73
tests/cpp/common/test_gpu_compressed_iterator.cu
Normal file
@ -0,0 +1,73 @@
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
#include <algorithm>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
struct WriteSymbolFunction {
|
||||
CompressedBufferWriter cbw;
|
||||
unsigned char* buffer_data_d;
|
||||
int* input_data_d;
|
||||
WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d,
|
||||
int* input_data_d)
|
||||
: cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {}
|
||||
|
||||
__device__ void operator()(size_t i) {
|
||||
cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i);
|
||||
}
|
||||
};
|
||||
|
||||
struct ReadSymbolFunction {
|
||||
CompressedIterator<int> ci;
|
||||
int* output_data_d;
|
||||
ReadSymbolFunction(CompressedIterator<int> ci, int* output_data_d)
|
||||
: ci(ci), output_data_d(output_data_d) {}
|
||||
|
||||
__device__ void operator()(size_t i) {
|
||||
output_data_d[i] = ci[i];
|
||||
}
|
||||
};
|
||||
|
||||
TEST(CompressedIterator, TestGPU) {
|
||||
std::vector<int> test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX};
|
||||
int num_elements = 1000;
|
||||
int repetitions = 1000;
|
||||
srand(9);
|
||||
|
||||
for (auto alphabet_size : test_cases) {
|
||||
for (int i = 0; i < repetitions; i++) {
|
||||
std::vector<int> input(num_elements);
|
||||
std::generate(input.begin(), input.end(),
|
||||
[=]() { return rand() % alphabet_size; });
|
||||
CompressedBufferWriter cbw(alphabet_size);
|
||||
thrust::device_vector<int> input_d(input);
|
||||
|
||||
thrust::device_vector<unsigned char> buffer_d(
|
||||
CompressedBufferWriter::CalculateBufferSize(input.size(),
|
||||
alphabet_size));
|
||||
|
||||
// write the data on device
|
||||
auto input_data_d = input_d.data().get();
|
||||
auto buffer_data_d = buffer_d.data().get();
|
||||
dh::LaunchN(0, input_d.size(),
|
||||
WriteSymbolFunction(cbw, buffer_data_d, input_data_d));
|
||||
|
||||
// read the data on device
|
||||
CompressedIterator<int> ci(buffer_d.data().get(), alphabet_size);
|
||||
thrust::device_vector<int> output_d(input.size());
|
||||
auto output_data_d = output_d.data().get();
|
||||
dh::LaunchN(0, output_d.size(), ReadSymbolFunction(ci, output_data_d));
|
||||
|
||||
std::vector<int> output(output_d.size());
|
||||
thrust::copy(output_d.begin(), output_d.end(), output.begin());
|
||||
|
||||
ASSERT_TRUE(input == output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -18,11 +18,19 @@ long GetFileSize(const std::string filename) {
|
||||
}
|
||||
|
||||
std::string CreateSimpleTestData() {
|
||||
return CreateBigTestData(6);
|
||||
}
|
||||
|
||||
std::string CreateBigTestData(size_t n_entries) {
|
||||
std::string tmp_file = TempFileName();
|
||||
std::ofstream fo;
|
||||
fo.open(tmp_file);
|
||||
fo << "0 0:0 1:10 2:20\n";
|
||||
fo << "1 0:0 3:30 4:40\n";
|
||||
const size_t entries_per_row = 3;
|
||||
size_t n_rows = (n_entries + entries_per_row - 1) / entries_per_row;
|
||||
for (size_t i = 0; i < n_rows; ++i) {
|
||||
const char* row = i % 2 == 0 ? " 0:0 1:10 2:20\n" : " 0:0 3:30 4:40\n";
|
||||
fo << i << row;
|
||||
}
|
||||
fo.close();
|
||||
return tmp_file;
|
||||
}
|
||||
|
||||
@ -23,6 +23,8 @@ long GetFileSize(const std::string filename);
|
||||
|
||||
std::string CreateSimpleTestData();
|
||||
|
||||
std::string CreateBigTestData(size_t n_entries);
|
||||
|
||||
void CheckObjFunction(xgboost::ObjFunction * obj,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "../../../src/data/sparse_page_source.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
||||
|
||||
@ -24,8 +25,14 @@ TEST(gpu_hist_experimental, TestSparseShard) {
|
||||
gmat.Init(dmat.get());
|
||||
TrainParam p;
|
||||
p.max_depth = 6;
|
||||
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
|
||||
p);
|
||||
|
||||
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
const RowBatch& batch = iter->Value();
|
||||
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
|
||||
shard.Init(hmat, batch);
|
||||
CHECK(!iter->Next());
|
||||
|
||||
ASSERT_LT(shard.row_stride, columns);
|
||||
|
||||
@ -59,8 +66,15 @@ TEST(gpu_hist_experimental, TestDenseShard) {
|
||||
gmat.Init(dmat.get());
|
||||
TrainParam p;
|
||||
p.max_depth = 6;
|
||||
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
|
||||
p);
|
||||
|
||||
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
const RowBatch& batch = iter->Value();
|
||||
|
||||
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
|
||||
shard.Init(hmat, batch);
|
||||
CHECK(!iter->Next());
|
||||
|
||||
ASSERT_EQ(shard.row_stride, columns);
|
||||
|
||||
@ -75,4 +89,4 @@ TEST(gpu_hist_experimental, TestDenseShard) {
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user