diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index 705cbcf43..d4be2e346 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -441,7 +441,7 @@ class bulk_allocator { public: ~bulk_allocator() { - for (int i = 0; i < d_ptr.size(); i++) { + for (size_t i = 0; i < d_ptr.size(); i++) { if (!(d_ptr[i] == nullptr)) { safe_cuda(cudaSetDevice(_device_idx[i])); safe_cuda(cudaFree(d_ptr[i])); @@ -522,7 +522,7 @@ inline size_t available_memory(int device_idx) { template void print(const thrust::device_vector &v, size_t max_items = 10) { thrust::host_vector h = v; - for (int i = 0; i < std::min(max_items, h.size()); i++) { + for (size_t i = 0; i < std::min(max_items, h.size()); i++) { std::cout << " " << h[i]; } std::cout << "\n"; @@ -531,7 +531,7 @@ void print(const thrust::device_vector &v, size_t max_items = 10) { template void print(const dvec &v, size_t max_items = 10) { std::vector h = v.as_vector(); - for (int i = 0; i < std::min(max_items, h.size()); i++) { + for (size_t i = 0; i < std::min(max_items, h.size()); i++) { std::cout << " " << h[i]; } std::cout << "\n"; @@ -539,10 +539,10 @@ void print(const dvec &v, size_t max_items = 10) { template void print(char *label, const thrust::device_vector &v, - const char *format = "%d ", int max = 10) { + const char *format = "%d ", size_t max = 10) { thrust::host_vector h_v = v; std::cout << label << ":\n"; - for (int i = 0; i < std::min(static_cast(h_v.size()), max); i++) { + for (size_t i = 0; i < std::min(static_cast(h_v.size()), max); i++) { printf(format, h_v[i]); } std::cout << "\n"; @@ -593,6 +593,7 @@ __global__ void launch_n_kernel(int device_idx, size_t begin, size_t end, template inline void launch_n(int device_idx, size_t n, L lambda) { safe_cuda(cudaSetDevice(device_idx)); + // TODO: Template on n so GRID_SIZE always fits into int. const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); #if defined(__CUDACC__) launch_n_kernel<<>>(static_cast(0), n, @@ -607,6 +608,7 @@ inline void multi_launch_n(size_t n, int n_devices, L lambda) { CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested " "needs to be less than equal to " "number of visible devices."; + // TODO: Template on n so GRID_SIZE always fits into int. const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); #if defined(__CUDACC__) n_devices = n_devices > n ? n : n_devices; diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index ffb861f06..8e9575227 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -19,8 +19,8 @@ namespace xgboost { namespace tree { void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat, - bst_uint element_begin, bst_uint element_end, - bst_uint row_begin, bst_uint row_end, int n_bins) { + bst_ulong element_begin, bst_ulong element_end, + bst_ulong row_begin, bst_ulong row_end, int n_bins) { dh::safe_cuda(cudaSetDevice(device_idx)); CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated"; CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1) @@ -31,15 +31,15 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat, cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin, gmat.index.begin() + element_end); gidx_buffer = host_buffer; - gidx = common::CompressedIterator(gidx_buffer.data(), n_bins); + gidx = common::CompressedIterator(gidx_buffer.data(), n_bins); // row_ptr thrust::copy(gmat.row_ptr.data() + row_begin, gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin()); // normalise row_ptr - bst_uint start = gmat.row_ptr[row_begin]; + size_t start = gmat.row_ptr[row_begin]; thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(), - [=] __device__(int val) { return val - start; }); + [=] __device__(size_t val) { return val - start; }); } void DeviceHist::Init(int n_bins_in) { @@ -193,7 +193,7 @@ void GPUHistBuilder::InitData(const std::vector& gpair, device_row_segments.push_back(0); device_element_segments.push_back(0); bst_uint offset = 0; - size_t shard_size = std::ceil(static_cast(num_rows) / n_devices); + bst_uint shard_size = std::ceil(static_cast(num_rows) / n_devices); for (int d_idx = 0; d_idx < n_devices; d_idx++) { int device_idx = dList[d_idx]; offset += shard_size; @@ -261,7 +261,7 @@ void GPUHistBuilder::InitData(const std::vector& gpair, int device_idx = dList[d_idx]; bst_uint num_rows_segment = device_row_segments[d_idx + 1] - device_row_segments[d_idx]; - bst_uint num_elements_segment = + bst_ulong num_elements_segment = device_element_segments[d_idx + 1] - device_element_segments[d_idx]; ba.allocate( device_idx, &(hist_vec[d_idx].data), @@ -360,7 +360,7 @@ void GPUHistBuilder::BuildHist(int depth) { auto hist_builder = hist_vec[d_idx].GetBuilder(); dh::TransformLbs( device_idx, &temp_memory[d_idx], end - begin, d_row_ptr, - row_end - row_begin, [=] __device__(int local_idx, int local_ridx) { + row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) { int nidx = d_position[local_ridx]; // OPTMARK: latency if (!is_active(nidx, depth)) return; @@ -606,7 +606,11 @@ __global__ void find_split_kernel( } #define MIN_BLOCK_THREADS 32 -#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size +#define CHUNK_BLOCK_THREADS 32 +// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due +// to CUDA compatibility 35 and above requirement +// for Maximum number of threads per block +#define MAX_BLOCK_THREADS 1024 void GPUHistBuilder::FindSplit(int depth) { // Specialised based on max_bins @@ -622,7 +626,7 @@ void GPUHistBuilder::FindSplitSpecialize(int depth) { if (param.max_bin <= BLOCK_THREADS) { LaunchFindSplit(depth); } else { - this->FindSplitSpecialize(depth); + this->FindSplitSpecialize(depth); } } @@ -885,7 +889,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) { size_t begin = device_row_segments[d_idx]; size_t end = device_row_segments[d_idx + 1]; - dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) { + dh::launch_n(device_idx, end - begin, [=] __device__(size_t local_idx) { int pos = d_position[local_idx]; if (!is_active(pos, depth)) { return; @@ -896,7 +900,8 @@ void GPUHistBuilder::UpdatePositionDense(int depth) { return; } - int gidx = d_gidx[local_idx * n_columns + node.split.findex]; + int gidx = d_gidx[local_idx * + static_cast(n_columns) + static_cast(node.split.findex)]; float fvalue = d_gidx_fvalue_map[gidx]; @@ -955,7 +960,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) { dh::TransformLbs( device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr, - row_end - row_begin, [=] __device__(int local_idx, int local_ridx) { + row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) { int pos = d_position[local_ridx]; if (!is_active(pos, depth)) { return; @@ -1065,25 +1070,13 @@ void GPUHistBuilder::Update(const std::vector& gpair, this->InitData(gpair, *p_fmat, *p_tree); this->InitFirstNode(gpair); this->ColSampleTree(); - // long long int elapsed=0; for (int depth = 0; depth < param.max_depth; depth++) { this->ColSampleLevel(); - // dh::Timer time; this->BuildHist(depth); - // elapsed+=time.elapsed(); - // printf("depth=%d\n",depth); - // time.printElapsed("BH Time"); - - // dh::Timer timesplit; this->FindSplit(depth); - // timesplit.printElapsed("FS Time"); - - // dh::Timer timeupdatepos; this->UpdatePosition(depth); - // timeupdatepos.printElapsed("UP Time"); } - // printf("Total BuildHist Time=%lld\n",elapsed); // done with multi-GPU, pass back result from master to tree on host int master_device = dList[0]; diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh index 1264faef4..bcaf49d39 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cuh +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -18,10 +18,10 @@ namespace tree { struct DeviceGMat { dh::dvec gidx_buffer; - common::CompressedIterator gidx; - dh::dvec row_ptr; + common::CompressedIterator gidx; + dh::dvec row_ptr; void Init(int device_idx, const common::GHistIndexMatrix &gmat, - bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins); + bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, bst_ulong row_end,int n_bins); }; struct HistBuilder { @@ -99,7 +99,7 @@ class GPUHistBuilder { // below vectors are for each devices used std::vector dList; std::vector device_row_segments; - std::vector device_element_segments; + std::vector device_element_segments; std::vector temp_memory; std::vector hist_vec; diff --git a/plugin/updater_gpu/test/python/test_large.py b/plugin/updater_gpu/test/python/test_large.py new file mode 100644 index 000000000..c3becc041 --- /dev/null +++ b/plugin/updater_gpu/test/python/test_large.py @@ -0,0 +1,134 @@ +from __future__ import print_function +#pylint: skip-file +import sys +sys.path.append("../../tests/python") +import xgboost as xgb +import testing as tm +import numpy as np +import unittest +from sklearn.datasets import make_classification +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush() + print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush() + +eprint("Testing Big Data (this may take a while)") + +rng = np.random.RandomState(1994) + +# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy +cols = 31 +# reduced to fit onto 1 gpu but still be large +rows2 = 5000 # medium +#rows2 = 4032 # fake large for testing +rows1 = 42360032 # large +#rows2 = 152360032 # can do this for multi-gpu test (very large) +rowslist = [rows1, rows2] + + +class TestGPU(unittest.TestCase): + def test_large(self): + eprint("Starting test for large data") + tm._skip_if_no_sklearn() + from sklearn.datasets import load_digits + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split + + + for rows in rowslist: + + eprint("Creating train data rows=%d cols=%d" % (rows,cols)) + X, y = make_classification(rows, n_features=cols, random_state=7) + rowstest = int(rows*0.2) + eprint("Creating test data rows=%d cols=%d" % (rowstest,cols)) + # note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot + Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8) + + eprint("Starting DMatrix(X,y)") + ag_dtrain = xgb.DMatrix(X,y) + eprint("Starting DMatrix(Xtest,ytest)") + ag_dtest = xgb.DMatrix(Xtest,ytest) + + + + + + max_depth=6 + max_bin=1024 + + # regression test --- hist must be same as exact on all-categorial data + ag_param = {'max_depth': max_depth, + 'tree_method': 'exact', + #'nthread': 1, + 'eta': 1, + 'silent': 0, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_paramb = {'max_depth': max_depth, + 'tree_method': 'hist', + #'nthread': 1, + 'eta': 1, + 'silent': 0, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_param2 = {'max_depth': max_depth, + 'tree_method': 'gpu_hist', + 'eta': 1, + 'silent': 0, + 'n_gpus': 1, + 'objective': 'binary:logistic', + 'max_bin': max_bin, + 'eval_metric': 'auc'} + ag_param3 = {'max_depth': max_depth, + 'tree_method': 'gpu_hist', + 'eta': 1, + 'silent': 0, + 'n_gpus': -1, + 'objective': 'binary:logistic', + 'max_bin': max_bin, + 'eval_metric': 'auc'} + #ag_param4 = {'max_depth': max_depth, + # 'tree_method': 'gpu_exact', + # 'eta': 1, + # 'silent': 0, + # 'n_gpus': 1, + # 'objective': 'binary:logistic', + # 'max_bin': max_bin, + # 'eval_metric': 'auc'} + ag_res = {} + ag_resb = {} + ag_res2 = {} + ag_res3 = {} + #ag_res4 = {} + + num_rounds = 1 + + eprint("normal updater") + xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res) + eprint("hist updater") + xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_resb) + eprint("gpu_hist updater 1 gpu") + xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res2) + eprint("gpu_hist updater all gpus") + xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res3) + #eprint("gpu_exact updater") + #xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + # evals_result=ag_res4) + + assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001 + assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001 + assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001 + #assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001 + + assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01 + assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01 + assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01 + #assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01 + + + diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 5d3fea87e..cbf3a368b 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -57,7 +57,7 @@ class Column { ColumnType type; const T* index; uint32_t index_base; - const uint32_t* row_ind; + const size_t* row_ind; size_t len; }; @@ -66,8 +66,8 @@ class Column { class ColumnMatrix { public: // get number of features - inline uint32_t GetNumFeature() const { - return type_.size(); + inline bst_uint GetNumFeature() const { + return static_cast(type_.size()); } // construct column matrix from GHistIndexMatrix @@ -78,8 +78,8 @@ class ColumnMatrix { slot of internal buffer. */ packing_factor_ = sizeof(uint32_t) / static_cast(this->dtype); - const uint32_t nfeature = gmat.cut->row_ptr.size() - 1; - const omp_ulong nrow = static_cast(gmat.row_ptr.size() - 1); + const bst_uint nfeature = static_cast(gmat.cut->row_ptr.size() - 1); + const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column feature_counts_.resize(nfeature); @@ -90,13 +90,13 @@ class ColumnMatrix { XGBOOST_TYPE_SWITCH(this->dtype, { max_val = static_cast(std::numeric_limits::max()); }); - for (uint32_t fid = 0; fid < nfeature; ++fid) { + for (bst_uint fid = 0; fid < nfeature; ++fid) { CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val); } gmat.GetFeatureCounts(&feature_counts_[0]); // classify features - for (uint32_t fid = 0; fid < nfeature; ++fid) { + for (bst_uint fid = 0; fid < nfeature; ++fid) { if (static_cast(feature_counts_[fid]) < param.sparse_threshold * nrow) { type_[fid] = kSparseColumn; @@ -108,13 +108,13 @@ class ColumnMatrix { // want to compute storage boundary for each feature // using variants of prefix sum scan boundary_.resize(nfeature); - bst_uint accum_index_ = 0; - bst_uint accum_row_ind_ = 0; - for (uint32_t fid = 0; fid < nfeature; ++fid) { + size_t accum_index_ = 0; + size_t accum_row_ind_ = 0; + for (bst_uint fid = 0; fid < nfeature; ++fid) { boundary_[fid].index_begin = accum_index_; boundary_[fid].row_ind_begin = accum_row_ind_; if (type_[fid] == kDenseColumn) { - accum_index_ += nrow; + accum_index_ += static_cast(nrow); } else { accum_index_ += feature_counts_[fid]; accum_row_ind_ += feature_counts_[fid]; @@ -129,14 +129,14 @@ class ColumnMatrix { // store least bin id for each feature index_base_.resize(nfeature); - for (uint32_t fid = 0; fid < nfeature; ++fid) { + for (bst_uint fid = 0; fid < nfeature; ++fid) { index_base_[fid] = gmat.cut->row_ptr[fid]; } - // fill index_ for dense columns - for (uint32_t fid = 0; fid < nfeature; ++fid) { + // pre-fill index_ for dense columns + for (bst_uint fid = 0; fid < nfeature; ++fid) { if (type_[fid] == kDenseColumn) { - const uint32_t ibegin = boundary_[fid].index_begin; + const size_t ibegin = boundary_[fid].index_begin; XGBOOST_TYPE_SWITCH(this->dtype, { const size_t block_offset = ibegin / packing_factor_; const size_t elem_offset = ibegin % packing_factor_; @@ -150,15 +150,15 @@ class ColumnMatrix { // loop over all rows and fill column entries // num_nonzeros[fid] = how many nonzeros have this feature accumulated so far? - std::vector num_nonzeros; + std::vector num_nonzeros; num_nonzeros.resize(nfeature); std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0); - for (uint32_t rid = 0; rid < nrow; ++rid) { - const size_t ibegin = static_cast(gmat.row_ptr[rid]); - const size_t iend = static_cast(gmat.row_ptr[rid + 1]); + for (size_t rid = 0; rid < nrow; ++rid) { + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; size_t fid = 0; for (size_t i = ibegin; i < iend; ++i) { - const size_t bin_id = gmat.index[i]; + const uint32_t bin_id = gmat.index[i]; while (bin_id >= gmat.cut->row_ptr[fid + 1]) { ++fid; } @@ -167,14 +167,14 @@ class ColumnMatrix { const size_t block_offset = boundary_[fid].index_begin / packing_factor_; const size_t elem_offset = boundary_[fid].index_begin % packing_factor_; DType* begin = reinterpret_cast(&index_[block_offset]) + elem_offset; - begin[rid] = bin_id - index_base_[fid]; + begin[rid] = static_cast(bin_id - index_base_[fid]); }); } else { XGBOOST_TYPE_SWITCH(this->dtype, { const size_t block_offset = boundary_[fid].index_begin / packing_factor_; const size_t elem_offset = boundary_[fid].index_begin % packing_factor_; DType* begin = reinterpret_cast(&index_[block_offset]) + elem_offset; - begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; + begin[num_nonzeros[fid]] = static_cast(bin_id - index_base_[fid]); }); row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid; ++num_nonzeros[fid]; @@ -213,16 +213,16 @@ class ColumnMatrix { // indicate where each column's index and row_ind is stored. // index_begin and index_end are logical offsets, so they should be converted to // actual offsets by scaling with packing_factor_ - unsigned index_begin; - unsigned index_end; - unsigned row_ind_begin; - unsigned row_ind_end; + size_t index_begin; + size_t index_end; + size_t row_ind_begin; + size_t row_ind_end; }; - std::vector feature_counts_; + std::vector feature_counts_; std::vector type_; std::vector index_; // index_: may store smaller integers; needs padding - std::vector row_ind_; + std::vector row_ind_; std::vector boundary_; size_t packing_factor_; // how many integers are stored in each slot of index_ diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 794c93398..c40d6c897 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -46,11 +46,11 @@ static int SymbolBits(int num_symbols) { class CompressedBufferWriter { private: - int symbol_bits_; + size_t symbol_bits_; size_t offset_; public: - explicit CompressedBufferWriter(int num_symbols) : offset_(0) { + explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) { symbol_bits_ = detail::SymbolBits(num_symbols); } @@ -70,9 +70,9 @@ class CompressedBufferWriter { * \return The calculated buffer size. */ - static size_t CalculateBufferSize(int num_elements, int num_symbols) { + static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { const int bits_per_byte = 8; - int compressed_size = std::ceil( + size_t compressed_size = std::ceil( static_cast(detail::SymbolBits(num_symbols) * num_elements) / bits_per_byte); return compressed_size + detail::padding; @@ -82,10 +82,10 @@ class CompressedBufferWriter { void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) { const int bits_per_byte = 8; - for (int i = 0; i < symbol_bits_; i++) { + for (size_t i = 0; i < symbol_bits_; i++) { size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte; byte_idx += detail::padding; - int bit_idx = + size_t bit_idx = ((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte; if (detail::CheckBit(symbol, i)) { @@ -100,14 +100,14 @@ class CompressedBufferWriter { uint64_t tmp = 0; int stored_bits = 0; const int max_stored_bits = 64 - symbol_bits_; - int buffer_position = detail::padding; - const int num_symbols = input_end - input_begin; - for (int i = 0; i < num_symbols; i++) { + size_t buffer_position = detail::padding; + const size_t num_symbols = input_end - input_begin; + for (size_t i = 0; i < num_symbols; i++) { typename std::iterator_traits::value_type symbol = input_begin[i]; if (stored_bits > max_stored_bits) { // Eject only full bytes - int tmp_bytes = stored_bits / 8; - for (int j = 0; j < tmp_bytes; j++) { + size_t tmp_bytes = stored_bits / 8; + for (size_t j = 0; j < tmp_bytes; j++) { buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8); buffer_position++; } @@ -121,8 +121,8 @@ class CompressedBufferWriter { } // Eject all bytes - int tmp_bytes = std::ceil(static_cast(stored_bits) / 8); - for (int j = 0; j < tmp_bytes; j++) { + size_t tmp_bytes = std::ceil(static_cast(stored_bits) / 8); + for (size_t j = 0; j < tmp_bytes; j++) { int shift_bits = stored_bits - (j + 1) * 8; if (shift_bits >= 0) { buffer[buffer_position] = tmp >> shift_bits; @@ -159,7 +159,7 @@ class CompressedIterator { /// iterator can point to private: compressed_byte_t *buffer_; - int symbol_bits_; + size_t symbol_bits_; size_t offset_; public: @@ -189,7 +189,7 @@ class CompressedIterator { return static_cast(tmp & mask); } - XGBOOST_DEVICE reference operator[](int idx) const { + XGBOOST_DEVICE reference operator[](size_t idx) const { self_type offset = (*this); offset.offset_ += idx; return *offset; diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index fe27ac8c5..ee64d9778 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -16,7 +16,7 @@ namespace xgboost { namespace common { -void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) { +void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { typedef common::WXQuantileSketch WXQSketch; const MetaInfo& info = p_fmat->info(); @@ -44,7 +44,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) { unsigned begin = std::min(nstep * tid, ncol); unsigned end = std::min(nstep * (tid + 1), ncol); for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*) - bst_uint ridx = static_cast(batch.base_rowid + i); + size_t ridx = batch.base_rowid + i; RowBatch::Inst inst = batch[i]; for (bst_uint j = 0; j < inst.length; ++j) { if (inst[j].index >= begin && inst[j].index < end) { @@ -108,7 +108,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { dmlc::DataIter* iter = p_fmat->RowIterator(); const int nthread = omp_get_max_threads(); - const unsigned nbins = cut->row_ptr.back(); + const uint32_t nbins = cut->row_ptr.back(); hit_count.resize(nbins, 0); hit_count_tloc_.resize(nthread * nbins, 0); @@ -116,7 +116,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { row_ptr.push_back(0); while (iter->Next()) { const RowBatch& batch = iter->Value(); - size_t rbegin = row_ptr.size() - 1; + const size_t rbegin = row_ptr.size() - 1; for (size_t i = 0; i < batch.size; ++i) { row_ptr.push_back(batch[i].length + row_ptr.back()); } @@ -140,7 +140,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { CHECK(cbegin != cend); auto it = std::upper_bound(cbegin, cend, inst[j].fvalue); if (it == cend) it = cend - 1; - unsigned idx = static_cast(it - cut->cut.begin()); + uint32_t idx = static_cast(it - cut->cut.begin()); index[ibegin + j] = idx; ++hit_count_tloc_[tid * nbins + idx]; } @@ -148,7 +148,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { } #pragma omp parallel for num_threads(nthread) schedule(static) - for (omp_ulong idx = 0; idx < nbins; ++idx) { + for (bst_omp_uint idx = 0; idx < nbins; ++idx) { for (int tid = 0; tid < nthread; ++tid) { hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; } @@ -157,10 +157,10 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { } template -static unsigned GetConflictCount(const std::vector& mark, - const Column& column, - unsigned max_cnt) { - unsigned ret = 0; +static size_t GetConflictCount(const std::vector& mark, + const Column& column, + size_t max_cnt) { + size_t ret = 0; if (column.type == xgboost::common::kDenseColumn) { for (size_t i = 0; i < column.len; ++i) { if (column.index[i] != std::numeric_limits::max() && mark[i]) { @@ -203,9 +203,9 @@ MarkUsed(std::vector* p_mark, const Column& column) { template inline std::vector> FindGroups_(const std::vector& feature_list, - const std::vector& feature_nnz, + const std::vector& feature_nnz, const ColumnMatrix& colmat, - unsigned nrow, + size_t nrow, const FastHistParam& param) { /* Goal: Bundle features together that has little or no "overlap", i.e. only a few data points should have nonzero values for @@ -214,10 +214,10 @@ FindGroups_(const std::vector& feature_list, std::vector> groups; std::vector> conflict_marks; - std::vector group_nnz; - std::vector group_conflict_cnt; - const unsigned max_conflict_cnt - = static_cast(param.max_conflict_rate * nrow); + std::vector group_nnz; + std::vector group_conflict_cnt; + const size_t max_conflict_cnt + = static_cast(param.max_conflict_rate * nrow); for (auto fid : feature_list) { const Column& column = colmat.GetColumn(fid); @@ -239,8 +239,8 @@ FindGroups_(const std::vector& feature_list, // examine each candidate group: is it okay to insert fid? for (auto gid : search_groups) { - const unsigned rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid]; - const unsigned cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt); + const size_t rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid]; + const size_t cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt); if (cnt <= rest_max_cnt) { need_new_group = false; groups[gid].push_back(fid); @@ -267,9 +267,9 @@ FindGroups_(const std::vector& feature_list, inline std::vector> FindGroups(const std::vector& feature_list, - const std::vector& feature_nnz, + const std::vector& feature_nnz, const ColumnMatrix& colmat, - unsigned nrow, + size_t nrow, const FastHistParam& param) { XGBOOST_TYPE_SWITCH(colmat.dtype, { return FindGroups_(feature_list, feature_nnz, colmat, nrow, param); @@ -288,11 +288,11 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, std::iota(feature_list.begin(), feature_list.end(), 0); // sort features by nonzero counts, descending order - std::vector feature_nnz(nfeature); + std::vector feature_nnz(nfeature); std::vector features_by_nnz(feature_list); gmat.GetFeatureCounts(&feature_nnz[0]); std::sort(features_by_nnz.begin(), features_by_nnz.end(), - [&feature_nnz](int a, int b) { + [&feature_nnz](unsigned a, unsigned b) { return feature_nnz[a] > feature_nnz[b]; }); @@ -307,7 +307,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, if (group.size() <= 1 || group.size() >= 5) { ret.push_back(group); // keep singleton groups and large (5+) groups } else { - unsigned nnz = 0; + size_t nnz = 0; for (auto fid : group) { nnz += feature_nnz[fid]; } @@ -338,37 +338,37 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, cut = gmat.cut; const size_t nrow = gmat.row_ptr.size() - 1; - const size_t nbins = gmat.cut->row_ptr.back(); + const uint32_t nbins = gmat.cut->row_ptr.back(); /* step 1: form feature groups */ auto groups = FastFeatureGrouping(gmat, colmat, param); - const size_t nblock = groups.size(); + const uint32_t nblock = static_cast(groups.size()); /* step 2: build a new CSR matrix for each feature group */ - std::vector bin2block(nbins); // lookup table [bin id] => [block id] - for (size_t group_id = 0; group_id < nblock; ++group_id) { + std::vector bin2block(nbins); // lookup table [bin id] => [block id] + for (uint32_t group_id = 0; group_id < nblock; ++group_id) { for (auto& fid : groups[group_id]) { - const unsigned bin_begin = gmat.cut->row_ptr[fid]; - const unsigned bin_end = gmat.cut->row_ptr[fid + 1]; - for (unsigned bin_id = bin_begin; bin_id < bin_end; ++bin_id) { + const uint32_t bin_begin = gmat.cut->row_ptr[fid]; + const uint32_t bin_end = gmat.cut->row_ptr[fid + 1]; + for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) { bin2block[bin_id] = group_id; } } } - std::vector> index_temp(nblock); - std::vector> row_ptr_temp(nblock); - for (size_t block_id = 0; block_id < nblock; ++block_id) { + std::vector> index_temp(nblock); + std::vector> row_ptr_temp(nblock); + for (uint32_t block_id = 0; block_id < nblock; ++block_id) { row_ptr_temp[block_id].push_back(0); } for (size_t rid = 0; rid < nrow; ++rid) { - const size_t ibegin = static_cast(gmat.row_ptr[rid]); - const size_t iend = static_cast(gmat.row_ptr[rid + 1]); + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; for (size_t j = ibegin; j < iend; ++j) { - const size_t bin_id = gmat.index[j]; - const size_t block_id = bin2block[bin_id]; + const uint32_t bin_id = gmat.index[j]; + const uint32_t block_id = bin2block[bin_id]; index_temp[block_id].push_back(bin_id); } - for (size_t block_id = 0; block_id < nblock; ++block_id) { + for (uint32_t block_id = 0; block_id < nblock; ++block_id) { row_ptr_temp[block_id].push_back(index_temp[block_id].size()); } } @@ -378,7 +378,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, std::vector row_ptr_blk_ptr; index_blk_ptr.push_back(0); row_ptr_blk_ptr.push_back(0); - for (size_t block_id = 0; block_id < nblock; ++block_id) { + for (uint32_t block_id = 0; block_id < nblock; ++block_id) { index.insert(index.end(), index_temp[block_id].begin(), index_temp[block_id].end()); row_ptr.insert(row_ptr.end(), row_ptr_temp[block_id].begin(), row_ptr_temp[block_id].end()); index_blk_ptr.push_back(index.size()); @@ -386,7 +386,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, } // save shortcut for each block - for (size_t block_id = 0; block_id < nblock; ++block_id) { + for (uint32_t block_id = 0; block_id < nblock; ++block_id) { Block blk; blk.index_begin = &index[index_blk_ptr[block_id]]; blk.row_ptr_begin = &row_ptr[row_ptr_blk_ptr[block_id]]; @@ -406,14 +406,14 @@ void GHistBuilder::BuildHist(const std::vector& gpair, const int K = 8; // loop unrolling factor const bst_omp_uint nthread = static_cast(this->nthread_); - const bst_omp_uint nrows = row_indices.end - row_indices.begin; - const bst_omp_uint rest = nrows % K; + const size_t nrows = row_indices.end - row_indices.begin; + const size_t rest = nrows % K; #pragma omp parallel for num_threads(nthread) schedule(guided) for (bst_omp_uint i = 0; i < nrows - rest; i += K) { const bst_omp_uint tid = omp_get_thread_num(); const size_t off = tid * nbins_; - bst_uint rid[K]; + size_t rid[K]; size_t ibegin[K]; size_t iend[K]; bst_gpair stat[K]; @@ -421,32 +421,32 @@ void GHistBuilder::BuildHist(const std::vector& gpair, rid[k] = row_indices.begin[i + k]; } for (int k = 0; k < K; ++k) { - ibegin[k] = static_cast(gmat.row_ptr[rid[k]]); - iend[k] = static_cast(gmat.row_ptr[rid[k] + 1]); + ibegin[k] = gmat.row_ptr[rid[k]]; + iend[k] = gmat.row_ptr[rid[k] + 1]; } for (int k = 0; k < K; ++k) { stat[k] = gpair[rid[k]]; } for (int k = 0; k < K; ++k) { for (size_t j = ibegin[k]; j < iend[k]; ++j) { - const size_t bin = gmat.index[j]; + const uint32_t bin = gmat.index[j]; data_[off + bin].Add(stat[k]); } } } for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { - const bst_uint rid = row_indices.begin[i]; - const size_t ibegin = static_cast(gmat.row_ptr[rid]); - const size_t iend = static_cast(gmat.row_ptr[rid + 1]); + const size_t rid = row_indices.begin[i]; + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; const bst_gpair stat = gpair[rid]; for (size_t j = ibegin; j < iend; ++j) { - const size_t bin = gmat.index[j]; + const uint32_t bin = gmat.index[j]; data_[bin].Add(stat); } } /* reduction */ - const bst_omp_uint nbins = static_cast(nbins_); + const uint32_t nbins = nbins_; #pragma omp parallel for num_threads(nthread) schedule(static) for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) { for (bst_omp_uint tid = 0; tid < nthread; ++tid) { @@ -462,16 +462,16 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, GHistRow hist) { const int K = 8; // loop unrolling factor const bst_omp_uint nthread = static_cast(this->nthread_); - const bst_omp_uint nblock = gmatb.GetNumBlock(); - const bst_omp_uint nrows = row_indices.end - row_indices.begin; - const bst_omp_uint rest = nrows % K; + const uint32_t nblock = gmatb.GetNumBlock(); + const size_t nrows = row_indices.end - row_indices.begin; + const size_t rest = nrows % K; #pragma omp parallel for num_threads(nthread) schedule(guided) for (bst_omp_uint bid = 0; bid < nblock; ++bid) { auto gmat = gmatb[bid]; - for (bst_omp_uint i = 0; i < nrows - rest; i += K) { - bst_uint rid[K]; + for (size_t i = 0; i < nrows - rest; i += K) { + size_t rid[K]; size_t ibegin[K]; size_t iend[K]; bst_gpair stat[K]; @@ -479,26 +479,26 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, rid[k] = row_indices.begin[i + k]; } for (int k = 0; k < K; ++k) { - ibegin[k] = static_cast(gmat.row_ptr[rid[k]]); - iend[k] = static_cast(gmat.row_ptr[rid[k] + 1]); + ibegin[k] = gmat.row_ptr[rid[k]]; + iend[k] = gmat.row_ptr[rid[k] + 1]; } for (int k = 0; k < K; ++k) { stat[k] = gpair[rid[k]]; } for (int k = 0; k < K; ++k) { for (size_t j = ibegin[k]; j < iend[k]; ++j) { - const size_t bin = gmat.index[j]; + const uint32_t bin = gmat.index[j]; hist.begin[bin].Add(stat[k]); } } } for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { - const bst_uint rid = row_indices.begin[i]; - const size_t ibegin = static_cast(gmat.row_ptr[rid]); - const size_t iend = static_cast(gmat.row_ptr[rid + 1]); + const size_t rid = row_indices.begin[i]; + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; const bst_gpair stat = gpair[rid]; for (size_t j = ibegin; j < iend; ++j) { - const size_t bin = gmat.index[j]; + const uint32_t bin = gmat.index[j]; hist.begin[bin].Add(stat); } } @@ -507,9 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { const bst_omp_uint nthread = static_cast(this->nthread_); - const bst_omp_uint nbins = static_cast(nbins_); + const uint32_t nbins = static_cast(nbins_); const int K = 8; // loop unrolling factor - const bst_omp_uint rest = nbins % K; + const uint32_t rest = nbins % K; #pragma omp parallel for num_threads(nthread) schedule(static) for (bst_omp_uint bin_id = 0; bin_id < nbins - rest; bin_id += K) { GHistEntry pb[K]; @@ -524,7 +524,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa self.begin[bin_id + k].SetSubtract(pb[k], sb[k]); } } - for (bst_omp_uint bin_id = nbins - rest; bin_id < nbins; ++bin_id) { + for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) { self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]); } } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 9c58cca73..4d5456e85 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -56,30 +56,30 @@ struct HistCutUnit { /*! \brief the index pointer of each histunit */ const bst_float* cut; /*! \brief number of cutting point, containing the maximum point */ - size_t size; + uint32_t size; // default constructor HistCutUnit() {} // constructor - HistCutUnit(const bst_float* cut, unsigned size) + HistCutUnit(const bst_float* cut, uint32_t size) : cut(cut), size(size) {} }; /*! \brief cut configuration for all the features */ struct HistCutMatrix { - /*! \brief actual unit pointer */ - std::vector row_ptr; + /*! \brief unit pointer to rows by element position */ + std::vector row_ptr; /*! \brief minimum value of each feature */ std::vector min_val; /*! \brief the cut field */ std::vector cut; /*! \brief Get histogram bound for fid */ - inline HistCutUnit operator[](unsigned fid) const { + inline HistCutUnit operator[](bst_uint fid) const { return HistCutUnit(dmlc::BeginPtr(cut) + row_ptr[fid], row_ptr[fid + 1] - row_ptr[fid]); } // create histogram cut matrix given statistics from data // using approximate quantile sketch approach - void Init(DMatrix* p_fmat, size_t max_num_bins); + void Init(DMatrix* p_fmat, uint32_t max_num_bins); }; @@ -89,11 +89,11 @@ struct HistCutMatrix { */ struct GHistIndexRow { /*! \brief The index of the histogram */ - const unsigned* index; + const uint32_t* index; /*! \brief The size of the histogram */ - unsigned size; + size_t size; GHistIndexRow() {} - GHistIndexRow(const unsigned* index, unsigned size) + GHistIndexRow(const uint32_t* index, size_t size) : index(index), size(size) {} }; @@ -103,21 +103,21 @@ struct GHistIndexRow { * This is a global histogram index. */ struct GHistIndexMatrix { - /*! \brief row pointer */ - std::vector row_ptr; + /*! \brief row pointer to rows by element position */ + std::vector row_ptr; /*! \brief The index data */ - std::vector index; + std::vector index; /*! \brief hit count of each index */ - std::vector hit_count; + std::vector hit_count; /*! \brief The corresponding cuts */ const HistCutMatrix* cut; // Create a global histogram matrix, given cut void Init(DMatrix* p_fmat); // get i-th row - inline GHistIndexRow operator[](bst_uint i) const { + inline GHistIndexRow operator[](size_t i) const { return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]); } - inline void GetFeatureCounts(bst_uint* counts) const { + inline void GetFeatureCounts(size_t* counts) const { const unsigned nfeature = cut->row_ptr.size() - 1; for (unsigned fid = 0; fid < nfeature; ++fid) { const unsigned ibegin = cut->row_ptr[fid]; @@ -129,18 +129,18 @@ struct GHistIndexMatrix { } private: - std::vector hit_count_tloc_; + std::vector hit_count_tloc_; }; struct GHistIndexBlock { - const unsigned* row_ptr; - const unsigned* index; + const size_t* row_ptr; + const uint32_t* index; - inline GHistIndexBlock(const unsigned* row_ptr, const unsigned* index) + inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index) : row_ptr(row_ptr), index(index) {} // get i-th row - inline GHistIndexRow operator[](bst_uint i) const { + inline GHistIndexRow operator[](size_t i) const { return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]); } }; @@ -153,23 +153,23 @@ class GHistIndexBlockMatrix { const ColumnMatrix& colmat, const FastHistParam& param); - inline GHistIndexBlock operator[](bst_uint i) const { + inline GHistIndexBlock operator[](size_t i) const { return GHistIndexBlock(blocks[i].row_ptr_begin, blocks[i].index_begin); } - inline unsigned GetNumBlock() const { + inline size_t GetNumBlock() const { return blocks.size(); } private: - std::vector row_ptr; - std::vector index; + std::vector row_ptr; + std::vector index; const HistCutMatrix* cut; struct Block { - const unsigned* row_ptr_begin; - const unsigned* row_ptr_end; - const unsigned* index_begin; - const unsigned* index_end; + const size_t* row_ptr_begin; + const size_t* row_ptr_end; + const uint32_t* index_begin; + const uint32_t* index_end; }; std::vector blocks; }; @@ -184,10 +184,10 @@ struct GHistRow { /*! \brief base pointer to first entry */ GHistEntry* begin; /*! \brief number of entries */ - unsigned size; + uint32_t size; GHistRow() {} - GHistRow(GHistEntry* begin, unsigned size) + GHistRow(GHistEntry* begin, uint32_t size) : begin(begin), size(size) {} }; @@ -198,19 +198,19 @@ class HistCollection { public: // access histogram for i-th node inline GHistRow operator[](bst_uint nid) const { - const size_t kMax = std::numeric_limits::max(); + const uint32_t kMax = std::numeric_limits::max(); CHECK_NE(row_ptr_[nid], kMax); return GHistRow(const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_); } // have we computed a histogram for i-th node? inline bool RowExists(bst_uint nid) const { - const size_t kMax = std::numeric_limits::max(); + const uint32_t kMax = std::numeric_limits::max(); return (nid < row_ptr_.size() && row_ptr_[nid] != kMax); } // initialize histogram collection - inline void Init(size_t nbins) { + inline void Init(uint32_t nbins) { nbins_ = nbins; row_ptr_.clear(); data_.clear(); @@ -218,7 +218,7 @@ class HistCollection { // create an empty histogram for i-th node inline void AddHistRow(bst_uint nid) { - const size_t kMax = std::numeric_limits::max(); + const uint32_t kMax = std::numeric_limits::max(); if (nid >= row_ptr_.size()) { row_ptr_.resize(nid + 1, kMax); } @@ -230,12 +230,12 @@ class HistCollection { private: /*! \brief number of all bins over all features */ - size_t nbins_; + uint32_t nbins_; std::vector data_; /*! \brief row_ptr_[nid] locates bin for historgram of node nid */ - std::vector row_ptr_; + std::vector row_ptr_; }; /*! @@ -244,7 +244,7 @@ class HistCollection { class GHistBuilder { public: // initialize builder - inline void Init(size_t nthread, size_t nbins) { + inline void Init(size_t nthread, uint32_t nbins) { nthread_ = nthread; nbins_ = nbins; } @@ -268,7 +268,7 @@ class GHistBuilder { /*! \brief number of threads for parallel computation */ size_t nthread_; /*! \brief number of all bins over all features */ - size_t nbins_; + uint32_t nbins_; std::vector data_; }; diff --git a/src/common/row_set.h b/src/common/row_set.h index 15b308721..cc0f846a6 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -21,14 +21,14 @@ class RowSetCollection { * rows (instances) associated with a particular node in a decision * tree. */ struct Elem { - const bst_uint* begin; - const bst_uint* end; + const size_t* begin; + const size_t* end; int node_id; // id of node associated with this instance set; -1 means uninitialized Elem(void) : begin(nullptr), end(nullptr), node_id(-1) {} - Elem(const bst_uint* begin, - const bst_uint* end, + Elem(const size_t* begin, + const size_t* end, int node_id) : begin(begin), end(end), node_id(node_id) {} @@ -38,8 +38,8 @@ class RowSetCollection { }; /* \brief specifies how to split a rowset into two */ struct Split { - std::vector left; - std::vector right; + std::vector left; + std::vector right; }; inline std::vector::const_iterator begin() const { @@ -65,8 +65,8 @@ class RowSetCollection { // initialize node id 0->everything inline void Init() { CHECK_EQ(elem_of_each_node_.size(), 0U); - const bst_uint* begin = dmlc::BeginPtr(row_indices_); - const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); + const size_t* begin = dmlc::BeginPtr(row_indices_); + const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); elem_of_each_node_.emplace_back(Elem(begin, end, 0)); } // split rowset into two @@ -77,16 +77,15 @@ class RowSetCollection { const Elem e = elem_of_each_node_[node_id]; const unsigned nthread = row_split_tloc.size(); CHECK(e.begin != nullptr); - bst_uint* all_begin = dmlc::BeginPtr(row_indices_); - bst_uint* begin = all_begin + (e.begin - all_begin); + size_t* all_begin = dmlc::BeginPtr(row_indices_); + size_t* begin = all_begin + (e.begin - all_begin); - bst_uint* it = begin; - // TODO(hcho3): parallelize this section + size_t* it = begin; for (bst_omp_uint tid = 0; tid < nthread; ++tid) { std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); it += row_split_tloc[tid].left.size(); } - bst_uint* split_pt = it; + size_t* split_pt = it; for (bst_omp_uint tid = 0; tid < nthread; ++tid) { std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); it += row_split_tloc[tid].right.size(); @@ -105,7 +104,7 @@ class RowSetCollection { } // stores the row indices in the set - std::vector row_indices_; + std::vector row_indices_; private: // vector: node_id -> elements diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc index 81f0b36d1..3f1c6c5ee 100644 --- a/src/tree/updater_fast_hist.cc +++ b/src/tree/updater_fast_hist.cc @@ -61,7 +61,7 @@ class FastHistMaker: public TreeUpdater { TStats::CheckInfo(dmat->info()); if (is_gmat_initialized_ == false) { double tstart = dmlc::GetTime(); - hmat_.Init(dmat, param.max_bin); + hmat_.Init(dmat, static_cast(param.max_bin)); gmat_.cut = &hmat_; gmat_.Init(dmat); column_matrix_.Init(gmat_, fhparam); @@ -111,23 +111,6 @@ class FastHistMaker: public TreeUpdater { bool is_gmat_initialized_; // data structure - /*! \brief per thread x per node entry to store tmp data */ - struct ThreadEntry { - /*! \brief statistics of data */ - TStats stats; - /*! \brief extra statistics of data */ - TStats stats_extra; - /*! \brief last feature value scanned */ - float last_fvalue; - /*! \brief first feature value scanned */ - float first_fvalue; - /*! \brief current best solution */ - SplitEntry best; - // constructor - explicit ThreadEntry(const TrainParam& param) - : stats(param), stats_extra(param) { - } - }; struct NodeEntry { /*! \brief statics for node entry */ TStats stats; @@ -340,7 +323,7 @@ class FastHistMaker: public TreeUpdater { } leaf_value = (*p_last_tree_)[nid].leaf_value(); - for (const bst_uint* it = rowset.begin; it < rowset.end; ++it) { + for (const size_t* it = rowset.begin; it < rowset.end; ++it) { out_preds[*it] += leaf_value; } } @@ -372,7 +355,7 @@ class FastHistMaker: public TreeUpdater { // clear local prediction cache leaf_value_cache_.clear(); // initialize histogram collection - size_t nbins = gmat.cut->row_ptr.back(); + uint32_t nbins = gmat.cut->row_ptr.back(); hist_.Init(nbins); // initialize histogram builder @@ -383,18 +366,18 @@ class FastHistMaker: public TreeUpdater { hist_builder_.Init(this->nthread, nbins); CHECK_EQ(info.root_index.size(), 0U); - std::vector& row_indices = row_set_collection_.row_indices_; + std::vector& row_indices = row_set_collection_.row_indices_; // mark subsample and build list of member rows if (param.subsample < 1.0f) { std::bernoulli_distribution coin_flip(param.subsample); auto& rnd = common::GlobalRandom(); - for (bst_uint i = 0; i < info.num_row; ++i) { + for (size_t i = 0; i < info.num_row; ++i) { if (gpair[i].hess >= 0.0f && coin_flip(rnd)) { row_indices.push_back(i); } } } else { - for (bst_uint i = 0; i < info.num_row; ++i) { + for (size_t i = 0; i < info.num_row; ++i) { if (gpair[i].hess >= 0.0f) { row_indices.push_back(i); } @@ -405,11 +388,11 @@ class FastHistMaker: public TreeUpdater { { /* determine layout of data */ - const auto nrow = info.num_row; - const auto ncol = info.num_col; - const auto nnz = info.num_nonzero; + const size_t nrow = info.num_row; + const size_t ncol = info.num_col; + const size_t nnz = info.num_nonzero; // number of discrete bins for feature 0 - const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0]; + const uint32_t nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0]; if (nrow * ncol == nnz) { // dense data with zero-based indexing data_layout_ = kDenseDataZeroBased; @@ -427,19 +410,19 @@ class FastHistMaker: public TreeUpdater { // store a pointer to training data p_last_fmat_ = &fmat; // initialize feature index - unsigned ncol = static_cast(info.num_col); + bst_uint ncol = static_cast(info.num_col); feat_index.clear(); if (data_layout_ == kDenseDataOneBased) { - for (unsigned i = 1; i < ncol; ++i) { + for (bst_uint i = 1; i < ncol; ++i) { feat_index.push_back(i); } } else { - for (unsigned i = 0; i < ncol; ++i) { + for (bst_uint i = 0; i < ncol; ++i) { feat_index.push_back(i); } } - unsigned n = std::max(static_cast(1), - static_cast(param.colsample_bytree * feat_index.size())); + bst_uint n = std::max(static_cast(1), + static_cast(param.colsample_bytree * feat_index.size())); std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom()); CHECK_GT(param.colsample_bytree, 0U) << "colsample_bytree cannot be zero."; @@ -450,11 +433,11 @@ class FastHistMaker: public TreeUpdater { choose the column that has a least positive number of discrete bins. For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ - const std::vector& row_ptr = gmat.cut->row_ptr; - const size_t nfeature = row_ptr.size() - 1; - size_t min_nbins_per_feature = 0; - for (size_t i = 0; i < nfeature; ++i) { - const unsigned nbins = row_ptr[i + 1] - row_ptr[i]; + const std::vector& row_ptr = gmat.cut->row_ptr; + const bst_uint nfeature = static_cast(row_ptr.size() - 1); + uint32_t min_nbins_per_feature = 0; + for (bst_uint i = 0; i < nfeature; ++i) { + const uint32_t nbins = row_ptr[i + 1] - row_ptr[i]; if (nbins > 0) { if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { min_nbins_per_feature = nbins; @@ -485,7 +468,7 @@ class FastHistMaker: public TreeUpdater { const std::vector& feat_set) { // start enumeration const MetaInfo& info = fmat.info(); - const bst_omp_uint nfeature = feat_set.size(); + const bst_uint nfeature = static_cast(feat_set.size()); const bst_omp_uint nthread = static_cast(this->nthread); best_split_tloc_.resize(nthread); #pragma omp parallel for schedule(static) num_threads(nthread) @@ -547,13 +530,17 @@ class FastHistMaker: public TreeUpdater { const bool default_left = (*p_tree)[nid].default_left(); const bst_uint fid = (*p_tree)[nid].split_index(); const bst_float split_pt = (*p_tree)[nid].split_cond(); - const bst_uint lower_bound = gmat.cut->row_ptr[fid]; - const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1]; - bst_int split_cond = -1; + const uint32_t lower_bound = gmat.cut->row_ptr[fid]; + const uint32_t upper_bound = gmat.cut->row_ptr[fid + 1]; + int32_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points - for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) { - if (split_pt == gmat.cut->cut[i]) split_cond = static_cast(i); + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut->cut[i]) { + split_cond = static_cast(i); + } } const auto& rowset = row_set_collection_[nid]; @@ -580,15 +567,15 @@ class FastHistMaker: public TreeUpdater { bool default_left) { std::vector& row_split_tloc = *p_row_split_tloc; const int K = 8; // loop unrolling factor - const bst_omp_uint nrows = rowset.end - rowset.begin; - const bst_omp_uint rest = nrows % K; + const size_t nrows = rowset.end - rowset.begin; + const size_t rest = nrows % K; #pragma omp parallel for num_threads(nthread) schedule(static) for (bst_omp_uint i = 0; i < nrows - rest; i += K) { const bst_uint tid = omp_get_thread_num(); auto& left = row_split_tloc[tid].left; auto& right = row_split_tloc[tid].right; - bst_uint rid[K]; + size_t rid[K]; T rbin[K]; for (int k = 0; k < K; ++k) { rid[k] = rowset.begin[i + k]; @@ -604,7 +591,9 @@ class FastHistMaker: public TreeUpdater { right.push_back(rid[k]); } } else { - if (static_cast(rbin[k] + column.index_base) <= split_cond) { + CHECK_LT(rbin[k] + column.index_base, + static_cast(std::numeric_limits::max())); + if (static_cast(rbin[k] + column.index_base) <= split_cond) { left.push_back(rid[k]); } else { right.push_back(rid[k]); @@ -612,10 +601,10 @@ class FastHistMaker: public TreeUpdater { } } } - for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { + for (size_t i = nrows - rest; i < nrows; ++i) { auto& left = row_split_tloc[nthread-1].left; auto& right = row_split_tloc[nthread-1].right; - const bst_uint rid = rowset.begin[i]; + const size_t rid = rowset.begin[i]; const T rbin = column.index[rid]; if (rbin == std::numeric_limits::max()) { // missing value if (default_left) { @@ -624,7 +613,9 @@ class FastHistMaker: public TreeUpdater { right.push_back(rid); } } else { - if (static_cast(rbin + column.index_base) <= split_cond) { + CHECK_LT(rbin + column.index_base, + static_cast(std::numeric_limits::max())); + if (static_cast(rbin + column.index_base) <= split_cond) { left.push_back(rid); } else { right.push_back(rid); @@ -642,13 +633,13 @@ class FastHistMaker: public TreeUpdater { bool default_left) { std::vector& row_split_tloc = *p_row_split_tloc; const int K = 8; // loop unrolling factor - const bst_omp_uint nrows = rowset.end - rowset.begin; - const bst_omp_uint rest = nrows % K; + const size_t nrows = rowset.end - rowset.begin; + const size_t rest = nrows % K; #pragma omp parallel for num_threads(nthread) schedule(static) for (bst_omp_uint i = 0; i < nrows - rest; i += K) { - bst_uint rid[K]; + size_t rid[K]; GHistIndexRow row[K]; - const unsigned* p[K]; + const uint32_t* p[K]; bst_uint tid = omp_get_thread_num(); auto& left = row_split_tloc[tid].left; auto& right = row_split_tloc[tid].right; @@ -663,7 +654,9 @@ class FastHistMaker: public TreeUpdater { } for (int k = 0; k < K; ++k) { if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) { - if (static_cast(*p[k]) <= split_cond) { + CHECK_LT(*p[k], + static_cast(std::numeric_limits::max())); + if (static_cast(*p[k]) <= split_cond) { left.push_back(rid[k]); } else { right.push_back(rid[k]); @@ -677,14 +670,15 @@ class FastHistMaker: public TreeUpdater { } } } - for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { - const bst_uint rid = rowset.begin[i]; + for (size_t i = nrows - rest; i < nrows; ++i) { + const size_t rid = rowset.begin[i]; const auto row = gmat[rid]; const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound); auto& left = row_split_tloc[0].left; auto& right = row_split_tloc[0].right; if (p != row.index + row.size && *p < upper_bound) { - if (static_cast(*p) <= split_cond) { + CHECK_LT(*p, static_cast(std::numeric_limits::max())); + if (static_cast(*p) <= split_cond) { left.push_back(rid); } else { right.push_back(rid); @@ -709,26 +703,26 @@ class FastHistMaker: public TreeUpdater { bst_int split_cond, bool default_left) { std::vector& row_split_tloc = *p_row_split_tloc; - const bst_omp_uint nrows = rowset.end - rowset.begin; + const size_t nrows = rowset.end - rowset.begin; #pragma omp parallel num_threads(nthread) { - const bst_uint tid = omp_get_thread_num(); - const bst_omp_uint ibegin = tid * nrows / nthread; - const bst_omp_uint iend = (tid + 1) * nrows / nthread; + const size_t tid = static_cast(omp_get_thread_num()); + const size_t ibegin = tid * nrows / nthread; + const size_t iend = (tid + 1) * nrows / nthread; if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range // search first nonzero row with index >= rowset[ibegin] - const uint32_t* p = std::lower_bound(column.row_ind, - column.row_ind + column.len, - rowset.begin[ibegin]); + const size_t* p = std::lower_bound(column.row_ind, + column.row_ind + column.len, + rowset.begin[ibegin]); auto& left = row_split_tloc[tid].left; auto& right = row_split_tloc[tid].right; if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) { - bst_omp_uint cursor = p - column.row_ind; + size_t cursor = p - column.row_ind; - for (bst_omp_uint i = ibegin; i < iend; ++i) { - const bst_uint rid = rowset.begin[i]; + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; while (cursor < column.len && column.row_ind[cursor] < rid && column.row_ind[cursor] <= rowset.begin[iend - 1]) { @@ -736,7 +730,9 @@ class FastHistMaker: public TreeUpdater { } if (cursor < column.len && column.row_ind[cursor] == rid) { const T rbin = column.index[cursor]; - if (static_cast(rbin + column.index_base) <= split_cond) { + CHECK_LT(rbin + column.index_base, + static_cast(std::numeric_limits::max())); + if (static_cast(rbin + column.index_base) <= split_cond) { left.push_back(rid); } else { right.push_back(rid); @@ -753,13 +749,13 @@ class FastHistMaker: public TreeUpdater { } } else { // all rows in [ibegin, iend) have missing values if (default_left) { - for (bst_omp_uint i = ibegin; i < iend; ++i) { - const bst_uint rid = rowset.begin[i]; + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; left.push_back(rid); } } else { - for (bst_omp_uint i = ibegin; i < iend; ++i) { - const bst_uint rid = rowset.begin[i]; + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; right.push_back(rid); } } @@ -786,17 +782,17 @@ class FastHistMaker: public TreeUpdater { For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ GHistRow hist = hist_[nid]; - const std::vector& row_ptr = gmat.cut->row_ptr; + const std::vector& row_ptr = gmat.cut->row_ptr; - const size_t ibegin = row_ptr[fid_least_bins_]; - const size_t iend = row_ptr[fid_least_bins_ + 1]; - for (size_t i = ibegin; i < iend; ++i) { + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + for (uint32_t i = ibegin; i < iend; ++i) { const GHistEntry et = hist.begin[i]; stats.Add(et.sum_grad, et.sum_hess); } } else { const RowSetCollection::Elem e = row_set_collection_[nid]; - for (const bst_uint* it = e.begin; it < e.end; ++it) { + for (const size_t* it = e.begin; it < e.end; ++it) { stats.Add(gpair[*it]); } } @@ -831,7 +827,7 @@ class FastHistMaker: public TreeUpdater { CHECK(d_step == +1 || d_step == -1); // aliases - const std::vector& cut_ptr = gmat.cut->row_ptr; + const std::vector& cut_ptr = gmat.cut->row_ptr; const std::vector& cut_val = gmat.cut->cut; // statistics on both sides of split @@ -841,20 +837,25 @@ class FastHistMaker: public TreeUpdater { SplitEntry best; // bin boundaries + CHECK_LE(cut_ptr[fid], + static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fid + 1], + static_cast(std::numeric_limits::max())); // imin: index (offset) of the minimum value for feature fid // need this for backward enumeration - const int imin = cut_ptr[fid]; + const int32_t imin = static_cast(cut_ptr[fid]); // ibegin, iend: smallest/largest cut points for feature fid - int ibegin, iend; + // use int to allow for value -1 + int32_t ibegin, iend; if (d_step > 0) { - ibegin = cut_ptr[fid]; - iend = cut_ptr[fid + 1]; + ibegin = static_cast(cut_ptr[fid]); + iend = static_cast(cut_ptr[fid + 1]); } else { - ibegin = cut_ptr[fid + 1] - 1; - iend = cut_ptr[fid] - 1; + ibegin = static_cast(cut_ptr[fid + 1]) - 1; + iend = static_cast(cut_ptr[fid]) - 1; } - for (int i = ibegin; i != iend; i += d_step) { + for (int32_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); @@ -930,7 +931,7 @@ class FastHistMaker: public TreeUpdater { HistCollection hist_; /*! \brief feature with least # of bins. to be used for dense specialization of InitNewNode() */ - size_t fid_least_bins_; + uint32_t fid_least_bins_; /*! \brief local prediction cache; maps node id to leaf value */ std::vector leaf_value_cache_; diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index 4ea1fc8e1..8e57cea77 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -1,5 +1,6 @@ #include "../../../src/common/compressed_iterator.h" #include "gtest/gtest.h" +#include namespace xgboost { namespace common { @@ -51,4 +52,4 @@ TEST(CompressedIterator, Test) { } } // namespace common -} // namespace xgboost \ No newline at end of file +} // namespace xgboost