Fix several GPU bugs (#2916)

* Fix #2905

* Fix gpu_exact test failures

* Fix bug in GPU prediction where multiple calls to batch prediction can produce incorrect results

* Fix GPU documentation formatting
This commit is contained in:
Rory Mitchell 2017-12-04 08:27:49 +13:00 committed by GitHub
parent 1e3aabbadc
commit 1b77903eeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 109 additions and 46 deletions

View File

@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms.
### Algorithms
```eval_rst
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| tree_method | Description |
+==============+=================================================================================================================================================================================================================+
+==============+=======================================================================================================================================================================+
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
```
### Supported parameters
@ -27,28 +27,27 @@ Specify the 'tree_method' parameter as one of the following algorithms.
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| parameter | gpu_exact | gpu_hist |
+====================+============+===========+
+======================+============+===========+
| subsample | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| colsample_bytree | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| colsample_bylevel | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| max_bin | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| gpu_id | |tick| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| n_gpus | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| predictor | |tick| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| grow_policy | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
| monotone_constraints | |cross| | |tick| |
+--------------------+------------+-----------+
+----------------------+------------+-----------+
```
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.

View File

@ -292,11 +292,9 @@ class GPUPredictor : public xgboost::Predictor {
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
tree_group.begin());
if (device_matrix->predictions.size() != out_preds->size()) {
device_matrix->predictions.resize(out_preds->size());
thrust::copy(out_preds->begin(), out_preds->end(),
device_matrix->predictions.begin());
}
const int BLOCK_THREADS = 128;
const int GRID_SIZE = static_cast<int>(

View File

@ -336,8 +336,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
this->Add(b.GetGrad(), b.GetHess());
}
/*! \brief calculate leaf weight */
template <typename param_t>
inline double CalcWeight(const param_t& param) const {
template <typename param_t>
XGBOOST_DEVICE inline double CalcWeight(const param_t &param) const {
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
}
/*! \brief calculate gain of the solution */

View File

@ -302,7 +302,7 @@ DEV_INLINE void argMaxWithAtomics(
ExactSplitCandidate s;
bst_gpair missing = parentSum - colSum;
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
param, 0, ValueConstraint(), tmp);
param, tmp);
s.index = id;
atomicArgMax(nodeSplits + uid, s);
} // end if nodeId != UNUSED_NODE
@ -580,7 +580,7 @@ class GPUMaker : public TreeUpdater {
// get the default direction for the current node
bst_gpair missing = n.sum_gradients - gradSum;
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
gpu_param, 0, ValueConstraint(), missingLeft);
gpu_param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
bst_gpair rGradSum = n.sum_gradients - lGradSum;

View File

@ -240,6 +240,29 @@ __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
return left_gain + right_gain - parent_gain;
}
// Without constraints
template <typename gpair_t>
__device__ float inline loss_chg_missing(const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
float missing_right_loss =
device_calc_loss_chg(param, scan, parent_sum, parent_gain);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// With constraints
template <typename gpair_t>
__device__ float inline loss_chg_missing(
const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum,

View File

@ -287,6 +287,10 @@ struct DeviceShard {
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(
ellpack_matrix.size(), num_symbols);
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,

View File

@ -1,35 +1,74 @@
from __future__ import print_function
import numpy as np
import sys
import unittest
import xgboost as xgb
from nose.plugins.attrib import attr
rng = np.random.RandomState(1994)
@attr('gpu')
class TestGPUPredict(unittest.TestCase):
def test_predict(self):
iterations = 1
iterations = 10
np.random.seed(1)
test_num_rows = [10, 1000, 5000]
test_num_cols = [10, 50, 500]
for num_rows in test_num_rows:
for num_cols in test_num_cols:
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
watchlist = [(dm, 'train')]
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
watchlist = [(dtrain, 'train'), (dval, 'validation')]
res = {}
param = {
"objective": "binary:logistic",
"predictor": "gpu_predictor",
'eval_metric': 'auc',
}
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res)
bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res)
assert self.non_decreasing(res["train"]["auc"])
gpu_pred = bst.predict(dm, output_margin=True)
bst.set_param({"predictor": "cpu_predictor"})
cpu_pred = bst.predict(dm, output_margin=True)
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
gpu_pred_train = bst.predict(dtrain, output_margin=True)
gpu_pred_test = bst.predict(dtest, output_margin=True)
gpu_pred_val = bst.predict(dval, output_margin=True)
param["predictor"] = "cpu_predictor"
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5)
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
# Test case for a bug where multiple batch predictions made on a test set produce incorrect results
def test_multi_predict(self):
from sklearn.datasets import make_regression
from sklearn.cross_validation import train_test_split
n = 1000
X, y = make_regression(n, random_state=rng)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
params = {}
params["tree_method"] = "gpu_hist"
params['predictor'] = "gpu_predictor"
bst_gpu_predict = xgb.train(params, dtrain)
params['predictor'] = "cpu_predictor"
bst_cpu_predict = xgb.train(params, dtrain)
predict0 = bst_gpu_predict.predict(dtest)
predict1 = bst_gpu_predict.predict(dtest)
cpu_predict = bst_cpu_predict.predict(dtest)
assert np.allclose(predict0, predict1)
assert np.allclose(predict0, cpu_predict)