Multi-threaded XGDMatrixCreateFromMat for faster DMatrix creation (#2530)
* Multi-threaded XGDMatrixCreateFromMat for faster DMatrix creation from numpy arrays for python interface.
This commit is contained in:
parent
56550ff3f1
commit
6b375f6ad8
@ -206,6 +206,22 @@ XGB_DLL int XGDMatrixCreateFromMat(const float *data,
|
||||
bst_ulong ncol,
|
||||
float missing,
|
||||
DMatrixHandle *out);
|
||||
/*!
|
||||
* \brief create matrix content from dense matrix
|
||||
* \param data pointer to the data space
|
||||
* \param nrow number of rows
|
||||
* \param ncol number columns
|
||||
* \param missing which value to represent missing value
|
||||
* \param out created dmatrix
|
||||
* \param nthread number of threads (up to maximum cores available, if <=0 use all cores)
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data,
|
||||
bst_ulong nrow,
|
||||
bst_ulong ncol,
|
||||
float missing,
|
||||
DMatrixHandle *out,
|
||||
int nthread);
|
||||
/*!
|
||||
* \brief create a new dmatrix from sliced content of existing matrix
|
||||
* \param handle instance of data matrix to be sliced
|
||||
|
||||
@ -5,13 +5,18 @@ import numpy as np
|
||||
from sklearn.datasets import make_classification
|
||||
import time
|
||||
|
||||
n = 1000000
|
||||
num_rounds = 500
|
||||
|
||||
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
||||
print("Generating dataset: {} rows * {} columns".format(args.rows,args.columns))
|
||||
tmp = time.time()
|
||||
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
|
||||
tmp = time.time()
|
||||
print ("DMatrix Start")
|
||||
# omp way
|
||||
dtrain = xgb.DMatrix(X, y, nthread=-1)
|
||||
# non-omp way
|
||||
#dtrain = xgb.DMatrix(X, y)
|
||||
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'max_depth': 6,
|
||||
@ -24,7 +29,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
||||
print("Training with '%s'" % param['tree_method'])
|
||||
tmp = time.time()
|
||||
xgb.train(param, dtrain, args.iterations)
|
||||
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
||||
print ("Train Time: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
param['silent'] = 1
|
||||
param['tree_method'] = cpu_algorithm
|
||||
|
||||
@ -343,8 +343,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
}
|
||||
|
||||
void GPUHistBuilder::BuildHist(int depth) {
|
||||
// dh::Timer time;
|
||||
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
int device_idx = dList[d_idx];
|
||||
size_t begin = device_element_segments[d_idx];
|
||||
@ -1070,9 +1068,9 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode(gpair);
|
||||
this->ColSampleTree();
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||
this->ColSampleLevel();
|
||||
|
||||
this->BuildHist(depth);
|
||||
this->FindSplit(depth);
|
||||
this->UpdatePosition(depth);
|
||||
|
||||
@ -29,13 +29,14 @@ class TestGPU(unittest.TestCase):
|
||||
|
||||
ag_param = {'max_depth': 2,
|
||||
'tree_method': 'exact',
|
||||
'nthread': 1,
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': 2,
|
||||
'tree_method': 'gpu_exact',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'objective': 'binary:logistic',
|
||||
@ -59,6 +60,7 @@ class TestGPU(unittest.TestCase):
|
||||
dtest = xgb.DMatrix(X_test, y_test)
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_exact',
|
||||
'max_depth': 3,
|
||||
'eval_metric': 'auc'}
|
||||
@ -75,6 +77,7 @@ class TestGPU(unittest.TestCase):
|
||||
dtrain2 = xgb.DMatrix(X2, label=y2)
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_exact',
|
||||
'max_depth': 2,
|
||||
'eval_metric': 'auc'}
|
||||
@ -128,26 +131,28 @@ class TestGPU(unittest.TestCase):
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
'nthread': 1,
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
ag_param3 = {'max_depth': max_depth,
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'n_gpus': n_gpus,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
ag_res2 = {}
|
||||
@ -178,6 +183,7 @@ class TestGPU(unittest.TestCase):
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': 1,
|
||||
'max_bin': max_bin,
|
||||
@ -189,6 +195,7 @@ class TestGPU(unittest.TestCase):
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
#assert self.non_decreasing(res['test']['auc'])
|
||||
param2 = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
@ -211,6 +218,7 @@ class TestGPU(unittest.TestCase):
|
||||
dtrain2 = xgb.DMatrix(X2, label=y2)
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
@ -250,6 +258,7 @@ class TestGPU(unittest.TestCase):
|
||||
######################################################################
|
||||
# fail-safe test for max_bin
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
@ -263,6 +272,7 @@ class TestGPU(unittest.TestCase):
|
||||
######################################################################
|
||||
# subsampling
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
@ -279,6 +289,7 @@ class TestGPU(unittest.TestCase):
|
||||
######################################################################
|
||||
# fail-safe test for max_bin=2
|
||||
param = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': 2,
|
||||
'n_gpus': n_gpus,
|
||||
|
||||
@ -18,62 +18,48 @@ rng = np.random.RandomState(1994)
|
||||
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
|
||||
cols = 31
|
||||
# reduced to fit onto 1 gpu but still be large
|
||||
rows2 = 5000 # medium
|
||||
#rows2 = 4032 # fake large for testing
|
||||
rows3 = 5000 # small
|
||||
rows2 = 4360032 # medium
|
||||
rows1 = 42360032 # large
|
||||
#rows2 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rowslist = [rows1, rows2]
|
||||
#rows1 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rowslist = [rows1, rows2, rows3]
|
||||
|
||||
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_large(self):
|
||||
eprint("Starting test for large data")
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
except:
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
|
||||
for rows in rowslist:
|
||||
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||
X, y = make_classification(rows, n_features=cols, random_state=7)
|
||||
rowstest = int(rows*0.2)
|
||||
eprint("Creating test data rows=%d cols=%d" % (rowstest,cols))
|
||||
# note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot
|
||||
Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8)
|
||||
|
||||
np.random.seed(7)
|
||||
X = np.random.rand(rows, cols)
|
||||
y = np.random.rand(rows)
|
||||
eprint("Starting DMatrix(X,y)")
|
||||
ag_dtrain = xgb.DMatrix(X,y)
|
||||
eprint("Starting DMatrix(Xtest,ytest)")
|
||||
ag_dtest = xgb.DMatrix(Xtest,ytest)
|
||||
ag_dtrain = xgb.DMatrix(X,y,nthread=0)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
max_depth=6
|
||||
max_bin=1024
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
#'nthread': 1,
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_paramb = {'max_depth': max_depth,
|
||||
'tree_method': 'hist',
|
||||
#'nthread': 1,
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': 1,
|
||||
@ -82,26 +68,18 @@ class TestGPU(unittest.TestCase):
|
||||
'eval_metric': 'auc'}
|
||||
ag_param3 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': -1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
#ag_param4 = {'max_depth': max_depth,
|
||||
# 'tree_method': 'gpu_exact',
|
||||
# 'eta': 1,
|
||||
# 'silent': 0,
|
||||
# 'n_gpus': 1,
|
||||
# 'objective': 'binary:logistic',
|
||||
# 'max_bin': max_bin,
|
||||
# 'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
ag_resb = {}
|
||||
ag_res2 = {}
|
||||
ag_res3 = {}
|
||||
#ag_res4 = {}
|
||||
|
||||
|
||||
num_rounds = 1
|
||||
|
||||
eprint("normal updater")
|
||||
@ -116,19 +94,10 @@ class TestGPU(unittest.TestCase):
|
||||
eprint("gpu_hist updater all gpus")
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res3)
|
||||
#eprint("gpu_exact updater")
|
||||
#xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
# evals_result=ag_res4)
|
||||
|
||||
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
|
||||
#assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001
|
||||
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01
|
||||
#assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -224,7 +224,8 @@ class DMatrix(object):
|
||||
|
||||
def __init__(self, data, label=None, missing=None,
|
||||
weight=None, silent=False,
|
||||
feature_names=None, feature_types=None):
|
||||
feature_names=None, feature_types=None,
|
||||
nthread=None):
|
||||
"""
|
||||
Data matrix used in XGBoost.
|
||||
|
||||
@ -268,7 +269,7 @@ class DMatrix(object):
|
||||
elif isinstance(data, scipy.sparse.csc_matrix):
|
||||
self._init_from_csc(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
self._init_from_npy2d(data, missing)
|
||||
self._init_from_npy2d(data, missing, nthread)
|
||||
else:
|
||||
try:
|
||||
csr = scipy.sparse.csr_matrix(data)
|
||||
@ -276,9 +277,15 @@ class DMatrix(object):
|
||||
except:
|
||||
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__))
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
if isinstance(data, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
else:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
self.set_weight(weight)
|
||||
if isinstance(data, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
else:
|
||||
self.set_weight(weight)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
@ -313,7 +320,7 @@ class DMatrix(object):
|
||||
ctypes.c_size_t(csc.shape[0]),
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
def _init_from_npy2d(self, mat, missing):
|
||||
def _init_from_npy2d(self, mat, missing, nthread):
|
||||
"""
|
||||
Initialize data from a 2-D numpy matrix.
|
||||
|
||||
@ -333,11 +340,21 @@ class DMatrix(object):
|
||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
||||
self.handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle)))
|
||||
if nthread is None:
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat(
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle)))
|
||||
else:
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle),
|
||||
nthread))
|
||||
|
||||
def __del__(self):
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
@ -395,9 +412,29 @@ class DMatrix(object):
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
c_data = c_array(ctypes.c_float, data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_float, data),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_float_info_npy2d(self, field, data):
|
||||
"""Set float type property into the DMatrix
|
||||
for numpy 2d array input
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field: str
|
||||
The field name of the information
|
||||
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
data = np.array(data, copy=False, dtype=np.float32)
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_uint_info(self, field, data):
|
||||
@ -440,6 +477,17 @@ class DMatrix(object):
|
||||
"""
|
||||
self.set_float_info('label', label)
|
||||
|
||||
def set_label_npy2d(self, label):
|
||||
"""Set label of dmatrix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label: array like
|
||||
The label information to be set into DMatrix
|
||||
from numpy 2D array
|
||||
"""
|
||||
self.set_float_info_npy2d('label', label)
|
||||
|
||||
def set_weight(self, weight):
|
||||
""" Set weight of each instance.
|
||||
|
||||
@ -450,6 +498,17 @@ class DMatrix(object):
|
||||
"""
|
||||
self.set_float_info('weight', weight)
|
||||
|
||||
def set_weight_npy2d(self, weight):
|
||||
""" Set weight of each instance
|
||||
for numpy 2D array
|
||||
|
||||
Parameters
|
||||
----------
|
||||
weight : array like
|
||||
Weight for each data point in numpy 2D array
|
||||
"""
|
||||
self.set_float_info_npy2d('weight', weight)
|
||||
|
||||
def set_base_margin(self, margin):
|
||||
""" Set base margin of booster to start from.
|
||||
|
||||
|
||||
@ -290,6 +290,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
// FIXME: User should be able to control number of threads
|
||||
const int nthread = omp_get_max_threads();
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
common::ParallelGroupBuilder<RowBatch::Entry> builder(&mat.row_ptr_, &mat.row_data_);
|
||||
@ -350,24 +351,159 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
|
||||
|
||||
API_BEGIN();
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
mat.row_ptr_.resize(1+nrow);
|
||||
bool nan_missing = common::CheckNAN(missing);
|
||||
mat.info.num_row = nrow;
|
||||
mat.info.num_col = ncol;
|
||||
const bst_float* data0 = data;
|
||||
|
||||
// count elements for sizing data
|
||||
data = data0;
|
||||
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
|
||||
xgboost::bst_ulong nelem = 0;
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (common::CheckNAN(data[j])) {
|
||||
CHECK(nan_missing)
|
||||
<< "There are NAN in the matrix, however, you did not set missing=NAN";
|
||||
<< "There are NAN in the matrix, however, you did not set missing=NAN";
|
||||
} else {
|
||||
if (nan_missing || data[j] != missing) {
|
||||
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
|
||||
++nelem;
|
||||
}
|
||||
}
|
||||
}
|
||||
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
|
||||
mat.row_ptr_[i+1] = mat.row_ptr_[i] + nelem;
|
||||
}
|
||||
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
|
||||
|
||||
data = data0;
|
||||
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
|
||||
xgboost::bst_ulong matj = 0;
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (common::CheckNAN(data[j])) {
|
||||
} else {
|
||||
if (nan_missing || data[j] != missing) {
|
||||
mat.row_data_[mat.row_ptr_[i] + matj] = RowBatch::Entry(j, data[j]);
|
||||
++matj;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mat.info.num_nonzero = mat.row_data_.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
void prefixsum_inplace(size_t *x, size_t N) {
|
||||
size_t *suma;
|
||||
#pragma omp parallel
|
||||
{
|
||||
const int ithread = omp_get_thread_num();
|
||||
const int nthreads = omp_get_num_threads();
|
||||
#pragma omp single
|
||||
{
|
||||
suma = new size_t[nthreads+1];
|
||||
suma[0] = 0;
|
||||
}
|
||||
size_t sum = 0;
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < N; i++) {
|
||||
sum += x[i];
|
||||
x[i] = sum;
|
||||
}
|
||||
suma[ithread+1] = sum;
|
||||
#pragma omp barrier
|
||||
size_t offset = 0;
|
||||
for (omp_ulong i = 0; i < (ithread+1); i++) {
|
||||
offset += suma[i];
|
||||
}
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < N; i++) {
|
||||
x[i] += offset;
|
||||
}
|
||||
}
|
||||
delete[] suma;
|
||||
}
|
||||
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data,
|
||||
xgboost::bst_ulong nrow,
|
||||
xgboost::bst_ulong ncol,
|
||||
bst_float missing,
|
||||
DMatrixHandle* out,
|
||||
int nthread) {
|
||||
// avoid openmp unless enough data to be worth it to avoid overhead costs
|
||||
if (nrow*ncol <= 10000*50) {
|
||||
return(XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
|
||||
}
|
||||
|
||||
API_BEGIN();
|
||||
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
|
||||
// const int nthreadmax = omp_get_max_threads();
|
||||
if (nthread <= 0) nthread=nthreadmax;
|
||||
omp_set_num_threads(nthread);
|
||||
xgboost::bst_ulong nrow_reserve_per_thread = std::ceil(nrow/static_cast<double>(nthread));
|
||||
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
mat.row_ptr_.resize(1+nrow);
|
||||
mat.info.num_row = nrow;
|
||||
mat.info.num_col = ncol;
|
||||
|
||||
// Check for errors in missing elements
|
||||
// Count elements per row (to avoid otherwise need to copy)
|
||||
bool nan_missing = common::CheckNAN(missing);
|
||||
int *badnan;
|
||||
badnan = new int[nthread];
|
||||
for (int i = 0; i < nthread; i++) {
|
||||
badnan[i] = 0;
|
||||
}
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
int ithread = omp_get_thread_num();
|
||||
|
||||
// Count elements per row
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
xgboost::bst_ulong nelem = 0;
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (common::CheckNAN(data[ncol*i + j]) && !nan_missing) {
|
||||
badnan[ithread] = 1;
|
||||
} else if (common::CheckNAN(data[ncol * i + j])) {
|
||||
} else if (nan_missing || data[ncol * i + j] != missing) {
|
||||
++nelem;
|
||||
}
|
||||
}
|
||||
mat.row_ptr_[i+1] = nelem;
|
||||
}
|
||||
}
|
||||
// Inform about any NaNs and resize data matrix
|
||||
for (int i = 0; i < nthread; i++) {
|
||||
CHECK(!badnan[i]) << "There are NAN in the matrix, however, you did not set missing=NAN";
|
||||
}
|
||||
|
||||
// do cumulative sum (to avoid otherwise need to copy)
|
||||
prefixsum_inplace(&mat.row_ptr_[0], mat.row_ptr_.size());
|
||||
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
|
||||
|
||||
// Fill data matrix (now that know size, no need for slow push_back())
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
xgboost::bst_ulong matj = 0;
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (common::CheckNAN(data[ncol * i + j])) {
|
||||
} else if (nan_missing || data[ncol * i + j] != missing) {
|
||||
mat.row_data_[mat.row_ptr_[i] + matj] =
|
||||
RowBatch::Entry(j, data[ncol * i + j]);
|
||||
++matj;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mat.info.num_nonzero = mat.row_data_.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
|
||||
40
tests/cpp/c_api/test_c_api.cc
Normal file
40
tests/cpp/c_api/test_c_api.cc
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright by Contributors
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
TEST(c_api, XGDMatrixCreateFromMat_omp) {
|
||||
std::vector<int> num_rows = {100, 11374, 15000};
|
||||
for (auto row : num_rows) {
|
||||
int num_cols = 50;
|
||||
int num_missing = 5;
|
||||
DMatrixHandle handle;
|
||||
std::vector<float> data(num_cols * row, 1.5);
|
||||
for (int i = 0; i < num_missing; i++) {
|
||||
data[i] = std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
XGDMatrixCreateFromMat_omp(data.data(), row, num_cols,
|
||||
std::numeric_limits<float>::quiet_NaN(), &handle,
|
||||
0);
|
||||
|
||||
std::shared_ptr<xgboost::DMatrix> dmat =
|
||||
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
xgboost::MetaInfo &info = dmat->info();
|
||||
ASSERT_EQ(info.num_col, num_cols);
|
||||
ASSERT_EQ(info.num_row, row);
|
||||
ASSERT_EQ(info.num_nonzero, num_cols * row - num_missing);
|
||||
|
||||
auto iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
auto batch = iter->Value();
|
||||
for (int i = 0; i < batch.size; i++) {
|
||||
auto inst = batch[i];
|
||||
for (int j = 0; i < inst.length; i++) {
|
||||
ASSERT_EQ(inst[j].fvalue, 1.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -212,6 +212,23 @@ class TestBasic(unittest.TestCase):
|
||||
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
|
||||
model_file=u'不正なパス')
|
||||
|
||||
def test_dmatrix_numpy_init_omp(self):
|
||||
|
||||
rows = [1000, 11326, 15000]
|
||||
cols = 50
|
||||
for row in rows:
|
||||
X = np.random.randn(row, cols)
|
||||
y = np.random.randn(row).astype('f')
|
||||
dm = xgb.DMatrix(X, y, nthread=0)
|
||||
np.testing.assert_array_equal(dm.get_label(), y)
|
||||
assert dm.num_row() == row
|
||||
assert dm.num_col() == cols
|
||||
|
||||
dm = xgb.DMatrix(X, y, nthread=10)
|
||||
np.testing.assert_array_equal(dm.get_label(), y)
|
||||
assert dm.num_row() == row
|
||||
assert dm.num_col() == cols
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user