Fix bugs in multithreaded ApplySplitSparseData() (#2161)
* Bugfix 1: Fix segfault in multithreaded ApplySplitSparseData() When there are more threads than rows in rowset, some threads end up with empty ranges, causing them to crash. (iend - 1 needs to be accessible as part of algorithm) Fix: run only those threads with nonempty ranges. * Add regression test for Bugfix 1 * Moving python_omp_test to existing python test group Turns out you don't need to set "OMP_NUM_THREADS" to enable multithreading. Just add nthread parameter. * Bugfix 2: Fix corner case of ApplySplitSparseData() for categorical feature When split value is less than all cut points, split_cond is set incorrectly. Fix: set split_cond = -1 to indicate this scenario * Bugfix 3: Initialize data layout indicator before using it data_layout_ is accessed before being set; this variable determines whether feature 0 is included in feat_set. Fix: re-order code in InitData() to initialize data_layout_ first * Adding regression test for Bugfix 2 Unfortunately, no regression test for Bugfix 3, as there is no way to deterministically assign value to an uninitialized variable.
This commit is contained in:
parent
ed5e75de2f
commit
2715baef64
@ -64,6 +64,7 @@ namespace xgboost {
|
|||||||
* used for feature index and row index.
|
* used for feature index and row index.
|
||||||
*/
|
*/
|
||||||
typedef uint32_t bst_uint;
|
typedef uint32_t bst_uint;
|
||||||
|
typedef int32_t bst_int;
|
||||||
/*! \brief long integers */
|
/*! \brief long integers */
|
||||||
typedef uint64_t bst_ulong; // NOLINT(*)
|
typedef uint64_t bst_ulong; // NOLINT(*)
|
||||||
/*! \brief float type, used for storing statistics */
|
/*! \brief float type, used for storing statistics */
|
||||||
|
|||||||
@ -374,6 +374,24 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
row_set_collection_.Init();
|
row_set_collection_.Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
/* determine layout of data */
|
||||||
|
const auto nrow = info.num_row;
|
||||||
|
const auto ncol = info.num_col;
|
||||||
|
const auto nnz = info.num_nonzero;
|
||||||
|
// number of discrete bins for feature 0
|
||||||
|
const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
|
||||||
|
if (nrow * ncol == nnz) {
|
||||||
|
// dense data with zero-based indexing
|
||||||
|
data_layout_ = kDenseDataZeroBased;
|
||||||
|
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
|
||||||
|
// dense data with one-based indexing
|
||||||
|
data_layout_ = kDenseDataOneBased;
|
||||||
|
} else {
|
||||||
|
// sparse data
|
||||||
|
data_layout_ = kSparseData;
|
||||||
|
}
|
||||||
|
}
|
||||||
{
|
{
|
||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
p_last_tree_ = &tree;
|
p_last_tree_ = &tree;
|
||||||
@ -398,24 +416,6 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
<< " is too small that no feature can be included";
|
<< " is too small that no feature can be included";
|
||||||
feat_index.resize(n);
|
feat_index.resize(n);
|
||||||
}
|
}
|
||||||
{
|
|
||||||
/* determine layout of data */
|
|
||||||
const auto nrow = info.num_row;
|
|
||||||
const auto ncol = info.num_col;
|
|
||||||
const auto nnz = info.num_nonzero;
|
|
||||||
// number of discrete bins for feature 0
|
|
||||||
const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
|
|
||||||
if (nrow * ncol == nnz) {
|
|
||||||
// dense data with zero-based indexing
|
|
||||||
data_layout_ = kDenseDataZeroBased;
|
|
||||||
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
|
|
||||||
// dense data with one-based indexing
|
|
||||||
data_layout_ = kDenseDataOneBased;
|
|
||||||
} else {
|
|
||||||
// sparse data
|
|
||||||
data_layout_ = kSparseData;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||||
/* specialized code for dense data:
|
/* specialized code for dense data:
|
||||||
choose the column that has a least positive number of discrete bins.
|
choose the column that has a least positive number of discrete bins.
|
||||||
@ -520,11 +520,11 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const bst_float split_pt = (*p_tree)[nid].split_cond();
|
const bst_float split_pt = (*p_tree)[nid].split_cond();
|
||||||
const bst_uint lower_bound = gmat.cut->row_ptr[fid];
|
const bst_uint lower_bound = gmat.cut->row_ptr[fid];
|
||||||
const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1];
|
const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1];
|
||||||
// set the split condition correctly
|
bst_int split_cond = -1;
|
||||||
bst_uint split_cond = 0;
|
// convert floating-point split_pt into corresponding bin_id
|
||||||
// set the condition
|
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||||
for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) {
|
for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) {
|
||||||
if (split_pt == gmat.cut->cut[i]) split_cond = i;
|
if (split_pt == gmat.cut->cut[i]) split_cond = static_cast<bst_int>(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& rowset = row_set_collection_[nid];
|
const auto& rowset = row_set_collection_[nid];
|
||||||
@ -547,7 +547,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||||
const Column<T>& column,
|
const Column<T>& column,
|
||||||
bst_uint split_cond,
|
bst_int split_cond,
|
||||||
bool default_left) {
|
bool default_left) {
|
||||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||||
const int K = 8; // loop unrolling factor
|
const int K = 8; // loop unrolling factor
|
||||||
@ -575,7 +575,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
right.push_back(rid[k]);
|
right.push_back(rid[k]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (rbin[k] + column.index_base <= split_cond) {
|
if (static_cast<bst_int>(rbin[k] + column.index_base) <= split_cond) {
|
||||||
left.push_back(rid[k]);
|
left.push_back(rid[k]);
|
||||||
} else {
|
} else {
|
||||||
right.push_back(rid[k]);
|
right.push_back(rid[k]);
|
||||||
@ -595,7 +595,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
right.push_back(rid);
|
right.push_back(rid);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (rbin + column.index_base <= split_cond) {
|
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
|
||||||
left.push_back(rid);
|
left.push_back(rid);
|
||||||
} else {
|
} else {
|
||||||
right.push_back(rid);
|
right.push_back(rid);
|
||||||
@ -609,7 +609,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||||
bst_uint lower_bound,
|
bst_uint lower_bound,
|
||||||
bst_uint upper_bound,
|
bst_uint upper_bound,
|
||||||
bst_uint split_cond,
|
bst_int split_cond,
|
||||||
bool default_left) {
|
bool default_left) {
|
||||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||||
const int K = 8; // loop unrolling factor
|
const int K = 8; // loop unrolling factor
|
||||||
@ -634,7 +634,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
for (int k = 0; k < K; ++k) {
|
for (int k = 0; k < K; ++k) {
|
||||||
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
|
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
|
||||||
if (*p[k] <= split_cond) {
|
if (static_cast<bst_int>(*p[k]) <= split_cond) {
|
||||||
left.push_back(rid[k]);
|
left.push_back(rid[k]);
|
||||||
} else {
|
} else {
|
||||||
right.push_back(rid[k]);
|
right.push_back(rid[k]);
|
||||||
@ -655,7 +655,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
auto& left = row_split_tloc[0].left;
|
auto& left = row_split_tloc[0].left;
|
||||||
auto& right = row_split_tloc[0].right;
|
auto& right = row_split_tloc[0].right;
|
||||||
if (p != row.index + row.size && *p < upper_bound) {
|
if (p != row.index + row.size && *p < upper_bound) {
|
||||||
if (*p <= split_cond) {
|
if (static_cast<bst_int>(*p) <= split_cond) {
|
||||||
left.push_back(rid);
|
left.push_back(rid);
|
||||||
} else {
|
} else {
|
||||||
right.push_back(rid);
|
right.push_back(rid);
|
||||||
@ -677,7 +677,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const Column<T>& column,
|
const Column<T>& column,
|
||||||
bst_uint lower_bound,
|
bst_uint lower_bound,
|
||||||
bst_uint upper_bound,
|
bst_uint upper_bound,
|
||||||
bst_uint split_cond,
|
bst_int split_cond,
|
||||||
bool default_left) {
|
bool default_left) {
|
||||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||||
@ -687,6 +687,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
const bst_uint tid = omp_get_thread_num();
|
const bst_uint tid = omp_get_thread_num();
|
||||||
const bst_omp_uint ibegin = tid * nrows / nthread;
|
const bst_omp_uint ibegin = tid * nrows / nthread;
|
||||||
const bst_omp_uint iend = (tid + 1) * nrows / nthread;
|
const bst_omp_uint iend = (tid + 1) * nrows / nthread;
|
||||||
|
if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range
|
||||||
// search first nonzero row with index >= rowset[ibegin]
|
// search first nonzero row with index >= rowset[ibegin]
|
||||||
const uint32_t* p = std::lower_bound(column.row_ind,
|
const uint32_t* p = std::lower_bound(column.row_ind,
|
||||||
column.row_ind + column.len,
|
column.row_ind + column.len,
|
||||||
@ -706,7 +707,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
if (cursor < column.len && column.row_ind[cursor] == rid) {
|
if (cursor < column.len && column.row_ind[cursor] == rid) {
|
||||||
const T rbin = column.index[cursor];
|
const T rbin = column.index[cursor];
|
||||||
if (rbin + column.index_base <= split_cond) {
|
if (static_cast<bst_int>(rbin + column.index_base) <= split_cond) {
|
||||||
left.push_back(rid);
|
left.push_back(rid);
|
||||||
} else {
|
} else {
|
||||||
right.push_back(rid);
|
right.push_back(rid);
|
||||||
@ -736,6 +737,7 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void InitNewNode(int nid,
|
inline void InitNewNode(int nid,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
|
|||||||
43
tests/python/test_openmp.py
Normal file
43
tests/python/test_openmp.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from scipy.sparse import csr_matrix
|
||||||
|
import xgboost as xgb
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestOMP(unittest.TestCase):
|
||||||
|
def test_omp(self):
|
||||||
|
# a contrived example where one node has an instance set of size 2.
|
||||||
|
data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||||
|
indices = [2, 1, 1, 2, 0, 0, 2, 0, 1, 3]
|
||||||
|
indptr = [0, 1, 2, 4, 5, 7, 9, 10]
|
||||||
|
A = csr_matrix((data, indices, indptr), shape=(7, 4))
|
||||||
|
y = [1, 1, 0, 0, 0, 1, 1]
|
||||||
|
dtrain = xgb.DMatrix(A, label=y)
|
||||||
|
|
||||||
|
# 1. use 3 threads to train a tree with an instance set of size 2
|
||||||
|
param = {'booster': 'gbtree',
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'grow_policy': 'lossguide',
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'eval_metric': 'auc',
|
||||||
|
'max_depth': 0,
|
||||||
|
'max_leaves': 1024,
|
||||||
|
'min_child_weight': 0,
|
||||||
|
'nthread': 3}
|
||||||
|
|
||||||
|
watchlist = [(dtrain, 'train')]
|
||||||
|
num_round = 1
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain, num_round, watchlist, evals_result=res)
|
||||||
|
assert res['train']['auc'][-1] > 0.99
|
||||||
|
|
||||||
|
# 2. vary number of threads and test whether you get the same result
|
||||||
|
param['nthread'] = 1
|
||||||
|
res2 = {}
|
||||||
|
xgb.train(param, dtrain, num_round, watchlist, evals_result=res2)
|
||||||
|
assert res['train']['auc'][-1] == res2['train']['auc'][-1]
|
||||||
|
|
||||||
|
param['nthread'] = 2
|
||||||
|
res3 = {}
|
||||||
|
xgb.train(param, dtrain, num_round, watchlist, evals_result=res3)
|
||||||
|
assert res['train']['auc'][-1] == res3['train']['auc'][-1]
|
||||||
Loading…
x
Reference in New Issue
Block a user