[GPU-Plugin] Fix gpu_hist to allow matrices with more than just 2^{32} elements. Also fixed CPU hist algorithm. (#2518)
This commit is contained in:
parent
c85bf9859e
commit
ca7fc9fda3
@ -441,7 +441,7 @@ class bulk_allocator {
|
||||
|
||||
public:
|
||||
~bulk_allocator() {
|
||||
for (int i = 0; i < d_ptr.size(); i++) {
|
||||
for (size_t i = 0; i < d_ptr.size(); i++) {
|
||||
if (!(d_ptr[i] == nullptr)) {
|
||||
safe_cuda(cudaSetDevice(_device_idx[i]));
|
||||
safe_cuda(cudaFree(d_ptr[i]));
|
||||
@ -522,7 +522,7 @@ inline size_t available_memory(int device_idx) {
|
||||
template <typename T>
|
||||
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
thrust::host_vector<T> h = v;
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
@ -531,7 +531,7 @@ void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
template <typename T>
|
||||
void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
std::vector<T> h = v.as_vector();
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
@ -539,10 +539,10 @@ void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
|
||||
template <typename T>
|
||||
void print(char *label, const thrust::device_vector<T> &v,
|
||||
const char *format = "%d ", int max = 10) {
|
||||
const char *format = "%d ", size_t max = 10) {
|
||||
thrust::host_vector<T> h_v = v;
|
||||
std::cout << label << ":\n";
|
||||
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
|
||||
for (size_t i = 0; i < std::min(static_cast<size_t>(h_v.size()), max); i++) {
|
||||
printf(format, h_v[i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
@ -593,6 +593,7 @@ __global__ void launch_n_kernel(int device_idx, size_t begin, size_t end,
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void launch_n(int device_idx, size_t n, L lambda) {
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
// TODO: Template on n so GRID_SIZE always fits into int.
|
||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||
#if defined(__CUDACC__)
|
||||
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
|
||||
@ -607,6 +608,7 @@ inline void multi_launch_n(size_t n, int n_devices, L lambda) {
|
||||
CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested "
|
||||
"needs to be less than equal to "
|
||||
"number of visible devices.";
|
||||
// TODO: Template on n so GRID_SIZE always fits into int.
|
||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||
#if defined(__CUDACC__)
|
||||
n_devices = n_devices > n ? n : n_devices;
|
||||
|
||||
@ -19,8 +19,8 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
||||
bst_uint element_begin, bst_uint element_end,
|
||||
bst_uint row_begin, bst_uint row_end, int n_bins) {
|
||||
bst_ulong element_begin, bst_ulong element_end,
|
||||
bst_ulong row_begin, bst_ulong row_end, int n_bins) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated";
|
||||
CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1)
|
||||
@ -31,15 +31,15 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
||||
cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin,
|
||||
gmat.index.begin() + element_end);
|
||||
gidx_buffer = host_buffer;
|
||||
gidx = common::CompressedIterator<int>(gidx_buffer.data(), n_bins);
|
||||
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.data(), n_bins);
|
||||
|
||||
// row_ptr
|
||||
thrust::copy(gmat.row_ptr.data() + row_begin,
|
||||
gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
|
||||
// normalise row_ptr
|
||||
bst_uint start = gmat.row_ptr[row_begin];
|
||||
size_t start = gmat.row_ptr[row_begin];
|
||||
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
|
||||
[=] __device__(int val) { return val - start; });
|
||||
[=] __device__(size_t val) { return val - start; });
|
||||
}
|
||||
|
||||
void DeviceHist::Init(int n_bins_in) {
|
||||
@ -193,7 +193,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
device_row_segments.push_back(0);
|
||||
device_element_segments.push_back(0);
|
||||
bst_uint offset = 0;
|
||||
size_t shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
|
||||
bst_uint shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
int device_idx = dList[d_idx];
|
||||
offset += shard_size;
|
||||
@ -261,7 +261,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
int device_idx = dList[d_idx];
|
||||
bst_uint num_rows_segment =
|
||||
device_row_segments[d_idx + 1] - device_row_segments[d_idx];
|
||||
bst_uint num_elements_segment =
|
||||
bst_ulong num_elements_segment =
|
||||
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
|
||||
ba.allocate(
|
||||
device_idx, &(hist_vec[d_idx].data),
|
||||
@ -360,7 +360,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
||||
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
||||
dh::TransformLbs(
|
||||
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
|
||||
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
||||
int nidx = d_position[local_ridx]; // OPTMARK: latency
|
||||
if (!is_active(nidx, depth)) return;
|
||||
|
||||
@ -606,7 +606,11 @@ __global__ void find_split_kernel(
|
||||
}
|
||||
|
||||
#define MIN_BLOCK_THREADS 32
|
||||
#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size
|
||||
#define CHUNK_BLOCK_THREADS 32
|
||||
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
|
||||
// to CUDA compatibility 35 and above requirement
|
||||
// for Maximum number of threads per block
|
||||
#define MAX_BLOCK_THREADS 1024
|
||||
|
||||
void GPUHistBuilder::FindSplit(int depth) {
|
||||
// Specialised based on max_bins
|
||||
@ -622,7 +626,7 @@ void GPUHistBuilder::FindSplitSpecialize(int depth) {
|
||||
if (param.max_bin <= BLOCK_THREADS) {
|
||||
LaunchFindSplit<BLOCK_THREADS>(depth);
|
||||
} else {
|
||||
this->FindSplitSpecialize<BLOCK_THREADS + 32>(depth);
|
||||
this->FindSplitSpecialize<BLOCK_THREADS + CHUNK_BLOCK_THREADS>(depth);
|
||||
}
|
||||
}
|
||||
|
||||
@ -885,7 +889,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
size_t begin = device_row_segments[d_idx];
|
||||
size_t end = device_row_segments[d_idx + 1];
|
||||
|
||||
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
|
||||
dh::launch_n(device_idx, end - begin, [=] __device__(size_t local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
@ -896,7 +900,8 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
return;
|
||||
}
|
||||
|
||||
int gidx = d_gidx[local_idx * n_columns + node.split.findex];
|
||||
int gidx = d_gidx[local_idx *
|
||||
static_cast<size_t>(n_columns) + static_cast<size_t>(node.split.findex)];
|
||||
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
@ -955,7 +960,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
||||
|
||||
dh::TransformLbs(
|
||||
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
|
||||
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
||||
int pos = d_position[local_ridx];
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
@ -1065,25 +1070,13 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode(gpair);
|
||||
this->ColSampleTree();
|
||||
// long long int elapsed=0;
|
||||
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||
this->ColSampleLevel();
|
||||
|
||||
// dh::Timer time;
|
||||
this->BuildHist(depth);
|
||||
// elapsed+=time.elapsed();
|
||||
// printf("depth=%d\n",depth);
|
||||
// time.printElapsed("BH Time");
|
||||
|
||||
// dh::Timer timesplit;
|
||||
this->FindSplit(depth);
|
||||
// timesplit.printElapsed("FS Time");
|
||||
|
||||
// dh::Timer timeupdatepos;
|
||||
this->UpdatePosition(depth);
|
||||
// timeupdatepos.printElapsed("UP Time");
|
||||
}
|
||||
// printf("Total BuildHist Time=%lld\n",elapsed);
|
||||
|
||||
// done with multi-GPU, pass back result from master to tree on host
|
||||
int master_device = dList[0];
|
||||
|
||||
@ -18,10 +18,10 @@ namespace tree {
|
||||
|
||||
struct DeviceGMat {
|
||||
dh::dvec<common::compressed_byte_t> gidx_buffer;
|
||||
common::CompressedIterator<int > gidx;
|
||||
dh::dvec<int> row_ptr;
|
||||
common::CompressedIterator<uint32_t> gidx;
|
||||
dh::dvec<size_t> row_ptr;
|
||||
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
|
||||
bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins);
|
||||
bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, bst_ulong row_end,int n_bins);
|
||||
};
|
||||
|
||||
struct HistBuilder {
|
||||
@ -99,7 +99,7 @@ class GPUHistBuilder {
|
||||
// below vectors are for each devices used
|
||||
std::vector<int> dList;
|
||||
std::vector<int> device_row_segments;
|
||||
std::vector<int> device_element_segments;
|
||||
std::vector<size_t> device_element_segments;
|
||||
|
||||
std::vector<dh::CubMemory> temp_memory;
|
||||
std::vector<DeviceHist> hist_vec;
|
||||
|
||||
134
plugin/updater_gpu/test/python/test_large.py
Normal file
134
plugin/updater_gpu/test/python/test_large.py
Normal file
@ -0,0 +1,134 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
import sys
|
||||
sys.path.append("../../tests/python")
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
from sklearn.datasets import make_classification
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
|
||||
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
|
||||
|
||||
eprint("Testing Big Data (this may take a while)")
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
|
||||
cols = 31
|
||||
# reduced to fit onto 1 gpu but still be large
|
||||
rows2 = 5000 # medium
|
||||
#rows2 = 4032 # fake large for testing
|
||||
rows1 = 42360032 # large
|
||||
#rows2 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rowslist = [rows1, rows2]
|
||||
|
||||
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_large(self):
|
||||
eprint("Starting test for large data")
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
except:
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
|
||||
for rows in rowslist:
|
||||
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||
X, y = make_classification(rows, n_features=cols, random_state=7)
|
||||
rowstest = int(rows*0.2)
|
||||
eprint("Creating test data rows=%d cols=%d" % (rowstest,cols))
|
||||
# note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot
|
||||
Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8)
|
||||
|
||||
eprint("Starting DMatrix(X,y)")
|
||||
ag_dtrain = xgb.DMatrix(X,y)
|
||||
eprint("Starting DMatrix(Xtest,ytest)")
|
||||
ag_dtest = xgb.DMatrix(Xtest,ytest)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
max_depth=6
|
||||
max_bin=1024
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
#'nthread': 1,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_paramb = {'max_depth': max_depth,
|
||||
'tree_method': 'hist',
|
||||
#'nthread': 1,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
ag_param3 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': -1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
#ag_param4 = {'max_depth': max_depth,
|
||||
# 'tree_method': 'gpu_exact',
|
||||
# 'eta': 1,
|
||||
# 'silent': 0,
|
||||
# 'n_gpus': 1,
|
||||
# 'objective': 'binary:logistic',
|
||||
# 'max_bin': max_bin,
|
||||
# 'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
ag_resb = {}
|
||||
ag_res2 = {}
|
||||
ag_res3 = {}
|
||||
#ag_res4 = {}
|
||||
|
||||
num_rounds = 1
|
||||
|
||||
eprint("normal updater")
|
||||
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res)
|
||||
eprint("hist updater")
|
||||
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_resb)
|
||||
eprint("gpu_hist updater 1 gpu")
|
||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res2)
|
||||
eprint("gpu_hist updater all gpus")
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res3)
|
||||
#eprint("gpu_exact updater")
|
||||
#xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
# evals_result=ag_res4)
|
||||
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
|
||||
#assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001
|
||||
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01
|
||||
#assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01
|
||||
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class Column {
|
||||
ColumnType type;
|
||||
const T* index;
|
||||
uint32_t index_base;
|
||||
const uint32_t* row_ind;
|
||||
const size_t* row_ind;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
@ -66,8 +66,8 @@ class Column {
|
||||
class ColumnMatrix {
|
||||
public:
|
||||
// get number of features
|
||||
inline uint32_t GetNumFeature() const {
|
||||
return type_.size();
|
||||
inline bst_uint GetNumFeature() const {
|
||||
return static_cast<bst_uint>(type_.size());
|
||||
}
|
||||
|
||||
// construct column matrix from GHistIndexMatrix
|
||||
@ -78,8 +78,8 @@ class ColumnMatrix {
|
||||
slot of internal buffer. */
|
||||
packing_factor_ = sizeof(uint32_t) / static_cast<size_t>(this->dtype);
|
||||
|
||||
const uint32_t nfeature = gmat.cut->row_ptr.size() - 1;
|
||||
const omp_ulong nrow = static_cast<omp_ulong>(gmat.row_ptr.size() - 1);
|
||||
const bst_uint nfeature = static_cast<bst_uint>(gmat.cut->row_ptr.size() - 1);
|
||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||
|
||||
// identify type of each column
|
||||
feature_counts_.resize(nfeature);
|
||||
@ -90,13 +90,13 @@ class ColumnMatrix {
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
max_val = static_cast<uint32_t>(std::numeric_limits<DType>::max());
|
||||
});
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
||||
CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val);
|
||||
}
|
||||
|
||||
gmat.GetFeatureCounts(&feature_counts_[0]);
|
||||
// classify features
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
||||
if (static_cast<double>(feature_counts_[fid])
|
||||
< param.sparse_threshold * nrow) {
|
||||
type_[fid] = kSparseColumn;
|
||||
@ -108,13 +108,13 @@ class ColumnMatrix {
|
||||
// want to compute storage boundary for each feature
|
||||
// using variants of prefix sum scan
|
||||
boundary_.resize(nfeature);
|
||||
bst_uint accum_index_ = 0;
|
||||
bst_uint accum_row_ind_ = 0;
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
size_t accum_index_ = 0;
|
||||
size_t accum_row_ind_ = 0;
|
||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
||||
boundary_[fid].index_begin = accum_index_;
|
||||
boundary_[fid].row_ind_begin = accum_row_ind_;
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
accum_index_ += nrow;
|
||||
accum_index_ += static_cast<size_t>(nrow);
|
||||
} else {
|
||||
accum_index_ += feature_counts_[fid];
|
||||
accum_row_ind_ += feature_counts_[fid];
|
||||
@ -129,14 +129,14 @@ class ColumnMatrix {
|
||||
|
||||
// store least bin id for each feature
|
||||
index_base_.resize(nfeature);
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
||||
index_base_[fid] = gmat.cut->row_ptr[fid];
|
||||
}
|
||||
|
||||
// fill index_ for dense columns
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
// pre-fill index_ for dense columns
|
||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
const uint32_t ibegin = boundary_[fid].index_begin;
|
||||
const size_t ibegin = boundary_[fid].index_begin;
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
const size_t block_offset = ibegin / packing_factor_;
|
||||
const size_t elem_offset = ibegin % packing_factor_;
|
||||
@ -150,15 +150,15 @@ class ColumnMatrix {
|
||||
|
||||
// loop over all rows and fill column entries
|
||||
// num_nonzeros[fid] = how many nonzeros have this feature accumulated so far?
|
||||
std::vector<uint32_t> num_nonzeros;
|
||||
std::vector<size_t> num_nonzeros;
|
||||
num_nonzeros.resize(nfeature);
|
||||
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
|
||||
for (uint32_t rid = 0; rid < nrow; ++rid) {
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
for (size_t rid = 0; rid < nrow; ++rid) {
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
size_t fid = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t bin_id = gmat.index[i];
|
||||
const uint32_t bin_id = gmat.index[i];
|
||||
while (bin_id >= gmat.cut->row_ptr[fid + 1]) {
|
||||
++fid;
|
||||
}
|
||||
@ -167,14 +167,14 @@ class ColumnMatrix {
|
||||
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
|
||||
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
|
||||
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
|
||||
begin[rid] = bin_id - index_base_[fid];
|
||||
begin[rid] = static_cast<DType>(bin_id - index_base_[fid]);
|
||||
});
|
||||
} else {
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
|
||||
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
|
||||
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
|
||||
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
|
||||
begin[num_nonzeros[fid]] = static_cast<DType>(bin_id - index_base_[fid]);
|
||||
});
|
||||
row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid;
|
||||
++num_nonzeros[fid];
|
||||
@ -213,16 +213,16 @@ class ColumnMatrix {
|
||||
// indicate where each column's index and row_ind is stored.
|
||||
// index_begin and index_end are logical offsets, so they should be converted to
|
||||
// actual offsets by scaling with packing_factor_
|
||||
unsigned index_begin;
|
||||
unsigned index_end;
|
||||
unsigned row_ind_begin;
|
||||
unsigned row_ind_end;
|
||||
size_t index_begin;
|
||||
size_t index_end;
|
||||
size_t row_ind_begin;
|
||||
size_t row_ind_end;
|
||||
};
|
||||
|
||||
std::vector<bst_uint> feature_counts_;
|
||||
std::vector<size_t> feature_counts_;
|
||||
std::vector<ColumnType> type_;
|
||||
std::vector<uint32_t> index_; // index_: may store smaller integers; needs padding
|
||||
std::vector<uint32_t> row_ind_;
|
||||
std::vector<size_t> row_ind_;
|
||||
std::vector<ColumnBoundary> boundary_;
|
||||
|
||||
size_t packing_factor_; // how many integers are stored in each slot of index_
|
||||
|
||||
@ -46,11 +46,11 @@ static int SymbolBits(int num_symbols) {
|
||||
|
||||
class CompressedBufferWriter {
|
||||
private:
|
||||
int symbol_bits_;
|
||||
size_t symbol_bits_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
explicit CompressedBufferWriter(int num_symbols) : offset_(0) {
|
||||
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
|
||||
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||
}
|
||||
|
||||
@ -70,9 +70,9 @@ class CompressedBufferWriter {
|
||||
* \return The calculated buffer size.
|
||||
*/
|
||||
|
||||
static size_t CalculateBufferSize(int num_elements, int num_symbols) {
|
||||
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
|
||||
const int bits_per_byte = 8;
|
||||
int compressed_size = std::ceil(
|
||||
size_t compressed_size = std::ceil(
|
||||
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
|
||||
bits_per_byte);
|
||||
return compressed_size + detail::padding;
|
||||
@ -82,10 +82,10 @@ class CompressedBufferWriter {
|
||||
void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) {
|
||||
const int bits_per_byte = 8;
|
||||
|
||||
for (int i = 0; i < symbol_bits_; i++) {
|
||||
for (size_t i = 0; i < symbol_bits_; i++) {
|
||||
size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte;
|
||||
byte_idx += detail::padding;
|
||||
int bit_idx =
|
||||
size_t bit_idx =
|
||||
((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte;
|
||||
|
||||
if (detail::CheckBit(symbol, i)) {
|
||||
@ -100,14 +100,14 @@ class CompressedBufferWriter {
|
||||
uint64_t tmp = 0;
|
||||
int stored_bits = 0;
|
||||
const int max_stored_bits = 64 - symbol_bits_;
|
||||
int buffer_position = detail::padding;
|
||||
const int num_symbols = input_end - input_begin;
|
||||
for (int i = 0; i < num_symbols; i++) {
|
||||
size_t buffer_position = detail::padding;
|
||||
const size_t num_symbols = input_end - input_begin;
|
||||
for (size_t i = 0; i < num_symbols; i++) {
|
||||
typename std::iterator_traits<iter_t>::value_type symbol = input_begin[i];
|
||||
if (stored_bits > max_stored_bits) {
|
||||
// Eject only full bytes
|
||||
int tmp_bytes = stored_bits / 8;
|
||||
for (int j = 0; j < tmp_bytes; j++) {
|
||||
size_t tmp_bytes = stored_bits / 8;
|
||||
for (size_t j = 0; j < tmp_bytes; j++) {
|
||||
buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8);
|
||||
buffer_position++;
|
||||
}
|
||||
@ -121,8 +121,8 @@ class CompressedBufferWriter {
|
||||
}
|
||||
|
||||
// Eject all bytes
|
||||
int tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
|
||||
for (int j = 0; j < tmp_bytes; j++) {
|
||||
size_t tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
|
||||
for (size_t j = 0; j < tmp_bytes; j++) {
|
||||
int shift_bits = stored_bits - (j + 1) * 8;
|
||||
if (shift_bits >= 0) {
|
||||
buffer[buffer_position] = tmp >> shift_bits;
|
||||
@ -159,7 +159,7 @@ class CompressedIterator {
|
||||
/// iterator can point to
|
||||
private:
|
||||
compressed_byte_t *buffer_;
|
||||
int symbol_bits_;
|
||||
size_t symbol_bits_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
@ -189,7 +189,7 @@ class CompressedIterator {
|
||||
return static_cast<T>(tmp & mask);
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE reference operator[](int idx) const {
|
||||
XGBOOST_DEVICE reference operator[](size_t idx) const {
|
||||
self_type offset = (*this);
|
||||
offset.offset_ += idx;
|
||||
return *offset;
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
|
||||
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
const MetaInfo& info = p_fmat->info();
|
||||
|
||||
@ -44,7 +44,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
|
||||
unsigned begin = std::min(nstep * tid, ncol);
|
||||
unsigned end = std::min(nstep * (tid + 1), ncol);
|
||||
for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*)
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
size_t ridx = batch.base_rowid + i;
|
||||
RowBatch::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
if (inst[j].index >= begin && inst[j].index < end) {
|
||||
@ -108,7 +108,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
|
||||
const int nthread = omp_get_max_threads();
|
||||
const unsigned nbins = cut->row_ptr.back();
|
||||
const uint32_t nbins = cut->row_ptr.back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||
|
||||
@ -116,7 +116,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
row_ptr.push_back(0);
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
size_t rbegin = row_ptr.size() - 1;
|
||||
const size_t rbegin = row_ptr.size() - 1;
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
row_ptr.push_back(batch[i].length + row_ptr.back());
|
||||
}
|
||||
@ -140,7 +140,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
CHECK(cbegin != cend);
|
||||
auto it = std::upper_bound(cbegin, cend, inst[j].fvalue);
|
||||
if (it == cend) it = cend - 1;
|
||||
unsigned idx = static_cast<unsigned>(it - cut->cut.begin());
|
||||
uint32_t idx = static_cast<uint32_t>(it - cut->cut.begin());
|
||||
index[ibegin + j] = idx;
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
@ -148,7 +148,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (omp_ulong idx = 0; idx < nbins; ++idx) {
|
||||
for (bst_omp_uint idx = 0; idx < nbins; ++idx) {
|
||||
for (int tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
@ -157,10 +157,10 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static unsigned GetConflictCount(const std::vector<bool>& mark,
|
||||
const Column<T>& column,
|
||||
unsigned max_cnt) {
|
||||
unsigned ret = 0;
|
||||
static size_t GetConflictCount(const std::vector<bool>& mark,
|
||||
const Column<T>& column,
|
||||
size_t max_cnt) {
|
||||
size_t ret = 0;
|
||||
if (column.type == xgboost::common::kDenseColumn) {
|
||||
for (size_t i = 0; i < column.len; ++i) {
|
||||
if (column.index[i] != std::numeric_limits<T>::max() && mark[i]) {
|
||||
@ -203,9 +203,9 @@ MarkUsed(std::vector<bool>* p_mark, const Column<T>& column) {
|
||||
template <typename T>
|
||||
inline std::vector<std::vector<unsigned>>
|
||||
FindGroups_(const std::vector<unsigned>& feature_list,
|
||||
const std::vector<bst_uint>& feature_nnz,
|
||||
const std::vector<size_t>& feature_nnz,
|
||||
const ColumnMatrix& colmat,
|
||||
unsigned nrow,
|
||||
size_t nrow,
|
||||
const FastHistParam& param) {
|
||||
/* Goal: Bundle features together that has little or no "overlap", i.e.
|
||||
only a few data points should have nonzero values for
|
||||
@ -214,10 +214,10 @@ FindGroups_(const std::vector<unsigned>& feature_list,
|
||||
|
||||
std::vector<std::vector<unsigned>> groups;
|
||||
std::vector<std::vector<bool>> conflict_marks;
|
||||
std::vector<unsigned> group_nnz;
|
||||
std::vector<unsigned> group_conflict_cnt;
|
||||
const unsigned max_conflict_cnt
|
||||
= static_cast<unsigned>(param.max_conflict_rate * nrow);
|
||||
std::vector<size_t> group_nnz;
|
||||
std::vector<size_t> group_conflict_cnt;
|
||||
const size_t max_conflict_cnt
|
||||
= static_cast<size_t>(param.max_conflict_rate * nrow);
|
||||
|
||||
for (auto fid : feature_list) {
|
||||
const Column<T>& column = colmat.GetColumn<T>(fid);
|
||||
@ -239,8 +239,8 @@ FindGroups_(const std::vector<unsigned>& feature_list,
|
||||
|
||||
// examine each candidate group: is it okay to insert fid?
|
||||
for (auto gid : search_groups) {
|
||||
const unsigned rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid];
|
||||
const unsigned cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt);
|
||||
const size_t rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid];
|
||||
const size_t cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt);
|
||||
if (cnt <= rest_max_cnt) {
|
||||
need_new_group = false;
|
||||
groups[gid].push_back(fid);
|
||||
@ -267,9 +267,9 @@ FindGroups_(const std::vector<unsigned>& feature_list,
|
||||
|
||||
inline std::vector<std::vector<unsigned>>
|
||||
FindGroups(const std::vector<unsigned>& feature_list,
|
||||
const std::vector<bst_uint>& feature_nnz,
|
||||
const std::vector<size_t>& feature_nnz,
|
||||
const ColumnMatrix& colmat,
|
||||
unsigned nrow,
|
||||
size_t nrow,
|
||||
const FastHistParam& param) {
|
||||
XGBOOST_TYPE_SWITCH(colmat.dtype, {
|
||||
return FindGroups_<DType>(feature_list, feature_nnz, colmat, nrow, param);
|
||||
@ -288,11 +288,11 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
|
||||
std::iota(feature_list.begin(), feature_list.end(), 0);
|
||||
|
||||
// sort features by nonzero counts, descending order
|
||||
std::vector<bst_uint> feature_nnz(nfeature);
|
||||
std::vector<size_t> feature_nnz(nfeature);
|
||||
std::vector<unsigned> features_by_nnz(feature_list);
|
||||
gmat.GetFeatureCounts(&feature_nnz[0]);
|
||||
std::sort(features_by_nnz.begin(), features_by_nnz.end(),
|
||||
[&feature_nnz](int a, int b) {
|
||||
[&feature_nnz](unsigned a, unsigned b) {
|
||||
return feature_nnz[a] > feature_nnz[b];
|
||||
});
|
||||
|
||||
@ -307,7 +307,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
|
||||
if (group.size() <= 1 || group.size() >= 5) {
|
||||
ret.push_back(group); // keep singleton groups and large (5+) groups
|
||||
} else {
|
||||
unsigned nnz = 0;
|
||||
size_t nnz = 0;
|
||||
for (auto fid : group) {
|
||||
nnz += feature_nnz[fid];
|
||||
}
|
||||
@ -338,37 +338,37 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
cut = gmat.cut;
|
||||
|
||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||
const size_t nbins = gmat.cut->row_ptr.back();
|
||||
const uint32_t nbins = gmat.cut->row_ptr.back();
|
||||
|
||||
/* step 1: form feature groups */
|
||||
auto groups = FastFeatureGrouping(gmat, colmat, param);
|
||||
const size_t nblock = groups.size();
|
||||
const uint32_t nblock = static_cast<uint32_t>(groups.size());
|
||||
|
||||
/* step 2: build a new CSR matrix for each feature group */
|
||||
std::vector<unsigned> bin2block(nbins); // lookup table [bin id] => [block id]
|
||||
for (size_t group_id = 0; group_id < nblock; ++group_id) {
|
||||
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
|
||||
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
|
||||
for (auto& fid : groups[group_id]) {
|
||||
const unsigned bin_begin = gmat.cut->row_ptr[fid];
|
||||
const unsigned bin_end = gmat.cut->row_ptr[fid + 1];
|
||||
for (unsigned bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
|
||||
const uint32_t bin_begin = gmat.cut->row_ptr[fid];
|
||||
const uint32_t bin_end = gmat.cut->row_ptr[fid + 1];
|
||||
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
|
||||
bin2block[bin_id] = group_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<unsigned>> index_temp(nblock);
|
||||
std::vector<std::vector<unsigned>> row_ptr_temp(nblock);
|
||||
for (size_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
std::vector<std::vector<uint32_t>> index_temp(nblock);
|
||||
std::vector<std::vector<size_t>> row_ptr_temp(nblock);
|
||||
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
row_ptr_temp[block_id].push_back(0);
|
||||
}
|
||||
for (size_t rid = 0; rid < nrow; ++rid) {
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const size_t bin_id = gmat.index[j];
|
||||
const size_t block_id = bin2block[bin_id];
|
||||
const uint32_t bin_id = gmat.index[j];
|
||||
const uint32_t block_id = bin2block[bin_id];
|
||||
index_temp[block_id].push_back(bin_id);
|
||||
}
|
||||
for (size_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
row_ptr_temp[block_id].push_back(index_temp[block_id].size());
|
||||
}
|
||||
}
|
||||
@ -378,7 +378,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
std::vector<size_t> row_ptr_blk_ptr;
|
||||
index_blk_ptr.push_back(0);
|
||||
row_ptr_blk_ptr.push_back(0);
|
||||
for (size_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
index.insert(index.end(), index_temp[block_id].begin(), index_temp[block_id].end());
|
||||
row_ptr.insert(row_ptr.end(), row_ptr_temp[block_id].begin(), row_ptr_temp[block_id].end());
|
||||
index_blk_ptr.push_back(index.size());
|
||||
@ -386,7 +386,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
}
|
||||
|
||||
// save shortcut for each block
|
||||
for (size_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
for (uint32_t block_id = 0; block_id < nblock; ++block_id) {
|
||||
Block blk;
|
||||
blk.index_begin = &index[index_blk_ptr[block_id]];
|
||||
blk.row_ptr_begin = &row_ptr[row_ptr_blk_ptr[block_id]];
|
||||
@ -406,14 +406,14 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const bst_omp_uint nrows = row_indices.end - row_indices.begin;
|
||||
const bst_omp_uint rest = nrows % K;
|
||||
const size_t nrows = row_indices.end - row_indices.begin;
|
||||
const size_t rest = nrows % K;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
const bst_omp_uint tid = omp_get_thread_num();
|
||||
const size_t off = tid * nbins_;
|
||||
bst_uint rid[K];
|
||||
size_t rid[K];
|
||||
size_t ibegin[K];
|
||||
size_t iend[K];
|
||||
bst_gpair stat[K];
|
||||
@ -421,32 +421,32 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
rid[k] = row_indices.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
|
||||
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
|
||||
ibegin[k] = gmat.row_ptr[rid[k]];
|
||||
iend[k] = gmat.row_ptr[rid[k] + 1];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
stat[k] = gpair[rid[k]];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
const uint32_t bin = gmat.index[j];
|
||||
data_[off + bin].Add(stat[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
const bst_uint rid = row_indices.begin[i];
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
const size_t rid = row_indices.begin[i];
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
const bst_gpair stat = gpair[rid];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
const uint32_t bin = gmat.index[j];
|
||||
data_[bin].Add(stat);
|
||||
}
|
||||
}
|
||||
|
||||
/* reduction */
|
||||
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
|
||||
const uint32_t nbins = nbins_;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
@ -462,16 +462,16 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
|
||||
GHistRow hist) {
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const bst_omp_uint nblock = gmatb.GetNumBlock();
|
||||
const bst_omp_uint nrows = row_indices.end - row_indices.begin;
|
||||
const bst_omp_uint rest = nrows % K;
|
||||
const uint32_t nblock = gmatb.GetNumBlock();
|
||||
const size_t nrows = row_indices.end - row_indices.begin;
|
||||
const size_t rest = nrows % K;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
||||
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
|
||||
auto gmat = gmatb[bid];
|
||||
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
bst_uint rid[K];
|
||||
for (size_t i = 0; i < nrows - rest; i += K) {
|
||||
size_t rid[K];
|
||||
size_t ibegin[K];
|
||||
size_t iend[K];
|
||||
bst_gpair stat[K];
|
||||
@ -479,26 +479,26 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
|
||||
rid[k] = row_indices.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
|
||||
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
|
||||
ibegin[k] = gmat.row_ptr[rid[k]];
|
||||
iend[k] = gmat.row_ptr[rid[k] + 1];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
stat[k] = gpair[rid[k]];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
const uint32_t bin = gmat.index[j];
|
||||
hist.begin[bin].Add(stat[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
const bst_uint rid = row_indices.begin[i];
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
const size_t rid = row_indices.begin[i];
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
const bst_gpair stat = gpair[rid];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
const uint32_t bin = gmat.index[j];
|
||||
hist.begin[bin].Add(stat);
|
||||
}
|
||||
}
|
||||
@ -507,9 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
|
||||
|
||||
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
|
||||
const uint32_t nbins = static_cast<bst_omp_uint>(nbins_);
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint rest = nbins % K;
|
||||
const uint32_t rest = nbins % K;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins - rest; bin_id += K) {
|
||||
GHistEntry pb[K];
|
||||
@ -524,7 +524,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
|
||||
self.begin[bin_id + k].SetSubtract(pb[k], sb[k]);
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
|
||||
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
|
||||
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,30 +56,30 @@ struct HistCutUnit {
|
||||
/*! \brief the index pointer of each histunit */
|
||||
const bst_float* cut;
|
||||
/*! \brief number of cutting point, containing the maximum point */
|
||||
size_t size;
|
||||
uint32_t size;
|
||||
// default constructor
|
||||
HistCutUnit() {}
|
||||
// constructor
|
||||
HistCutUnit(const bst_float* cut, unsigned size)
|
||||
HistCutUnit(const bst_float* cut, uint32_t size)
|
||||
: cut(cut), size(size) {}
|
||||
};
|
||||
|
||||
/*! \brief cut configuration for all the features */
|
||||
struct HistCutMatrix {
|
||||
/*! \brief actual unit pointer */
|
||||
std::vector<unsigned> row_ptr;
|
||||
/*! \brief unit pointer to rows by element position */
|
||||
std::vector<uint32_t> row_ptr;
|
||||
/*! \brief minimum value of each feature */
|
||||
std::vector<bst_float> min_val;
|
||||
/*! \brief the cut field */
|
||||
std::vector<bst_float> cut;
|
||||
/*! \brief Get histogram bound for fid */
|
||||
inline HistCutUnit operator[](unsigned fid) const {
|
||||
inline HistCutUnit operator[](bst_uint fid) const {
|
||||
return HistCutUnit(dmlc::BeginPtr(cut) + row_ptr[fid],
|
||||
row_ptr[fid + 1] - row_ptr[fid]);
|
||||
}
|
||||
// create histogram cut matrix given statistics from data
|
||||
// using approximate quantile sketch approach
|
||||
void Init(DMatrix* p_fmat, size_t max_num_bins);
|
||||
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
|
||||
};
|
||||
|
||||
|
||||
@ -89,11 +89,11 @@ struct HistCutMatrix {
|
||||
*/
|
||||
struct GHistIndexRow {
|
||||
/*! \brief The index of the histogram */
|
||||
const unsigned* index;
|
||||
const uint32_t* index;
|
||||
/*! \brief The size of the histogram */
|
||||
unsigned size;
|
||||
size_t size;
|
||||
GHistIndexRow() {}
|
||||
GHistIndexRow(const unsigned* index, unsigned size)
|
||||
GHistIndexRow(const uint32_t* index, size_t size)
|
||||
: index(index), size(size) {}
|
||||
};
|
||||
|
||||
@ -103,21 +103,21 @@ struct GHistIndexRow {
|
||||
* This is a global histogram index.
|
||||
*/
|
||||
struct GHistIndexMatrix {
|
||||
/*! \brief row pointer */
|
||||
std::vector<unsigned> row_ptr;
|
||||
/*! \brief row pointer to rows by element position */
|
||||
std::vector<size_t> row_ptr;
|
||||
/*! \brief The index data */
|
||||
std::vector<unsigned> index;
|
||||
std::vector<uint32_t> index;
|
||||
/*! \brief hit count of each index */
|
||||
std::vector<unsigned> hit_count;
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief The corresponding cuts */
|
||||
const HistCutMatrix* cut;
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat);
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](bst_uint i) const {
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
|
||||
}
|
||||
inline void GetFeatureCounts(bst_uint* counts) const {
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
const unsigned nfeature = cut->row_ptr.size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
const unsigned ibegin = cut->row_ptr[fid];
|
||||
@ -129,18 +129,18 @@ struct GHistIndexMatrix {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<unsigned> hit_count_tloc_;
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
};
|
||||
|
||||
struct GHistIndexBlock {
|
||||
const unsigned* row_ptr;
|
||||
const unsigned* index;
|
||||
const size_t* row_ptr;
|
||||
const uint32_t* index;
|
||||
|
||||
inline GHistIndexBlock(const unsigned* row_ptr, const unsigned* index)
|
||||
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
|
||||
: row_ptr(row_ptr), index(index) {}
|
||||
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](bst_uint i) const {
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
|
||||
}
|
||||
};
|
||||
@ -153,23 +153,23 @@ class GHistIndexBlockMatrix {
|
||||
const ColumnMatrix& colmat,
|
||||
const FastHistParam& param);
|
||||
|
||||
inline GHistIndexBlock operator[](bst_uint i) const {
|
||||
inline GHistIndexBlock operator[](size_t i) const {
|
||||
return GHistIndexBlock(blocks[i].row_ptr_begin, blocks[i].index_begin);
|
||||
}
|
||||
|
||||
inline unsigned GetNumBlock() const {
|
||||
inline size_t GetNumBlock() const {
|
||||
return blocks.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<unsigned> row_ptr;
|
||||
std::vector<unsigned> index;
|
||||
std::vector<size_t> row_ptr;
|
||||
std::vector<uint32_t> index;
|
||||
const HistCutMatrix* cut;
|
||||
struct Block {
|
||||
const unsigned* row_ptr_begin;
|
||||
const unsigned* row_ptr_end;
|
||||
const unsigned* index_begin;
|
||||
const unsigned* index_end;
|
||||
const size_t* row_ptr_begin;
|
||||
const size_t* row_ptr_end;
|
||||
const uint32_t* index_begin;
|
||||
const uint32_t* index_end;
|
||||
};
|
||||
std::vector<Block> blocks;
|
||||
};
|
||||
@ -184,10 +184,10 @@ struct GHistRow {
|
||||
/*! \brief base pointer to first entry */
|
||||
GHistEntry* begin;
|
||||
/*! \brief number of entries */
|
||||
unsigned size;
|
||||
uint32_t size;
|
||||
|
||||
GHistRow() {}
|
||||
GHistRow(GHistEntry* begin, unsigned size)
|
||||
GHistRow(GHistEntry* begin, uint32_t size)
|
||||
: begin(begin), size(size) {}
|
||||
};
|
||||
|
||||
@ -198,19 +198,19 @@ class HistCollection {
|
||||
public:
|
||||
// access histogram for i-th node
|
||||
inline GHistRow operator[](bst_uint nid) const {
|
||||
const size_t kMax = std::numeric_limits<size_t>::max();
|
||||
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
CHECK_NE(row_ptr_[nid], kMax);
|
||||
return GHistRow(const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_);
|
||||
}
|
||||
|
||||
// have we computed a histogram for i-th node?
|
||||
inline bool RowExists(bst_uint nid) const {
|
||||
const size_t kMax = std::numeric_limits<size_t>::max();
|
||||
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
return (nid < row_ptr_.size() && row_ptr_[nid] != kMax);
|
||||
}
|
||||
|
||||
// initialize histogram collection
|
||||
inline void Init(size_t nbins) {
|
||||
inline void Init(uint32_t nbins) {
|
||||
nbins_ = nbins;
|
||||
row_ptr_.clear();
|
||||
data_.clear();
|
||||
@ -218,7 +218,7 @@ class HistCollection {
|
||||
|
||||
// create an empty histogram for i-th node
|
||||
inline void AddHistRow(bst_uint nid) {
|
||||
const size_t kMax = std::numeric_limits<size_t>::max();
|
||||
const uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
if (nid >= row_ptr_.size()) {
|
||||
row_ptr_.resize(nid + 1, kMax);
|
||||
}
|
||||
@ -230,12 +230,12 @@ class HistCollection {
|
||||
|
||||
private:
|
||||
/*! \brief number of all bins over all features */
|
||||
size_t nbins_;
|
||||
uint32_t nbins_;
|
||||
|
||||
std::vector<GHistEntry> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
||||
std::vector<size_t> row_ptr_;
|
||||
std::vector<uint32_t> row_ptr_;
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -244,7 +244,7 @@ class HistCollection {
|
||||
class GHistBuilder {
|
||||
public:
|
||||
// initialize builder
|
||||
inline void Init(size_t nthread, size_t nbins) {
|
||||
inline void Init(size_t nthread, uint32_t nbins) {
|
||||
nthread_ = nthread;
|
||||
nbins_ = nbins;
|
||||
}
|
||||
@ -268,7 +268,7 @@ class GHistBuilder {
|
||||
/*! \brief number of threads for parallel computation */
|
||||
size_t nthread_;
|
||||
/*! \brief number of all bins over all features */
|
||||
size_t nbins_;
|
||||
uint32_t nbins_;
|
||||
std::vector<GHistEntry> data_;
|
||||
};
|
||||
|
||||
|
||||
@ -21,14 +21,14 @@ class RowSetCollection {
|
||||
* rows (instances) associated with a particular node in a decision
|
||||
* tree. */
|
||||
struct Elem {
|
||||
const bst_uint* begin;
|
||||
const bst_uint* end;
|
||||
const size_t* begin;
|
||||
const size_t* end;
|
||||
int node_id;
|
||||
// id of node associated with this instance set; -1 means uninitialized
|
||||
Elem(void)
|
||||
: begin(nullptr), end(nullptr), node_id(-1) {}
|
||||
Elem(const bst_uint* begin,
|
||||
const bst_uint* end,
|
||||
Elem(const size_t* begin,
|
||||
const size_t* end,
|
||||
int node_id)
|
||||
: begin(begin), end(end), node_id(node_id) {}
|
||||
|
||||
@ -38,8 +38,8 @@ class RowSetCollection {
|
||||
};
|
||||
/* \brief specifies how to split a rowset into two */
|
||||
struct Split {
|
||||
std::vector<bst_uint> left;
|
||||
std::vector<bst_uint> right;
|
||||
std::vector<size_t> left;
|
||||
std::vector<size_t> right;
|
||||
};
|
||||
|
||||
inline std::vector<Elem>::const_iterator begin() const {
|
||||
@ -65,8 +65,8 @@ class RowSetCollection {
|
||||
// initialize node id 0->everything
|
||||
inline void Init() {
|
||||
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
||||
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
|
||||
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||
const size_t* begin = dmlc::BeginPtr(row_indices_);
|
||||
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||
}
|
||||
// split rowset into two
|
||||
@ -77,16 +77,15 @@ class RowSetCollection {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
const unsigned nthread = row_split_tloc.size();
|
||||
CHECK(e.begin != nullptr);
|
||||
bst_uint* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
bst_uint* begin = all_begin + (e.begin - all_begin);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
|
||||
bst_uint* it = begin;
|
||||
// TODO(hcho3): parallelize this section
|
||||
size_t* it = begin;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
|
||||
it += row_split_tloc[tid].left.size();
|
||||
}
|
||||
bst_uint* split_pt = it;
|
||||
size_t* split_pt = it;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
|
||||
it += row_split_tloc[tid].right.size();
|
||||
@ -105,7 +104,7 @@ class RowSetCollection {
|
||||
}
|
||||
|
||||
// stores the row indices in the set
|
||||
std::vector<bst_uint> row_indices_;
|
||||
std::vector<size_t> row_indices_;
|
||||
|
||||
private:
|
||||
// vector: node_id -> elements
|
||||
|
||||
@ -61,7 +61,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
TStats::CheckInfo(dmat->info());
|
||||
if (is_gmat_initialized_ == false) {
|
||||
double tstart = dmlc::GetTime();
|
||||
hmat_.Init(dmat, param.max_bin);
|
||||
hmat_.Init(dmat, static_cast<uint32_t>(param.max_bin));
|
||||
gmat_.cut = &hmat_;
|
||||
gmat_.Init(dmat);
|
||||
column_matrix_.Init(gmat_, fhparam);
|
||||
@ -111,23 +111,6 @@ class FastHistMaker: public TreeUpdater {
|
||||
bool is_gmat_initialized_;
|
||||
|
||||
// data structure
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry {
|
||||
/*! \brief statistics of data */
|
||||
TStats stats;
|
||||
/*! \brief extra statistics of data */
|
||||
TStats stats_extra;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief first feature value scanned */
|
||||
float first_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
explicit ThreadEntry(const TrainParam& param)
|
||||
: stats(param), stats_extra(param) {
|
||||
}
|
||||
};
|
||||
struct NodeEntry {
|
||||
/*! \brief statics for node entry */
|
||||
TStats stats;
|
||||
@ -340,7 +323,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
leaf_value = (*p_last_tree_)[nid].leaf_value();
|
||||
|
||||
for (const bst_uint* it = rowset.begin; it < rowset.end; ++it) {
|
||||
for (const size_t* it = rowset.begin; it < rowset.end; ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
}
|
||||
}
|
||||
@ -372,7 +355,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
// clear local prediction cache
|
||||
leaf_value_cache_.clear();
|
||||
// initialize histogram collection
|
||||
size_t nbins = gmat.cut->row_ptr.back();
|
||||
uint32_t nbins = gmat.cut->row_ptr.back();
|
||||
hist_.Init(nbins);
|
||||
|
||||
// initialize histogram builder
|
||||
@ -383,18 +366,18 @@ class FastHistMaker: public TreeUpdater {
|
||||
hist_builder_.Init(this->nthread, nbins);
|
||||
|
||||
CHECK_EQ(info.root_index.size(), 0U);
|
||||
std::vector<bst_uint>& row_indices = row_set_collection_.row_indices_;
|
||||
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
|
||||
// mark subsample and build list of member rows
|
||||
if (param.subsample < 1.0f) {
|
||||
std::bernoulli_distribution coin_flip(param.subsample);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
for (bst_uint i = 0; i < info.num_row; ++i) {
|
||||
for (size_t i = 0; i < info.num_row; ++i) {
|
||||
if (gpair[i].hess >= 0.0f && coin_flip(rnd)) {
|
||||
row_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (bst_uint i = 0; i < info.num_row; ++i) {
|
||||
for (size_t i = 0; i < info.num_row; ++i) {
|
||||
if (gpair[i].hess >= 0.0f) {
|
||||
row_indices.push_back(i);
|
||||
}
|
||||
@ -405,11 +388,11 @@ class FastHistMaker: public TreeUpdater {
|
||||
|
||||
{
|
||||
/* determine layout of data */
|
||||
const auto nrow = info.num_row;
|
||||
const auto ncol = info.num_col;
|
||||
const auto nnz = info.num_nonzero;
|
||||
const size_t nrow = info.num_row;
|
||||
const size_t ncol = info.num_col;
|
||||
const size_t nnz = info.num_nonzero;
|
||||
// number of discrete bins for feature 0
|
||||
const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
|
||||
const uint32_t nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
|
||||
if (nrow * ncol == nnz) {
|
||||
// dense data with zero-based indexing
|
||||
data_layout_ = kDenseDataZeroBased;
|
||||
@ -427,19 +410,19 @@ class FastHistMaker: public TreeUpdater {
|
||||
// store a pointer to training data
|
||||
p_last_fmat_ = &fmat;
|
||||
// initialize feature index
|
||||
unsigned ncol = static_cast<unsigned>(info.num_col);
|
||||
bst_uint ncol = static_cast<bst_uint>(info.num_col);
|
||||
feat_index.clear();
|
||||
if (data_layout_ == kDenseDataOneBased) {
|
||||
for (unsigned i = 1; i < ncol; ++i) {
|
||||
for (bst_uint i = 1; i < ncol; ++i) {
|
||||
feat_index.push_back(i);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < ncol; ++i) {
|
||||
for (bst_uint i = 0; i < ncol; ++i) {
|
||||
feat_index.push_back(i);
|
||||
}
|
||||
}
|
||||
unsigned n = std::max(static_cast<unsigned>(1),
|
||||
static_cast<unsigned>(param.colsample_bytree * feat_index.size()));
|
||||
bst_uint n = std::max(static_cast<bst_uint>(1),
|
||||
static_cast<bst_uint>(param.colsample_bytree * feat_index.size()));
|
||||
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
|
||||
CHECK_GT(param.colsample_bytree, 0U)
|
||||
<< "colsample_bytree cannot be zero.";
|
||||
@ -450,11 +433,11 @@ class FastHistMaker: public TreeUpdater {
|
||||
choose the column that has a least positive number of discrete bins.
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid] */
|
||||
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
|
||||
const size_t nfeature = row_ptr.size() - 1;
|
||||
size_t min_nbins_per_feature = 0;
|
||||
for (size_t i = 0; i < nfeature; ++i) {
|
||||
const unsigned nbins = row_ptr[i + 1] - row_ptr[i];
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
|
||||
const bst_uint nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
|
||||
uint32_t min_nbins_per_feature = 0;
|
||||
for (bst_uint i = 0; i < nfeature; ++i) {
|
||||
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
|
||||
if (nbins > 0) {
|
||||
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
|
||||
min_nbins_per_feature = nbins;
|
||||
@ -485,7 +468,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
const std::vector<bst_uint>& feat_set) {
|
||||
// start enumeration
|
||||
const MetaInfo& info = fmat.info();
|
||||
const bst_omp_uint nfeature = feat_set.size();
|
||||
const bst_uint nfeature = static_cast<bst_uint>(feat_set.size());
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread);
|
||||
best_split_tloc_.resize(nthread);
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
@ -547,13 +530,17 @@ class FastHistMaker: public TreeUpdater {
|
||||
const bool default_left = (*p_tree)[nid].default_left();
|
||||
const bst_uint fid = (*p_tree)[nid].split_index();
|
||||
const bst_float split_pt = (*p_tree)[nid].split_cond();
|
||||
const bst_uint lower_bound = gmat.cut->row_ptr[fid];
|
||||
const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1];
|
||||
bst_int split_cond = -1;
|
||||
const uint32_t lower_bound = gmat.cut->row_ptr[fid];
|
||||
const uint32_t upper_bound = gmat.cut->row_ptr[fid + 1];
|
||||
int32_t split_cond = -1;
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) {
|
||||
if (split_pt == gmat.cut->cut[i]) split_cond = static_cast<bst_int>(i);
|
||||
CHECK_LT(upper_bound,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
|
||||
if (split_pt == gmat.cut->cut[i]) {
|
||||
split_cond = static_cast<int32_t>(i);
|
||||
}
|
||||
}
|
||||
|
||||
const auto& rowset = row_set_collection_[nid];
|
||||
@ -580,15 +567,15 @@ class FastHistMaker: public TreeUpdater {
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
const bst_omp_uint rest = nrows % K;
|
||||
const size_t nrows = rowset.end - rowset.begin;
|
||||
const size_t rest = nrows % K;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
const bst_uint tid = omp_get_thread_num();
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
bst_uint rid[K];
|
||||
size_t rid[K];
|
||||
T rbin[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
rid[k] = rowset.begin[i + k];
|
||||
@ -604,7 +591,9 @@ class FastHistMaker: public TreeUpdater {
|
||||
right.push_back(rid[k]);
|
||||
}
|
||||
} else {
|
||||
if (static_cast<bst_int>(rbin[k] + column.index_base) <= split_cond) {
|
||||
CHECK_LT(rbin[k] + column.index_base,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
if (static_cast<int32_t>(rbin[k] + column.index_base) <= split_cond) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
@ -612,10 +601,10 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||
auto& left = row_split_tloc[nthread-1].left;
|
||||
auto& right = row_split_tloc[nthread-1].right;
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
const size_t rid = rowset.begin[i];
|
||||
const T rbin = column.index[rid];
|
||||
if (rbin == std::numeric_limits<T>::max()) { // missing value
|
||||
if (default_left) {
|
||||
@ -624,7 +613,9 @@ class FastHistMaker: public TreeUpdater {
|
||||
right.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
|
||||
CHECK_LT(rbin + column.index_base,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
@ -642,13 +633,13 @@ class FastHistMaker: public TreeUpdater {
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
const bst_omp_uint rest = nrows % K;
|
||||
const size_t nrows = rowset.end - rowset.begin;
|
||||
const size_t rest = nrows % K;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
bst_uint rid[K];
|
||||
size_t rid[K];
|
||||
GHistIndexRow row[K];
|
||||
const unsigned* p[K];
|
||||
const uint32_t* p[K];
|
||||
bst_uint tid = omp_get_thread_num();
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
@ -663,7 +654,9 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
|
||||
if (static_cast<bst_int>(*p[k]) <= split_cond) {
|
||||
CHECK_LT(*p[k],
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
if (static_cast<int32_t>(*p[k]) <= split_cond) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
@ -677,14 +670,15 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
const auto row = gmat[rid];
|
||||
const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound);
|
||||
auto& left = row_split_tloc[0].left;
|
||||
auto& right = row_split_tloc[0].right;
|
||||
if (p != row.index + row.size && *p < upper_bound) {
|
||||
if (static_cast<bst_int>(*p) <= split_cond) {
|
||||
CHECK_LT(*p, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
if (static_cast<int32_t>(*p) <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
@ -709,26 +703,26 @@ class FastHistMaker: public TreeUpdater {
|
||||
bst_int split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
const size_t nrows = rowset.end - rowset.begin;
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
const bst_uint tid = omp_get_thread_num();
|
||||
const bst_omp_uint ibegin = tid * nrows / nthread;
|
||||
const bst_omp_uint iend = (tid + 1) * nrows / nthread;
|
||||
const size_t tid = static_cast<size_t>(omp_get_thread_num());
|
||||
const size_t ibegin = tid * nrows / nthread;
|
||||
const size_t iend = (tid + 1) * nrows / nthread;
|
||||
if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range
|
||||
// search first nonzero row with index >= rowset[ibegin]
|
||||
const uint32_t* p = std::lower_bound(column.row_ind,
|
||||
column.row_ind + column.len,
|
||||
rowset.begin[ibegin]);
|
||||
const size_t* p = std::lower_bound(column.row_ind,
|
||||
column.row_ind + column.len,
|
||||
rowset.begin[ibegin]);
|
||||
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) {
|
||||
bst_omp_uint cursor = p - column.row_ind;
|
||||
size_t cursor = p - column.row_ind;
|
||||
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
while (cursor < column.len
|
||||
&& column.row_ind[cursor] < rid
|
||||
&& column.row_ind[cursor] <= rowset.begin[iend - 1]) {
|
||||
@ -736,7 +730,9 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
if (cursor < column.len && column.row_ind[cursor] == rid) {
|
||||
const T rbin = column.index[cursor];
|
||||
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
|
||||
CHECK_LT(rbin + column.index_base,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
@ -753,13 +749,13 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
} else { // all rows in [ibegin, iend) have missing values
|
||||
if (default_left) {
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
left.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
@ -786,17 +782,17 @@ class FastHistMaker: public TreeUpdater {
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid] */
|
||||
GHistRow hist = hist_[nid];
|
||||
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
|
||||
|
||||
const size_t ibegin = row_ptr[fid_least_bins_];
|
||||
const size_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||
const GHistEntry et = hist.begin[i];
|
||||
stats.Add(et.sum_grad, et.sum_hess);
|
||||
}
|
||||
} else {
|
||||
const RowSetCollection::Elem e = row_set_collection_[nid];
|
||||
for (const bst_uint* it = e.begin; it < e.end; ++it) {
|
||||
for (const size_t* it = e.begin; it < e.end; ++it) {
|
||||
stats.Add(gpair[*it]);
|
||||
}
|
||||
}
|
||||
@ -831,7 +827,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
CHECK(d_step == +1 || d_step == -1);
|
||||
|
||||
// aliases
|
||||
const std::vector<unsigned>& cut_ptr = gmat.cut->row_ptr;
|
||||
const std::vector<uint32_t>& cut_ptr = gmat.cut->row_ptr;
|
||||
const std::vector<bst_float>& cut_val = gmat.cut->cut;
|
||||
|
||||
// statistics on both sides of split
|
||||
@ -841,20 +837,25 @@ class FastHistMaker: public TreeUpdater {
|
||||
SplitEntry best;
|
||||
|
||||
// bin boundaries
|
||||
CHECK_LE(cut_ptr[fid],
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
CHECK_LE(cut_ptr[fid + 1],
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
// imin: index (offset) of the minimum value for feature fid
|
||||
// need this for backward enumeration
|
||||
const int imin = cut_ptr[fid];
|
||||
const int32_t imin = static_cast<int32_t>(cut_ptr[fid]);
|
||||
// ibegin, iend: smallest/largest cut points for feature fid
|
||||
int ibegin, iend;
|
||||
// use int to allow for value -1
|
||||
int32_t ibegin, iend;
|
||||
if (d_step > 0) {
|
||||
ibegin = cut_ptr[fid];
|
||||
iend = cut_ptr[fid + 1];
|
||||
ibegin = static_cast<int32_t>(cut_ptr[fid]);
|
||||
iend = static_cast<int32_t>(cut_ptr[fid + 1]);
|
||||
} else {
|
||||
ibegin = cut_ptr[fid + 1] - 1;
|
||||
iend = cut_ptr[fid] - 1;
|
||||
ibegin = static_cast<int32_t>(cut_ptr[fid + 1]) - 1;
|
||||
iend = static_cast<int32_t>(cut_ptr[fid]) - 1;
|
||||
}
|
||||
|
||||
for (int i = ibegin; i != iend; i += d_step) {
|
||||
for (int32_t i = ibegin; i != iend; i += d_step) {
|
||||
// start working
|
||||
// try to find a split
|
||||
e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
|
||||
@ -930,7 +931,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
HistCollection hist_;
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
size_t fid_least_bins_;
|
||||
uint32_t fid_least_bins_;
|
||||
/*! \brief local prediction cache; maps node id to leaf value */
|
||||
std::vector<float> leaf_value_cache_;
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -51,4 +52,4 @@ TEST(CompressedIterator, Test) {
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user