Multi-threaded XGDMatrixCreateFromMat for faster DMatrix creation (#2530)

* Multi-threaded XGDMatrixCreateFromMat for faster DMatrix creation from numpy arrays for python interface.
This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-07-20 19:43:17 -07:00 committed by Rory Mitchell
parent 56550ff3f1
commit 6b375f6ad8
9 changed files with 324 additions and 73 deletions

View File

@ -206,6 +206,22 @@ XGB_DLL int XGDMatrixCreateFromMat(const float *data,
bst_ulong ncol,
float missing,
DMatrixHandle *out);
/*!
* \brief create matrix content from dense matrix
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param out created dmatrix
* \param nthread number of threads (up to maximum cores available, if <=0 use all cores)
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data,
bst_ulong nrow,
bst_ulong ncol,
float missing,
DMatrixHandle *out,
int nthread);
/*!
* \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced

View File

@ -5,13 +5,18 @@ import numpy as np
from sklearn.datasets import make_classification
import time
n = 1000000
num_rounds = 500
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
print("Generating dataset: {} rows * {} columns".format(args.rows,args.columns))
tmp = time.time()
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
dtrain = xgb.DMatrix(X, y)
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
print ("DMatrix Start")
# omp way
dtrain = xgb.DMatrix(X, y, nthread=-1)
# non-omp way
#dtrain = xgb.DMatrix(X, y)
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
param = {'objective': 'binary:logistic',
'max_depth': 6,
@ -24,7 +29,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
print("Training with '%s'" % param['tree_method'])
tmp = time.time()
xgb.train(param, dtrain, args.iterations)
print ("Time: %s seconds" % (str(time.time() - tmp)))
print ("Train Time: %s seconds" % (str(time.time() - tmp)))
param['silent'] = 1
param['tree_method'] = cpu_algorithm

View File

@ -343,8 +343,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
}
void GPUHistBuilder::BuildHist(int depth) {
// dh::Timer time;
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
size_t begin = device_element_segments[d_idx];
@ -1070,9 +1068,9 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode(gpair);
this->ColSampleTree();
for (int depth = 0; depth < param.max_depth; depth++) {
this->ColSampleLevel();
this->BuildHist(depth);
this->FindSplit(depth);
this->UpdatePosition(depth);

View File

@ -29,13 +29,14 @@ class TestGPU(unittest.TestCase):
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 1,
'nthread': 0,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'tree_method': 'gpu_exact',
'nthread': 0,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
@ -59,6 +60,7 @@ class TestGPU(unittest.TestCase):
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_exact',
'max_depth': 3,
'eval_metric': 'auc'}
@ -75,6 +77,7 @@ class TestGPU(unittest.TestCase):
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_exact',
'max_depth': 2,
'eval_metric': 'auc'}
@ -128,26 +131,28 @@ class TestGPU(unittest.TestCase):
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
'nthread': 1,
'nthread': 0,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth,
'nthread': 0,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 1,
'n_gpus': 1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'nthread': 0,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 1,
'n_gpus': n_gpus,
'objective': 'binary:logistic',
'max_bin': max_bin,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
@ -178,6 +183,7 @@ class TestGPU(unittest.TestCase):
param = {'objective': 'binary:logistic',
'tree_method': 'gpu_hist',
'nthread': 0,
'max_depth': max_depth,
'n_gpus': 1,
'max_bin': max_bin,
@ -189,6 +195,7 @@ class TestGPU(unittest.TestCase):
assert self.non_decreasing(res['train']['auc'])
#assert self.non_decreasing(res['test']['auc'])
param2 = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
@ -211,6 +218,7 @@ class TestGPU(unittest.TestCase):
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
@ -250,6 +258,7 @@ class TestGPU(unittest.TestCase):
######################################################################
# fail-safe test for max_bin
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
@ -263,6 +272,7 @@ class TestGPU(unittest.TestCase):
######################################################################
# subsampling
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
@ -279,6 +289,7 @@ class TestGPU(unittest.TestCase):
######################################################################
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': 2,
'n_gpus': n_gpus,

View File

@ -18,62 +18,48 @@ rng = np.random.RandomState(1994)
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
cols = 31
# reduced to fit onto 1 gpu but still be large
rows2 = 5000 # medium
#rows2 = 4032 # fake large for testing
rows3 = 5000 # small
rows2 = 4360032 # medium
rows1 = 42360032 # large
#rows2 = 152360032 # can do this for multi-gpu test (very large)
rowslist = [rows1, rows2]
#rows1 = 152360032 # can do this for multi-gpu test (very large)
rowslist = [rows1, rows2, rows3]
class TestGPU(unittest.TestCase):
def test_large(self):
eprint("Starting test for large data")
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
for rows in rowslist:
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
X, y = make_classification(rows, n_features=cols, random_state=7)
rowstest = int(rows*0.2)
eprint("Creating test data rows=%d cols=%d" % (rowstest,cols))
# note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot
Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8)
np.random.seed(7)
X = np.random.rand(rows, cols)
y = np.random.rand(rows)
eprint("Starting DMatrix(X,y)")
ag_dtrain = xgb.DMatrix(X,y)
eprint("Starting DMatrix(Xtest,ytest)")
ag_dtest = xgb.DMatrix(Xtest,ytest)
ag_dtrain = xgb.DMatrix(X,y,nthread=0)
max_depth=6
max_bin=1024
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
#'nthread': 1,
'nthread': 0,
'eta': 1,
'silent': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_paramb = {'max_depth': max_depth,
'tree_method': 'hist',
#'nthread': 1,
'nthread': 0,
'eta': 1,
'silent': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'n_gpus': 1,
@ -82,26 +68,18 @@ class TestGPU(unittest.TestCase):
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'n_gpus': -1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
#ag_param4 = {'max_depth': max_depth,
# 'tree_method': 'gpu_exact',
# 'eta': 1,
# 'silent': 0,
# 'n_gpus': 1,
# 'objective': 'binary:logistic',
# 'max_bin': max_bin,
# 'eval_metric': 'auc'}
ag_res = {}
ag_resb = {}
ag_res2 = {}
ag_res3 = {}
#ag_res4 = {}
num_rounds = 1
eprint("normal updater")
@ -116,19 +94,10 @@ class TestGPU(unittest.TestCase):
eprint("gpu_hist updater all gpus")
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res3)
#eprint("gpu_exact updater")
#xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
# evals_result=ag_res4)
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
#assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001
assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01
assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01
assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01
#assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01

View File

@ -224,7 +224,8 @@ class DMatrix(object):
def __init__(self, data, label=None, missing=None,
weight=None, silent=False,
feature_names=None, feature_types=None):
feature_names=None, feature_types=None,
nthread=None):
"""
Data matrix used in XGBoost.
@ -268,7 +269,7 @@ class DMatrix(object):
elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data)
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing)
self._init_from_npy2d(data, missing, nthread)
else:
try:
csr = scipy.sparse.csr_matrix(data)
@ -276,9 +277,15 @@ class DMatrix(object):
except:
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__))
if label is not None:
self.set_label(label)
if isinstance(data, np.ndarray):
self.set_label_npy2d(label)
else:
self.set_label(label)
if weight is not None:
self.set_weight(weight)
if isinstance(data, np.ndarray):
self.set_weight_npy2d(weight)
else:
self.set_weight(weight)
self.feature_names = feature_names
self.feature_types = feature_types
@ -313,7 +320,7 @@ class DMatrix(object):
ctypes.c_size_t(csc.shape[0]),
ctypes.byref(self.handle)))
def _init_from_npy2d(self, mat, missing):
def _init_from_npy2d(self, mat, missing, nthread):
"""
Initialize data from a 2-D numpy matrix.
@ -333,11 +340,21 @@ class DMatrix(object):
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
self.handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(self.handle)))
if nthread is None:
_check_call(_LIB.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(self.handle)))
else:
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(self.handle),
nthread))
def __del__(self):
_check_call(_LIB.XGDMatrixFree(self.handle))
@ -395,9 +412,29 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_array(ctypes.c_float, data),
c_data,
c_bst_ulong(len(data))))
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
for numpy 2d array input
Parameters
----------
field: str
The field name of the information
data: numpy array
The array of data to be set
"""
data = np.array(data, copy=False, dtype=np.float32)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
def set_uint_info(self, field, data):
@ -440,6 +477,17 @@ class DMatrix(object):
"""
self.set_float_info('label', label)
def set_label_npy2d(self, label):
"""Set label of dmatrix
Parameters
----------
label: array like
The label information to be set into DMatrix
from numpy 2D array
"""
self.set_float_info_npy2d('label', label)
def set_weight(self, weight):
""" Set weight of each instance.
@ -450,6 +498,17 @@ class DMatrix(object):
"""
self.set_float_info('weight', weight)
def set_weight_npy2d(self, weight):
""" Set weight of each instance
for numpy 2D array
Parameters
----------
weight : array like
Weight for each data point in numpy 2D array
"""
self.set_float_info_npy2d('weight', weight)
def set_base_margin(self, margin):
""" Set base margin of booster to start from.

View File

@ -290,6 +290,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
// FIXME: User should be able to control number of threads
const int nthread = omp_get_max_threads();
data::SimpleCSRSource& mat = *source;
common::ParallelGroupBuilder<RowBatch::Entry> builder(&mat.row_ptr_, &mat.row_data_);
@ -350,24 +351,159 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
bool nan_missing = common::CheckNAN(missing);
mat.info.num_row = nrow;
mat.info.num_col = ncol;
const bst_float* data0 = data;
// count elements for sizing data
data = data0;
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
xgboost::bst_ulong nelem = 0;
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[j])) {
CHECK(nan_missing)
<< "There are NAN in the matrix, however, you did not set missing=NAN";
<< "There are NAN in the matrix, however, you did not set missing=NAN";
} else {
if (nan_missing || data[j] != missing) {
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
++nelem;
}
}
}
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
mat.row_ptr_[i+1] = mat.row_ptr_[i] + nelem;
}
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
data = data0;
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
xgboost::bst_ulong matj = 0;
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[j])) {
} else {
if (nan_missing || data[j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] = RowBatch::Entry(j, data[j]);
++matj;
}
}
}
}
mat.info.num_nonzero = mat.row_data_.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
void prefixsum_inplace(size_t *x, size_t N) {
size_t *suma;
#pragma omp parallel
{
const int ithread = omp_get_thread_num();
const int nthreads = omp_get_num_threads();
#pragma omp single
{
suma = new size_t[nthreads+1];
suma[0] = 0;
}
size_t sum = 0;
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < N; i++) {
sum += x[i];
x[i] = sum;
}
suma[ithread+1] = sum;
#pragma omp barrier
size_t offset = 0;
for (omp_ulong i = 0; i < (ithread+1); i++) {
offset += suma[i];
}
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < N; i++) {
x[i] += offset;
}
}
delete[] suma;
}
XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data,
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol,
bst_float missing,
DMatrixHandle* out,
int nthread) {
// avoid openmp unless enough data to be worth it to avoid overhead costs
if (nrow*ncol <= 10000*50) {
return(XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
}
API_BEGIN();
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
// const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread=nthreadmax;
omp_set_num_threads(nthread);
xgboost::bst_ulong nrow_reserve_per_thread = std::ceil(nrow/static_cast<double>(nthread));
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
mat.info.num_row = nrow;
mat.info.num_col = ncol;
// Check for errors in missing elements
// Count elements per row (to avoid otherwise need to copy)
bool nan_missing = common::CheckNAN(missing);
int *badnan;
badnan = new int[nthread];
for (int i = 0; i < nthread; i++) {
badnan[i] = 0;
}
#pragma omp parallel num_threads(nthread)
{
int ithread = omp_get_thread_num();
// Count elements per row
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
xgboost::bst_ulong nelem = 0;
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[ncol*i + j]) && !nan_missing) {
badnan[ithread] = 1;
} else if (common::CheckNAN(data[ncol * i + j])) {
} else if (nan_missing || data[ncol * i + j] != missing) {
++nelem;
}
}
mat.row_ptr_[i+1] = nelem;
}
}
// Inform about any NaNs and resize data matrix
for (int i = 0; i < nthread; i++) {
CHECK(!badnan[i]) << "There are NAN in the matrix, however, you did not set missing=NAN";
}
// do cumulative sum (to avoid otherwise need to copy)
prefixsum_inplace(&mat.row_ptr_[0], mat.row_ptr_.size());
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
// Fill data matrix (now that know size, no need for slow push_back())
#pragma omp parallel num_threads(nthread)
{
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
xgboost::bst_ulong matj = 0;
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[ncol * i + j])) {
} else if (nan_missing || data[ncol * i + j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] =
RowBatch::Entry(j, data[ncol * i + j]);
++matj;
}
}
}
}
mat.info.num_nonzero = mat.row_data_.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();

View File

@ -0,0 +1,40 @@
// Copyright by Contributors
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <xgboost/data.h>
TEST(c_api, XGDMatrixCreateFromMat_omp) {
std::vector<int> num_rows = {100, 11374, 15000};
for (auto row : num_rows) {
int num_cols = 50;
int num_missing = 5;
DMatrixHandle handle;
std::vector<float> data(num_cols * row, 1.5);
for (int i = 0; i < num_missing; i++) {
data[i] = std::numeric_limits<float>::quiet_NaN();
}
XGDMatrixCreateFromMat_omp(data.data(), row, num_cols,
std::numeric_limits<float>::quiet_NaN(), &handle,
0);
std::shared_ptr<xgboost::DMatrix> dmat =
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
xgboost::MetaInfo &info = dmat->info();
ASSERT_EQ(info.num_col, num_cols);
ASSERT_EQ(info.num_row, row);
ASSERT_EQ(info.num_nonzero, num_cols * row - num_missing);
auto iter = dmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
auto batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
auto inst = batch[i];
for (int j = 0; i < inst.length; i++) {
ASSERT_EQ(inst[j].fvalue, 1.5);
}
}
}
}
}

View File

@ -212,6 +212,23 @@ class TestBasic(unittest.TestCase):
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
model_file=u'不正なパス')
def test_dmatrix_numpy_init_omp(self):
rows = [1000, 11326, 15000]
cols = 50
for row in rows:
X = np.random.randn(row, cols)
y = np.random.randn(row).astype('f')
dm = xgb.DMatrix(X, y, nthread=0)
np.testing.assert_array_equal(dm.get_label(), y)
assert dm.num_row() == row
assert dm.num_col() == cols
dm = xgb.DMatrix(X, y, nthread=10)
np.testing.assert_array_equal(dm.get_label(), y)
assert dm.num_row() == row
assert dm.num_col() == cols
def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5)
dm = xgb.DMatrix(data)