[GPU-Plugin] Fix gpu_hist to allow matrices with more than just 2^{32} elements. Also fixed CPU hist algorithm. (#2518)

This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-07-17 16:19:27 -07:00 committed by Rory Mitchell
parent c85bf9859e
commit ca7fc9fda3
11 changed files with 413 additions and 283 deletions

View File

@ -441,7 +441,7 @@ class bulk_allocator {
public:
~bulk_allocator() {
for (int i = 0; i < d_ptr.size(); i++) {
for (size_t i = 0; i < d_ptr.size(); i++) {
if (!(d_ptr[i] == nullptr)) {
safe_cuda(cudaSetDevice(_device_idx[i]));
safe_cuda(cudaFree(d_ptr[i]));
@ -522,7 +522,7 @@ inline size_t available_memory(int device_idx) {
template <typename T>
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
thrust::host_vector<T> h = v;
for (int i = 0; i < std::min(max_items, h.size()); i++) {
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
@ -531,7 +531,7 @@ void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
template <typename T>
void print(const dvec<T> &v, size_t max_items = 10) {
std::vector<T> h = v.as_vector();
for (int i = 0; i < std::min(max_items, h.size()); i++) {
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
@ -539,10 +539,10 @@ void print(const dvec<T> &v, size_t max_items = 10) {
template <typename T>
void print(char *label, const thrust::device_vector<T> &v,
const char *format = "%d ", int max = 10) {
const char *format = "%d ", size_t max = 10) {
thrust::host_vector<T> h_v = v;
std::cout << label << ":\n";
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
for (size_t i = 0; i < std::min(static_cast<size_t>(h_v.size()), max); i++) {
printf(format, h_v[i]);
}
std::cout << "\n";
@ -593,6 +593,7 @@ __global__ void launch_n_kernel(int device_idx, size_t begin, size_t end,
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void launch_n(int device_idx, size_t n, L lambda) {
safe_cuda(cudaSetDevice(device_idx));
// TODO: Template on n so GRID_SIZE always fits into int.
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
@ -607,6 +608,7 @@ inline void multi_launch_n(size_t n, int n_devices, L lambda) {
CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested "
"needs to be less than equal to "
"number of visible devices.";
// TODO: Template on n so GRID_SIZE always fits into int.
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
n_devices = n_devices > n ? n : n_devices;

View File

@ -19,8 +19,8 @@ namespace xgboost {
namespace tree {
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
bst_uint element_begin, bst_uint element_end,
bst_uint row_begin, bst_uint row_end, int n_bins) {
bst_ulong element_begin, bst_ulong element_end,
bst_ulong row_begin, bst_ulong row_end, int n_bins) {
dh::safe_cuda(cudaSetDevice(device_idx));
CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated";
CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1)
@ -31,15 +31,15 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin,
gmat.index.begin() + element_end);
gidx_buffer = host_buffer;
gidx = common::CompressedIterator<int>(gidx_buffer.data(), n_bins);
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.data(), n_bins);
// row_ptr
thrust::copy(gmat.row_ptr.data() + row_begin,
gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
// normalise row_ptr
bst_uint start = gmat.row_ptr[row_begin];
size_t start = gmat.row_ptr[row_begin];
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
[=] __device__(int val) { return val - start; });
[=] __device__(size_t val) { return val - start; });
}
void DeviceHist::Init(int n_bins_in) {
@ -193,7 +193,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
device_row_segments.push_back(0);
device_element_segments.push_back(0);
bst_uint offset = 0;
size_t shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
bst_uint shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
offset += shard_size;
@ -261,7 +261,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
int device_idx = dList[d_idx];
bst_uint num_rows_segment =
device_row_segments[d_idx + 1] - device_row_segments[d_idx];
bst_uint num_elements_segment =
bst_ulong num_elements_segment =
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
ba.allocate(
device_idx, &(hist_vec[d_idx].data),
@ -360,7 +360,7 @@ void GPUHistBuilder::BuildHist(int depth) {
auto hist_builder = hist_vec[d_idx].GetBuilder();
dh::TransformLbs(
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
int nidx = d_position[local_ridx]; // OPTMARK: latency
if (!is_active(nidx, depth)) return;
@ -606,7 +606,11 @@ __global__ void find_split_kernel(
}
#define MIN_BLOCK_THREADS 32
#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size
#define CHUNK_BLOCK_THREADS 32
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
// to CUDA compatibility 35 and above requirement
// for Maximum number of threads per block
#define MAX_BLOCK_THREADS 1024
void GPUHistBuilder::FindSplit(int depth) {
// Specialised based on max_bins
@ -622,7 +626,7 @@ void GPUHistBuilder::FindSplitSpecialize(int depth) {
if (param.max_bin <= BLOCK_THREADS) {
LaunchFindSplit<BLOCK_THREADS>(depth);
} else {
this->FindSplitSpecialize<BLOCK_THREADS + 32>(depth);
this->FindSplitSpecialize<BLOCK_THREADS + CHUNK_BLOCK_THREADS>(depth);
}
}
@ -885,7 +889,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
size_t begin = device_row_segments[d_idx];
size_t end = device_row_segments[d_idx + 1];
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
dh::launch_n(device_idx, end - begin, [=] __device__(size_t local_idx) {
int pos = d_position[local_idx];
if (!is_active(pos, depth)) {
return;
@ -896,7 +900,8 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
return;
}
int gidx = d_gidx[local_idx * n_columns + node.split.findex];
int gidx = d_gidx[local_idx *
static_cast<size_t>(n_columns) + static_cast<size_t>(node.split.findex)];
float fvalue = d_gidx_fvalue_map[gidx];
@ -955,7 +960,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
dh::TransformLbs(
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
int pos = d_position[local_ridx];
if (!is_active(pos, depth)) {
return;
@ -1065,25 +1070,13 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode(gpair);
this->ColSampleTree();
// long long int elapsed=0;
for (int depth = 0; depth < param.max_depth; depth++) {
this->ColSampleLevel();
// dh::Timer time;
this->BuildHist(depth);
// elapsed+=time.elapsed();
// printf("depth=%d\n",depth);
// time.printElapsed("BH Time");
// dh::Timer timesplit;
this->FindSplit(depth);
// timesplit.printElapsed("FS Time");
// dh::Timer timeupdatepos;
this->UpdatePosition(depth);
// timeupdatepos.printElapsed("UP Time");
}
// printf("Total BuildHist Time=%lld\n",elapsed);
// done with multi-GPU, pass back result from master to tree on host
int master_device = dList[0];

View File

@ -18,10 +18,10 @@ namespace tree {
struct DeviceGMat {
dh::dvec<common::compressed_byte_t> gidx_buffer;
common::CompressedIterator<int > gidx;
dh::dvec<int> row_ptr;
common::CompressedIterator<uint32_t> gidx;
dh::dvec<size_t> row_ptr;
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins);
bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, bst_ulong row_end,int n_bins);
};
struct HistBuilder {
@ -99,7 +99,7 @@ class GPUHistBuilder {
// below vectors are for each devices used
std::vector<int> dList;
std::vector<int> device_row_segments;
std::vector<int> device_element_segments;
std::vector<size_t> device_element_segments;
std::vector<dh::CubMemory> temp_memory;
std::vector<DeviceHist> hist_vec;

View File

@ -0,0 +1,134 @@
from __future__ import print_function
#pylint: skip-file
import sys
sys.path.append("../../tests/python")
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
from sklearn.datasets import make_classification
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
eprint("Testing Big Data (this may take a while)")
rng = np.random.RandomState(1994)
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
cols = 31
# reduced to fit onto 1 gpu but still be large
rows2 = 5000 # medium
#rows2 = 4032 # fake large for testing
rows1 = 42360032 # large
#rows2 = 152360032 # can do this for multi-gpu test (very large)
rowslist = [rows1, rows2]
class TestGPU(unittest.TestCase):
def test_large(self):
eprint("Starting test for large data")
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
for rows in rowslist:
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
X, y = make_classification(rows, n_features=cols, random_state=7)
rowstest = int(rows*0.2)
eprint("Creating test data rows=%d cols=%d" % (rowstest,cols))
# note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot
Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8)
eprint("Starting DMatrix(X,y)")
ag_dtrain = xgb.DMatrix(X,y)
eprint("Starting DMatrix(Xtest,ytest)")
ag_dtest = xgb.DMatrix(Xtest,ytest)
max_depth=6
max_bin=1024
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
#'nthread': 1,
'eta': 1,
'silent': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_paramb = {'max_depth': max_depth,
'tree_method': 'hist',
#'nthread': 1,
'eta': 1,
'silent': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 0,
'n_gpus': 1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 0,
'n_gpus': -1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
#ag_param4 = {'max_depth': max_depth,
# 'tree_method': 'gpu_exact',
# 'eta': 1,
# 'silent': 0,
# 'n_gpus': 1,
# 'objective': 'binary:logistic',
# 'max_bin': max_bin,
# 'eval_metric': 'auc'}
ag_res = {}
ag_resb = {}
ag_res2 = {}
ag_res3 = {}
#ag_res4 = {}
num_rounds = 1
eprint("normal updater")
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
eprint("hist updater")
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_resb)
eprint("gpu_hist updater 1 gpu")
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
eprint("gpu_hist updater all gpus")
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res3)
#eprint("gpu_exact updater")
#xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
# evals_result=ag_res4)
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
#assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001
assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01
assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01
assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01
#assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01

View File

@ -57,7 +57,7 @@ class Column {
ColumnType type;
const T* index;
uint32_t index_base;
const uint32_t* row_ind;
const size_t* row_ind;
size_t len;
};
@ -66,8 +66,8 @@ class Column {
class ColumnMatrix {
public:
// get number of features
inline uint32_t GetNumFeature() const {
return type_.size();
inline bst_uint GetNumFeature() const {
return static_cast<bst_uint>(type_.size());
}
// construct column matrix from GHistIndexMatrix
@ -78,8 +78,8 @@ class ColumnMatrix {
slot of internal buffer. */
packing_factor_ = sizeof(uint32_t) / static_cast<size_t>(this->dtype);
const uint32_t nfeature = gmat.cut->row_ptr.size() - 1;
const omp_ulong nrow = static_cast<omp_ulong>(gmat.row_ptr.size() - 1);
const bst_uint nfeature = static_cast<bst_uint>(gmat.cut->row_ptr.size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column
feature_counts_.resize(nfeature);
@ -90,13 +90,13 @@ class ColumnMatrix {
XGBOOST_TYPE_SWITCH(this->dtype, {
max_val = static_cast<uint32_t>(std::numeric_limits<DType>::max());
});
for (uint32_t fid = 0; fid < nfeature; ++fid) {
for (bst_uint fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val);
}
gmat.GetFeatureCounts(&feature_counts_[0]);
// classify features
for (uint32_t fid = 0; fid < nfeature; ++fid) {
for (bst_uint fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid])
< param.sparse_threshold * nrow) {
type_[fid] = kSparseColumn;
@ -108,13 +108,13 @@ class ColumnMatrix {
// want to compute storage boundary for each feature
// using variants of prefix sum scan
boundary_.resize(nfeature);
bst_uint accum_index_ = 0;
bst_uint accum_row_ind_ = 0;
for (uint32_t fid = 0; fid < nfeature; ++fid) {
size_t accum_index_ = 0;
size_t accum_row_ind_ = 0;
for (bst_uint fid = 0; fid < nfeature; ++fid) {
boundary_[fid].index_begin = accum_index_;
boundary_[fid].row_ind_begin = accum_row_ind_;
if (type_[fid] == kDenseColumn) {
accum_index_ += nrow;
accum_index_ += static_cast<size_t>(nrow);
} else {
accum_index_ += feature_counts_[fid];
accum_row_ind_ += feature_counts_[fid];
@ -129,14 +129,14 @@ class ColumnMatrix {
// store least bin id for each feature
index_base_.resize(nfeature);
for (uint32_t fid = 0; fid < nfeature; ++fid) {
for (bst_uint fid = 0; fid < nfeature; ++fid) {
index_base_[fid] = gmat.cut->row_ptr[fid];
}
// fill index_ for dense columns
for (uint32_t fid = 0; fid < nfeature; ++fid) {
// pre-fill index_ for dense columns
for (bst_uint fid = 0; fid < nfeature; ++fid) {
if (type_[fid] == kDenseColumn) {
const uint32_t ibegin = boundary_[fid].index_begin;
const size_t ibegin = boundary_[fid].index_begin;
XGBOOST_TYPE_SWITCH(this->dtype, {
const size_t block_offset = ibegin / packing_factor_;
const size_t elem_offset = ibegin % packing_factor_;
@ -150,15 +150,15 @@ class ColumnMatrix {
// loop over all rows and fill column entries
// num_nonzeros[fid] = how many nonzeros have this feature accumulated so far?
std::vector<uint32_t> num_nonzeros;
std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature);
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
for (uint32_t rid = 0; rid < nrow; ++rid) {
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
for (size_t rid = 0; rid < nrow; ++rid) {
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
size_t fid = 0;
for (size_t i = ibegin; i < iend; ++i) {
const size_t bin_id = gmat.index[i];
const uint32_t bin_id = gmat.index[i];
while (bin_id >= gmat.cut->row_ptr[fid + 1]) {
++fid;
}
@ -167,14 +167,14 @@ class ColumnMatrix {
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
begin[rid] = bin_id - index_base_[fid];
begin[rid] = static_cast<DType>(bin_id - index_base_[fid]);
});
} else {
XGBOOST_TYPE_SWITCH(this->dtype, {
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
begin[num_nonzeros[fid]] = static_cast<DType>(bin_id - index_base_[fid]);
});
row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid;
++num_nonzeros[fid];
@ -213,16 +213,16 @@ class ColumnMatrix {
// indicate where each column's index and row_ind is stored.
// index_begin and index_end are logical offsets, so they should be converted to
// actual offsets by scaling with packing_factor_
unsigned index_begin;
unsigned index_end;
unsigned row_ind_begin;
unsigned row_ind_end;
size_t index_begin;
size_t index_end;
size_t row_ind_begin;
size_t row_ind_end;
};
std::vector<bst_uint> feature_counts_;
std::vector<size_t> feature_counts_;
std::vector<ColumnType> type_;
std::vector<uint32_t> index_; // index_: may store smaller integers; needs padding
std::vector<uint32_t> row_ind_;
std::vector<size_t> row_ind_;
std::vector<ColumnBoundary> boundary_;
size_t packing_factor_; // how many integers are stored in each slot of index_

View File

@ -46,11 +46,11 @@ static int SymbolBits(int num_symbols) {
class CompressedBufferWriter {
private:
int symbol_bits_;
size_t symbol_bits_;
size_t offset_;
public:
explicit CompressedBufferWriter(int num_symbols) : offset_(0) {
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}
@ -70,9 +70,9 @@ class CompressedBufferWriter {
* \return The calculated buffer size.
*/
static size_t CalculateBufferSize(int num_elements, int num_symbols) {
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
const int bits_per_byte = 8;
int compressed_size = std::ceil(
size_t compressed_size = std::ceil(
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
bits_per_byte);
return compressed_size + detail::padding;
@ -82,10 +82,10 @@ class CompressedBufferWriter {
void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) {
const int bits_per_byte = 8;
for (int i = 0; i < symbol_bits_; i++) {
for (size_t i = 0; i < symbol_bits_; i++) {
size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte;
byte_idx += detail::padding;
int bit_idx =
size_t bit_idx =
((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte;
if (detail::CheckBit(symbol, i)) {
@ -100,14 +100,14 @@ class CompressedBufferWriter {
uint64_t tmp = 0;
int stored_bits = 0;
const int max_stored_bits = 64 - symbol_bits_;
int buffer_position = detail::padding;
const int num_symbols = input_end - input_begin;
for (int i = 0; i < num_symbols; i++) {
size_t buffer_position = detail::padding;
const size_t num_symbols = input_end - input_begin;
for (size_t i = 0; i < num_symbols; i++) {
typename std::iterator_traits<iter_t>::value_type symbol = input_begin[i];
if (stored_bits > max_stored_bits) {
// Eject only full bytes
int tmp_bytes = stored_bits / 8;
for (int j = 0; j < tmp_bytes; j++) {
size_t tmp_bytes = stored_bits / 8;
for (size_t j = 0; j < tmp_bytes; j++) {
buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8);
buffer_position++;
}
@ -121,8 +121,8 @@ class CompressedBufferWriter {
}
// Eject all bytes
int tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
for (int j = 0; j < tmp_bytes; j++) {
size_t tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
for (size_t j = 0; j < tmp_bytes; j++) {
int shift_bits = stored_bits - (j + 1) * 8;
if (shift_bits >= 0) {
buffer[buffer_position] = tmp >> shift_bits;
@ -159,7 +159,7 @@ class CompressedIterator {
/// iterator can point to
private:
compressed_byte_t *buffer_;
int symbol_bits_;
size_t symbol_bits_;
size_t offset_;
public:
@ -189,7 +189,7 @@ class CompressedIterator {
return static_cast<T>(tmp & mask);
}
XGBOOST_DEVICE reference operator[](int idx) const {
XGBOOST_DEVICE reference operator[](size_t idx) const {
self_type offset = (*this);
offset.offset_ += idx;
return *offset;

View File

@ -16,7 +16,7 @@
namespace xgboost {
namespace common {
void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
const MetaInfo& info = p_fmat->info();
@ -44,7 +44,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);
for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*)
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
size_t ridx = batch.base_rowid + i;
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
@ -108,7 +108,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
const int nthread = omp_get_max_threads();
const unsigned nbins = cut->row_ptr.back();
const uint32_t nbins = cut->row_ptr.back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
@ -116,7 +116,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
row_ptr.push_back(0);
while (iter->Next()) {
const RowBatch& batch = iter->Value();
size_t rbegin = row_ptr.size() - 1;
const size_t rbegin = row_ptr.size() - 1;
for (size_t i = 0; i < batch.size; ++i) {
row_ptr.push_back(batch[i].length + row_ptr.back());
}
@ -140,7 +140,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
CHECK(cbegin != cend);
auto it = std::upper_bound(cbegin, cend, inst[j].fvalue);
if (it == cend) it = cend - 1;
unsigned idx = static_cast<unsigned>(it - cut->cut.begin());
uint32_t idx = static_cast<uint32_t>(it - cut->cut.begin());
index[ibegin + j] = idx;
++hit_count_tloc_[tid * nbins + idx];
}
@ -148,7 +148,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
}
#pragma omp parallel for num_threads(nthread) schedule(static)
for (omp_ulong idx = 0; idx < nbins; ++idx) {
for (bst_omp_uint idx = 0; idx < nbins; ++idx) {
for (int tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
}
@ -157,10 +157,10 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
}
template <typename T>
static unsigned GetConflictCount(const std::vector<bool>& mark,
const Column<T>& column,
unsigned max_cnt) {
unsigned ret = 0;
static size_t GetConflictCount(const std::vector<bool>& mark,
const Column<T>& column,
size_t max_cnt) {
size_t ret = 0;
if (column.type == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.len; ++i) {
if (column.index[i] != std::numeric_limits<T>::max() && mark[i]) {
@ -203,9 +203,9 @@ MarkUsed(std::vector<bool>* p_mark, const Column<T>& column) {
template <typename T>
inline std::vector<std::vector<unsigned>>
FindGroups_(const std::vector<unsigned>& feature_list,
const std::vector<bst_uint>& feature_nnz,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
unsigned nrow,
size_t nrow,
const FastHistParam& param) {
/* Goal: Bundle features together that has little or no "overlap", i.e.
only a few data points should have nonzero values for
@ -214,10 +214,10 @@ FindGroups_(const std::vector<unsigned>& feature_list,
std::vector<std::vector<unsigned>> groups;
std::vector<std::vector<bool>> conflict_marks;
std::vector<unsigned> group_nnz;
std::vector<unsigned> group_conflict_cnt;
const unsigned max_conflict_cnt
= static_cast<unsigned>(param.max_conflict_rate * nrow);
std::vector<size_t> group_nnz;
std::vector<size_t> group_conflict_cnt;
const size_t max_conflict_cnt
= static_cast<size_t>(param.max_conflict_rate * nrow);
for (auto fid : feature_list) {
const Column<T>& column = colmat.GetColumn<T>(fid);
@ -239,8 +239,8 @@ FindGroups_(const std::vector<unsigned>& feature_list,
// examine each candidate group: is it okay to insert fid?
for (auto gid : search_groups) {
const unsigned rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid];
const unsigned cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt);
const size_t rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid];
const size_t cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt);
if (cnt <= rest_max_cnt) {
need_new_group = false;
groups[gid].push_back(fid);
@ -267,9 +267,9 @@ FindGroups_(const std::vector<unsigned>& feature_list,
inline std::vector<std::vector<unsigned>>
FindGroups(const std::vector<unsigned>& feature_list,
const std::vector<bst_uint>& feature_nnz,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
unsigned nrow,
size_t nrow,
const FastHistParam& param) {
XGBOOST_TYPE_SWITCH(colmat.dtype, {
return FindGroups_<DType>(feature_list, feature_nnz, colmat, nrow, param);
@ -288,11 +288,11 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
std::iota(feature_list.begin(), feature_list.end(), 0);
// sort features by nonzero counts, descending order
std::vector<bst_uint> feature_nnz(nfeature);
std::vector<size_t> feature_nnz(nfeature);
std::vector<unsigned> features_by_nnz(feature_list);
gmat.GetFeatureCounts(&feature_nnz[0]);
std::sort(features_by_nnz.begin(), features_by_nnz.end(),
[&feature_nnz](int a, int b) {
[&feature_nnz](unsigned a, unsigned b) {
return feature_nnz[a] > feature_nnz[b];
});
@ -307,7 +307,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
if (group.size() <= 1 || group.size() >= 5) {
ret.push_back(group); // keep singleton groups and large (5+) groups
} else {
unsigned nnz = 0;
size_t nnz = 0;
for (auto fid : group) {
nnz += feature_nnz[fid];
}
@ -338,37 +338,37 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
cut = gmat.cut;
const size_t nrow = gmat.row_ptr.size() - 1;
const size_t nbins = gmat.cut->row_ptr.back();
const uint32_t nbins = gmat.cut->row_ptr.back();
/* step 1: form feature groups */
auto groups = FastFeatureGrouping(gmat, colmat, param);
const size_t nblock = groups.size();
const uint32_t nblock = static_cast<uint32_t>(groups.size());
/* step 2: build a new CSR matrix for each feature group */
std::vector<unsigned> bin2block(nbins); // lookup table [bin id] => [block id]
for (size_t group_id = 0; group_id < nblock; ++group_id) {
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
for (auto& fid : groups[group_id]) {
const unsigned bin_begin = gmat.cut->row_ptr[fid];
const unsigned bin_end = gmat.cut->row_ptr[fid + 1];
for (unsigned bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
const uint32_t bin_begin = gmat.cut->row_ptr[fid];
const uint32_t bin_end = gmat.cut->row_ptr[fid + 1];
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
bin2block[bin_id] = group_id;
}
}
}
std::vector<std::vector<unsigned>> index_temp(nblock);
std::vector<std::vector<unsigned>> row_ptr_temp(nblock);
for (size_t block_id = 0; block_id < nblock; ++block_id) {
std::vector<std::vector<uint32_t>> index_temp(nblock);
std::vector<std::vector<size_t>> row_ptr_temp(nblock);
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
row_ptr_temp[block_id].push_back(0);
}
for (size_t rid = 0; rid < nrow; ++rid) {
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
for (size_t j = ibegin; j < iend; ++j) {
const size_t bin_id = gmat.index[j];
const size_t block_id = bin2block[bin_id];
const uint32_t bin_id = gmat.index[j];
const uint32_t block_id = bin2block[bin_id];
index_temp[block_id].push_back(bin_id);
}
for (size_t block_id = 0; block_id < nblock; ++block_id) {
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
row_ptr_temp[block_id].push_back(index_temp[block_id].size());
}
}
@ -378,7 +378,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
std::vector<size_t> row_ptr_blk_ptr;
index_blk_ptr.push_back(0);
row_ptr_blk_ptr.push_back(0);
for (size_t block_id = 0; block_id < nblock; ++block_id) {
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
index.insert(index.end(), index_temp[block_id].begin(), index_temp[block_id].end());
row_ptr.insert(row_ptr.end(), row_ptr_temp[block_id].begin(), row_ptr_temp[block_id].end());
index_blk_ptr.push_back(index.size());
@ -386,7 +386,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
}
// save shortcut for each block
for (size_t block_id = 0; block_id < nblock; ++block_id) {
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
Block blk;
blk.index_begin = &index[index_blk_ptr[block_id]];
blk.row_ptr_begin = &row_ptr[row_ptr_blk_ptr[block_id]];
@ -406,14 +406,14 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
const int K = 8; // loop unrolling factor
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
const bst_omp_uint nrows = row_indices.end - row_indices.begin;
const bst_omp_uint rest = nrows % K;
const size_t nrows = row_indices.end - row_indices.begin;
const size_t rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
const bst_omp_uint tid = omp_get_thread_num();
const size_t off = tid * nbins_;
bst_uint rid[K];
size_t rid[K];
size_t ibegin[K];
size_t iend[K];
bst_gpair stat[K];
@ -421,32 +421,32 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
rid[k] = row_indices.begin[i + k];
}
for (int k = 0; k < K; ++k) {
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1];
}
for (int k = 0; k < K; ++k) {
stat[k] = gpair[rid[k]];
}
for (int k = 0; k < K; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const size_t bin = gmat.index[j];
const uint32_t bin = gmat.index[j];
data_[off + bin].Add(stat[k]);
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = row_indices.begin[i];
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
const size_t rid = row_indices.begin[i];
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const bst_gpair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) {
const size_t bin = gmat.index[j];
const uint32_t bin = gmat.index[j];
data_[bin].Add(stat);
}
}
/* reduction */
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
const uint32_t nbins = nbins_;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
@ -462,16 +462,16 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
GHistRow hist) {
const int K = 8; // loop unrolling factor
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
const bst_omp_uint nblock = gmatb.GetNumBlock();
const bst_omp_uint nrows = row_indices.end - row_indices.begin;
const bst_omp_uint rest = nrows % K;
const uint32_t nblock = gmatb.GetNumBlock();
const size_t nrows = row_indices.end - row_indices.begin;
const size_t rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
auto gmat = gmatb[bid];
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
bst_uint rid[K];
for (size_t i = 0; i < nrows - rest; i += K) {
size_t rid[K];
size_t ibegin[K];
size_t iend[K];
bst_gpair stat[K];
@ -479,26 +479,26 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
rid[k] = row_indices.begin[i + k];
}
for (int k = 0; k < K; ++k) {
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1];
}
for (int k = 0; k < K; ++k) {
stat[k] = gpair[rid[k]];
}
for (int k = 0; k < K; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const size_t bin = gmat.index[j];
const uint32_t bin = gmat.index[j];
hist.begin[bin].Add(stat[k]);
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = row_indices.begin[i];
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
const size_t rid = row_indices.begin[i];
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const bst_gpair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) {
const size_t bin = gmat.index[j];
const uint32_t bin = gmat.index[j];
hist.begin[bin].Add(stat);
}
}
@ -507,9 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
const uint32_t nbins = static_cast<bst_omp_uint>(nbins_);
const int K = 8; // loop unrolling factor
const bst_omp_uint rest = nbins % K;
const uint32_t rest = nbins % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0; bin_id < nbins - rest; bin_id += K) {
GHistEntry pb[K];
@ -524,7 +524,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
self.begin[bin_id + k].SetSubtract(pb[k], sb[k]);
}
}
for (bst_omp_uint bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]);
}
}

View File

@ -56,30 +56,30 @@ struct HistCutUnit {
/*! \brief the index pointer of each histunit */
const bst_float* cut;
/*! \brief number of cutting point, containing the maximum point */
size_t size;
uint32_t size;
// default constructor
HistCutUnit() {}
// constructor
HistCutUnit(const bst_float* cut, unsigned size)
HistCutUnit(const bst_float* cut, uint32_t size)
: cut(cut), size(size) {}
};
/*! \brief cut configuration for all the features */
struct HistCutMatrix {
/*! \brief actual unit pointer */
std::vector<unsigned> row_ptr;
/*! \brief unit pointer to rows by element position */
std::vector<uint32_t> row_ptr;
/*! \brief minimum value of each feature */
std::vector<bst_float> min_val;
/*! \brief the cut field */
std::vector<bst_float> cut;
/*! \brief Get histogram bound for fid */
inline HistCutUnit operator[](unsigned fid) const {
inline HistCutUnit operator[](bst_uint fid) const {
return HistCutUnit(dmlc::BeginPtr(cut) + row_ptr[fid],
row_ptr[fid + 1] - row_ptr[fid]);
}
// create histogram cut matrix given statistics from data
// using approximate quantile sketch approach
void Init(DMatrix* p_fmat, size_t max_num_bins);
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
};
@ -89,11 +89,11 @@ struct HistCutMatrix {
*/
struct GHistIndexRow {
/*! \brief The index of the histogram */
const unsigned* index;
const uint32_t* index;
/*! \brief The size of the histogram */
unsigned size;
size_t size;
GHistIndexRow() {}
GHistIndexRow(const unsigned* index, unsigned size)
GHistIndexRow(const uint32_t* index, size_t size)
: index(index), size(size) {}
};
@ -103,21 +103,21 @@ struct GHistIndexRow {
* This is a global histogram index.
*/
struct GHistIndexMatrix {
/*! \brief row pointer */
std::vector<unsigned> row_ptr;
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
/*! \brief The index data */
std::vector<unsigned> index;
std::vector<uint32_t> index;
/*! \brief hit count of each index */
std::vector<unsigned> hit_count;
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
const HistCutMatrix* cut;
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat);
// get i-th row
inline GHistIndexRow operator[](bst_uint i) const {
inline GHistIndexRow operator[](size_t i) const {
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
}
inline void GetFeatureCounts(bst_uint* counts) const {
inline void GetFeatureCounts(size_t* counts) const {
const unsigned nfeature = cut->row_ptr.size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
const unsigned ibegin = cut->row_ptr[fid];
@ -129,18 +129,18 @@ struct GHistIndexMatrix {
}
private:
std::vector<unsigned> hit_count_tloc_;
std::vector<size_t> hit_count_tloc_;
};
struct GHistIndexBlock {
const unsigned* row_ptr;
const unsigned* index;
const size_t* row_ptr;
const uint32_t* index;
inline GHistIndexBlock(const unsigned* row_ptr, const unsigned* index)
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
: row_ptr(row_ptr), index(index) {}
// get i-th row
inline GHistIndexRow operator[](bst_uint i) const {
inline GHistIndexRow operator[](size_t i) const {
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
}
};
@ -153,23 +153,23 @@ class GHistIndexBlockMatrix {
const ColumnMatrix& colmat,
const FastHistParam& param);
inline GHistIndexBlock operator[](bst_uint i) const {
inline GHistIndexBlock operator[](size_t i) const {
return GHistIndexBlock(blocks[i].row_ptr_begin, blocks[i].index_begin);
}
inline unsigned GetNumBlock() const {
inline size_t GetNumBlock() const {
return blocks.size();
}
private:
std::vector<unsigned> row_ptr;
std::vector<unsigned> index;
std::vector<size_t> row_ptr;
std::vector<uint32_t> index;
const HistCutMatrix* cut;
struct Block {
const unsigned* row_ptr_begin;
const unsigned* row_ptr_end;
const unsigned* index_begin;
const unsigned* index_end;
const size_t* row_ptr_begin;
const size_t* row_ptr_end;
const uint32_t* index_begin;
const uint32_t* index_end;
};
std::vector<Block> blocks;
};
@ -184,10 +184,10 @@ struct GHistRow {
/*! \brief base pointer to first entry */
GHistEntry* begin;
/*! \brief number of entries */
unsigned size;
uint32_t size;
GHistRow() {}
GHistRow(GHistEntry* begin, unsigned size)
GHistRow(GHistEntry* begin, uint32_t size)
: begin(begin), size(size) {}
};
@ -198,19 +198,19 @@ class HistCollection {
public:
// access histogram for i-th node
inline GHistRow operator[](bst_uint nid) const {
const size_t kMax = std::numeric_limits<size_t>::max();
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
CHECK_NE(row_ptr_[nid], kMax);
return GHistRow(const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_);
}
// have we computed a histogram for i-th node?
inline bool RowExists(bst_uint nid) const {
const size_t kMax = std::numeric_limits<size_t>::max();
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != kMax);
}
// initialize histogram collection
inline void Init(size_t nbins) {
inline void Init(uint32_t nbins) {
nbins_ = nbins;
row_ptr_.clear();
data_.clear();
@ -218,7 +218,7 @@ class HistCollection {
// create an empty histogram for i-th node
inline void AddHistRow(bst_uint nid) {
const size_t kMax = std::numeric_limits<size_t>::max();
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
if (nid >= row_ptr_.size()) {
row_ptr_.resize(nid + 1, kMax);
}
@ -230,12 +230,12 @@ class HistCollection {
private:
/*! \brief number of all bins over all features */
size_t nbins_;
uint32_t nbins_;
std::vector<GHistEntry> data_;
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
std::vector<size_t> row_ptr_;
std::vector<uint32_t> row_ptr_;
};
/*!
@ -244,7 +244,7 @@ class HistCollection {
class GHistBuilder {
public:
// initialize builder
inline void Init(size_t nthread, size_t nbins) {
inline void Init(size_t nthread, uint32_t nbins) {
nthread_ = nthread;
nbins_ = nbins;
}
@ -268,7 +268,7 @@ class GHistBuilder {
/*! \brief number of threads for parallel computation */
size_t nthread_;
/*! \brief number of all bins over all features */
size_t nbins_;
uint32_t nbins_;
std::vector<GHistEntry> data_;
};

View File

@ -21,14 +21,14 @@ class RowSetCollection {
* rows (instances) associated with a particular node in a decision
* tree. */
struct Elem {
const bst_uint* begin;
const bst_uint* end;
const size_t* begin;
const size_t* end;
int node_id;
// id of node associated with this instance set; -1 means uninitialized
Elem(void)
: begin(nullptr), end(nullptr), node_id(-1) {}
Elem(const bst_uint* begin,
const bst_uint* end,
Elem(const size_t* begin,
const size_t* end,
int node_id)
: begin(begin), end(end), node_id(node_id) {}
@ -38,8 +38,8 @@ class RowSetCollection {
};
/* \brief specifies how to split a rowset into two */
struct Split {
std::vector<bst_uint> left;
std::vector<bst_uint> right;
std::vector<size_t> left;
std::vector<size_t> right;
};
inline std::vector<Elem>::const_iterator begin() const {
@ -65,8 +65,8 @@ class RowSetCollection {
// initialize node id 0->everything
inline void Init() {
CHECK_EQ(elem_of_each_node_.size(), 0U);
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
const size_t* begin = dmlc::BeginPtr(row_indices_);
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
}
// split rowset into two
@ -77,16 +77,15 @@ class RowSetCollection {
const Elem e = elem_of_each_node_[node_id];
const unsigned nthread = row_split_tloc.size();
CHECK(e.begin != nullptr);
bst_uint* all_begin = dmlc::BeginPtr(row_indices_);
bst_uint* begin = all_begin + (e.begin - all_begin);
size_t* all_begin = dmlc::BeginPtr(row_indices_);
size_t* begin = all_begin + (e.begin - all_begin);
bst_uint* it = begin;
// TODO(hcho3): parallelize this section
size_t* it = begin;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
it += row_split_tloc[tid].left.size();
}
bst_uint* split_pt = it;
size_t* split_pt = it;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
it += row_split_tloc[tid].right.size();
@ -105,7 +104,7 @@ class RowSetCollection {
}
// stores the row indices in the set
std::vector<bst_uint> row_indices_;
std::vector<size_t> row_indices_;
private:
// vector: node_id -> elements

View File

@ -61,7 +61,7 @@ class FastHistMaker: public TreeUpdater {
TStats::CheckInfo(dmat->info());
if (is_gmat_initialized_ == false) {
double tstart = dmlc::GetTime();
hmat_.Init(dmat, param.max_bin);
hmat_.Init(dmat, static_cast<uint32_t>(param.max_bin));
gmat_.cut = &hmat_;
gmat_.Init(dmat);
column_matrix_.Init(gmat_, fhparam);
@ -111,23 +111,6 @@ class FastHistMaker: public TreeUpdater {
bool is_gmat_initialized_;
// data structure
/*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry {
/*! \brief statistics of data */
TStats stats;
/*! \brief extra statistics of data */
TStats stats_extra;
/*! \brief last feature value scanned */
float last_fvalue;
/*! \brief first feature value scanned */
float first_fvalue;
/*! \brief current best solution */
SplitEntry best;
// constructor
explicit ThreadEntry(const TrainParam& param)
: stats(param), stats_extra(param) {
}
};
struct NodeEntry {
/*! \brief statics for node entry */
TStats stats;
@ -340,7 +323,7 @@ class FastHistMaker: public TreeUpdater {
}
leaf_value = (*p_last_tree_)[nid].leaf_value();
for (const bst_uint* it = rowset.begin; it < rowset.end; ++it) {
for (const size_t* it = rowset.begin; it < rowset.end; ++it) {
out_preds[*it] += leaf_value;
}
}
@ -372,7 +355,7 @@ class FastHistMaker: public TreeUpdater {
// clear local prediction cache
leaf_value_cache_.clear();
// initialize histogram collection
size_t nbins = gmat.cut->row_ptr.back();
uint32_t nbins = gmat.cut->row_ptr.back();
hist_.Init(nbins);
// initialize histogram builder
@ -383,18 +366,18 @@ class FastHistMaker: public TreeUpdater {
hist_builder_.Init(this->nthread, nbins);
CHECK_EQ(info.root_index.size(), 0U);
std::vector<bst_uint>& row_indices = row_set_collection_.row_indices_;
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
// mark subsample and build list of member rows
if (param.subsample < 1.0f) {
std::bernoulli_distribution coin_flip(param.subsample);
auto& rnd = common::GlobalRandom();
for (bst_uint i = 0; i < info.num_row; ++i) {
for (size_t i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f && coin_flip(rnd)) {
row_indices.push_back(i);
}
}
} else {
for (bst_uint i = 0; i < info.num_row; ++i) {
for (size_t i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f) {
row_indices.push_back(i);
}
@ -405,11 +388,11 @@ class FastHistMaker: public TreeUpdater {
{
/* determine layout of data */
const auto nrow = info.num_row;
const auto ncol = info.num_col;
const auto nnz = info.num_nonzero;
const size_t nrow = info.num_row;
const size_t ncol = info.num_col;
const size_t nnz = info.num_nonzero;
// number of discrete bins for feature 0
const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
const uint32_t nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
if (nrow * ncol == nnz) {
// dense data with zero-based indexing
data_layout_ = kDenseDataZeroBased;
@ -427,19 +410,19 @@ class FastHistMaker: public TreeUpdater {
// store a pointer to training data
p_last_fmat_ = &fmat;
// initialize feature index
unsigned ncol = static_cast<unsigned>(info.num_col);
bst_uint ncol = static_cast<bst_uint>(info.num_col);
feat_index.clear();
if (data_layout_ == kDenseDataOneBased) {
for (unsigned i = 1; i < ncol; ++i) {
for (bst_uint i = 1; i < ncol; ++i) {
feat_index.push_back(i);
}
} else {
for (unsigned i = 0; i < ncol; ++i) {
for (bst_uint i = 0; i < ncol; ++i) {
feat_index.push_back(i);
}
}
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param.colsample_bytree * feat_index.size()));
bst_uint n = std::max(static_cast<bst_uint>(1),
static_cast<bst_uint>(param.colsample_bytree * feat_index.size()));
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
CHECK_GT(param.colsample_bytree, 0U)
<< "colsample_bytree cannot be zero.";
@ -450,11 +433,11 @@ class FastHistMaker: public TreeUpdater {
choose the column that has a least positive number of discrete bins.
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
const size_t nfeature = row_ptr.size() - 1;
size_t min_nbins_per_feature = 0;
for (size_t i = 0; i < nfeature; ++i) {
const unsigned nbins = row_ptr[i + 1] - row_ptr[i];
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
const bst_uint nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
uint32_t min_nbins_per_feature = 0;
for (bst_uint i = 0; i < nfeature; ++i) {
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
if (nbins > 0) {
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
min_nbins_per_feature = nbins;
@ -485,7 +468,7 @@ class FastHistMaker: public TreeUpdater {
const std::vector<bst_uint>& feat_set) {
// start enumeration
const MetaInfo& info = fmat.info();
const bst_omp_uint nfeature = feat_set.size();
const bst_uint nfeature = static_cast<bst_uint>(feat_set.size());
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread);
best_split_tloc_.resize(nthread);
#pragma omp parallel for schedule(static) num_threads(nthread)
@ -547,13 +530,17 @@ class FastHistMaker: public TreeUpdater {
const bool default_left = (*p_tree)[nid].default_left();
const bst_uint fid = (*p_tree)[nid].split_index();
const bst_float split_pt = (*p_tree)[nid].split_cond();
const bst_uint lower_bound = gmat.cut->row_ptr[fid];
const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1];
bst_int split_cond = -1;
const uint32_t lower_bound = gmat.cut->row_ptr[fid];
const uint32_t upper_bound = gmat.cut->row_ptr[fid + 1];
int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) {
if (split_pt == gmat.cut->cut[i]) split_cond = static_cast<bst_int>(i);
CHECK_LT(upper_bound,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
if (split_pt == gmat.cut->cut[i]) {
split_cond = static_cast<int32_t>(i);
}
}
const auto& rowset = row_set_collection_[nid];
@ -580,15 +567,15 @@ class FastHistMaker: public TreeUpdater {
bool default_left) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
const int K = 8; // loop unrolling factor
const bst_omp_uint nrows = rowset.end - rowset.begin;
const bst_omp_uint rest = nrows % K;
const size_t nrows = rowset.end - rowset.begin;
const size_t rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
const bst_uint tid = omp_get_thread_num();
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
bst_uint rid[K];
size_t rid[K];
T rbin[K];
for (int k = 0; k < K; ++k) {
rid[k] = rowset.begin[i + k];
@ -604,7 +591,9 @@ class FastHistMaker: public TreeUpdater {
right.push_back(rid[k]);
}
} else {
if (static_cast<bst_int>(rbin[k] + column.index_base) <= split_cond) {
CHECK_LT(rbin[k] + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin[k] + column.index_base) <= split_cond) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
@ -612,10 +601,10 @@ class FastHistMaker: public TreeUpdater {
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
for (size_t i = nrows - rest; i < nrows; ++i) {
auto& left = row_split_tloc[nthread-1].left;
auto& right = row_split_tloc[nthread-1].right;
const bst_uint rid = rowset.begin[i];
const size_t rid = rowset.begin[i];
const T rbin = column.index[rid];
if (rbin == std::numeric_limits<T>::max()) { // missing value
if (default_left) {
@ -624,7 +613,9 @@ class FastHistMaker: public TreeUpdater {
right.push_back(rid);
}
} else {
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
CHECK_LT(rbin + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
@ -642,13 +633,13 @@ class FastHistMaker: public TreeUpdater {
bool default_left) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
const int K = 8; // loop unrolling factor
const bst_omp_uint nrows = rowset.end - rowset.begin;
const bst_omp_uint rest = nrows % K;
const size_t nrows = rowset.end - rowset.begin;
const size_t rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
bst_uint rid[K];
size_t rid[K];
GHistIndexRow row[K];
const unsigned* p[K];
const uint32_t* p[K];
bst_uint tid = omp_get_thread_num();
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
@ -663,7 +654,9 @@ class FastHistMaker: public TreeUpdater {
}
for (int k = 0; k < K; ++k) {
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
if (static_cast<bst_int>(*p[k]) <= split_cond) {
CHECK_LT(*p[k],
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(*p[k]) <= split_cond) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
@ -677,14 +670,15 @@ class FastHistMaker: public TreeUpdater {
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = rowset.begin[i];
for (size_t i = nrows - rest; i < nrows; ++i) {
const size_t rid = rowset.begin[i];
const auto row = gmat[rid];
const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound);
auto& left = row_split_tloc[0].left;
auto& right = row_split_tloc[0].right;
if (p != row.index + row.size && *p < upper_bound) {
if (static_cast<bst_int>(*p) <= split_cond) {
CHECK_LT(*p, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(*p) <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
@ -709,26 +703,26 @@ class FastHistMaker: public TreeUpdater {
bst_int split_cond,
bool default_left) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
const bst_omp_uint nrows = rowset.end - rowset.begin;
const size_t nrows = rowset.end - rowset.begin;
#pragma omp parallel num_threads(nthread)
{
const bst_uint tid = omp_get_thread_num();
const bst_omp_uint ibegin = tid * nrows / nthread;
const bst_omp_uint iend = (tid + 1) * nrows / nthread;
const size_t tid = static_cast<size_t>(omp_get_thread_num());
const size_t ibegin = tid * nrows / nthread;
const size_t iend = (tid + 1) * nrows / nthread;
if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range
// search first nonzero row with index >= rowset[ibegin]
const uint32_t* p = std::lower_bound(column.row_ind,
column.row_ind + column.len,
rowset.begin[ibegin]);
const size_t* p = std::lower_bound(column.row_ind,
column.row_ind + column.len,
rowset.begin[ibegin]);
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) {
bst_omp_uint cursor = p - column.row_ind;
size_t cursor = p - column.row_ind;
for (bst_omp_uint i = ibegin; i < iend; ++i) {
const bst_uint rid = rowset.begin[i];
for (size_t i = ibegin; i < iend; ++i) {
const size_t rid = rowset.begin[i];
while (cursor < column.len
&& column.row_ind[cursor] < rid
&& column.row_ind[cursor] <= rowset.begin[iend - 1]) {
@ -736,7 +730,9 @@ class FastHistMaker: public TreeUpdater {
}
if (cursor < column.len && column.row_ind[cursor] == rid) {
const T rbin = column.index[cursor];
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
CHECK_LT(rbin + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
@ -753,13 +749,13 @@ class FastHistMaker: public TreeUpdater {
}
} else { // all rows in [ibegin, iend) have missing values
if (default_left) {
for (bst_omp_uint i = ibegin; i < iend; ++i) {
const bst_uint rid = rowset.begin[i];
for (size_t i = ibegin; i < iend; ++i) {
const size_t rid = rowset.begin[i];
left.push_back(rid);
}
} else {
for (bst_omp_uint i = ibegin; i < iend; ++i) {
const bst_uint rid = rowset.begin[i];
for (size_t i = ibegin; i < iend; ++i) {
const size_t rid = rowset.begin[i];
right.push_back(rid);
}
}
@ -786,17 +782,17 @@ class FastHistMaker: public TreeUpdater {
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
GHistRow hist = hist_[nid];
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
const size_t ibegin = row_ptr[fid_least_bins_];
const size_t iend = row_ptr[fid_least_bins_ + 1];
for (size_t i = ibegin; i < iend; ++i) {
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
for (uint32_t i = ibegin; i < iend; ++i) {
const GHistEntry et = hist.begin[i];
stats.Add(et.sum_grad, et.sum_hess);
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const bst_uint* it = e.begin; it < e.end; ++it) {
for (const size_t* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
}
}
@ -831,7 +827,7 @@ class FastHistMaker: public TreeUpdater {
CHECK(d_step == +1 || d_step == -1);
// aliases
const std::vector<unsigned>& cut_ptr = gmat.cut->row_ptr;
const std::vector<uint32_t>& cut_ptr = gmat.cut->row_ptr;
const std::vector<bst_float>& cut_val = gmat.cut->cut;
// statistics on both sides of split
@ -841,20 +837,25 @@ class FastHistMaker: public TreeUpdater {
SplitEntry best;
// bin boundaries
CHECK_LE(cut_ptr[fid],
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
CHECK_LE(cut_ptr[fid + 1],
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
// imin: index (offset) of the minimum value for feature fid
// need this for backward enumeration
const int imin = cut_ptr[fid];
const int32_t imin = static_cast<int32_t>(cut_ptr[fid]);
// ibegin, iend: smallest/largest cut points for feature fid
int ibegin, iend;
// use int to allow for value -1
int32_t ibegin, iend;
if (d_step > 0) {
ibegin = cut_ptr[fid];
iend = cut_ptr[fid + 1];
ibegin = static_cast<int32_t>(cut_ptr[fid]);
iend = static_cast<int32_t>(cut_ptr[fid + 1]);
} else {
ibegin = cut_ptr[fid + 1] - 1;
iend = cut_ptr[fid] - 1;
ibegin = static_cast<int32_t>(cut_ptr[fid + 1]) - 1;
iend = static_cast<int32_t>(cut_ptr[fid]) - 1;
}
for (int i = ibegin; i != iend; i += d_step) {
for (int32_t i = ibegin; i != iend; i += d_step) {
// start working
// try to find a split
e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
@ -930,7 +931,7 @@ class FastHistMaker: public TreeUpdater {
HistCollection hist_;
/*! \brief feature with least # of bins. to be used for dense specialization
of InitNewNode() */
size_t fid_least_bins_;
uint32_t fid_least_bins_;
/*! \brief local prediction cache; maps node id to leaf value */
std::vector<float> leaf_value_cache_;

View File

@ -1,5 +1,6 @@
#include "../../../src/common/compressed_iterator.h"
#include "gtest/gtest.h"
#include <algorithm>
namespace xgboost {
namespace common {