Fix several GPU bugs (#2916)

* Fix #2905

* Fix gpu_exact test failures

* Fix bug in GPU prediction where multiple calls to batch prediction can produce incorrect results

* Fix GPU documentation formatting
This commit is contained in:
Rory Mitchell 2017-12-04 08:27:49 +13:00 committed by GitHub
parent 1e3aabbadc
commit 1b77903eeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 109 additions and 46 deletions

View File

@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms.
### Algorithms ### Algorithms
```eval_rst ```eval_rst
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| tree_method | Description | | tree_method | Description |
+==============+=================================================================================================================================================================================================================+ +==============+=======================================================================================================================================================================+
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | | gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. | | gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
``` ```
### Supported parameters ### Supported parameters
@ -27,28 +27,27 @@ Specify the 'tree_method' parameter as one of the following algorithms.
.. |tick| unicode:: U+2714 .. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718 .. |cross| unicode:: U+2718
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| parameter | gpu_exact | gpu_hist | | parameter | gpu_exact | gpu_hist |
+====================+============+===========+ +======================+============+===========+
| subsample | |cross| | |tick| | | subsample | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| colsample_bytree | |cross| | |tick| | | colsample_bytree | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| colsample_bylevel | |cross| | |tick| | | colsample_bylevel | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| max_bin | |cross| | |tick| | | max_bin | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| gpu_id | |tick| | |tick| | | gpu_id | |tick| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| n_gpus | |cross| | |tick| | | n_gpus | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| predictor | |tick| | |tick| | | predictor | |tick| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| grow_policy | |cross| | |tick| | | grow_policy | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
| monotone_constraints | |cross| | |tick| | | monotone_constraints | |cross| | |tick| |
+--------------------+------------+-----------+ +----------------------+------------+-----------+
``` ```
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'. GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.

View File

@ -292,11 +292,9 @@ class GPUPredictor : public xgboost::Predictor {
thrust::copy(model.tree_info.begin(), model.tree_info.end(), thrust::copy(model.tree_info.begin(), model.tree_info.end(),
tree_group.begin()); tree_group.begin());
if (device_matrix->predictions.size() != out_preds->size()) { device_matrix->predictions.resize(out_preds->size());
device_matrix->predictions.resize(out_preds->size()); thrust::copy(out_preds->begin(), out_preds->end(),
thrust::copy(out_preds->begin(), out_preds->end(), device_matrix->predictions.begin());
device_matrix->predictions.begin());
}
const int BLOCK_THREADS = 128; const int BLOCK_THREADS = 128;
const int GRID_SIZE = static_cast<int>( const int GRID_SIZE = static_cast<int>(

View File

@ -336,8 +336,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
this->Add(b.GetGrad(), b.GetHess()); this->Add(b.GetGrad(), b.GetHess());
} }
/*! \brief calculate leaf weight */ /*! \brief calculate leaf weight */
template <typename param_t> template <typename param_t>
inline double CalcWeight(const param_t& param) const { XGBOOST_DEVICE inline double CalcWeight(const param_t &param) const {
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess); return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
} }
/*! \brief calculate gain of the solution */ /*! \brief calculate gain of the solution */

View File

@ -302,7 +302,7 @@ DEV_INLINE void argMaxWithAtomics(
ExactSplitCandidate s; ExactSplitCandidate s;
bst_gpair missing = parentSum - colSum; bst_gpair missing = parentSum - colSum;
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain, s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
param, 0, ValueConstraint(), tmp); param, tmp);
s.index = id; s.index = id;
atomicArgMax(nodeSplits + uid, s); atomicArgMax(nodeSplits + uid, s);
} // end if nodeId != UNUSED_NODE } // end if nodeId != UNUSED_NODE
@ -580,7 +580,7 @@ class GPUMaker : public TreeUpdater {
// get the default direction for the current node // get the default direction for the current node
bst_gpair missing = n.sum_gradients - gradSum; bst_gpair missing = n.sum_gradients - gradSum;
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain, loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
gpu_param, 0, ValueConstraint(), missingLeft); gpu_param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes // get the score/weight/id/gradSum for left and right child nodes
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan; bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
bst_gpair rGradSum = n.sum_gradients - lGradSum; bst_gpair rGradSum = n.sum_gradients - lGradSum;

View File

@ -240,6 +240,29 @@ __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
return left_gain + right_gain - parent_gain; return left_gain + right_gain - parent_gain;
} }
// Without constraints
template <typename gpair_t>
__device__ float inline loss_chg_missing(const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
float missing_right_loss =
device_calc_loss_chg(param, scan, parent_sum, parent_gain);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// With constraints
template <typename gpair_t> template <typename gpair_t>
__device__ float inline loss_chg_missing( __device__ float inline loss_chg_missing(
const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum, const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum,

View File

@ -287,6 +287,10 @@ struct DeviceShard {
size_t compressed_size_bytes = size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize( common::CompressedBufferWriter::CalculateBufferSize(
ellpack_matrix.size(), num_symbols); ellpack_matrix.size(), num_symbols);
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
int max_nodes = int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth); param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,

View File

@ -1,35 +1,74 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import sys
import unittest import unittest
import xgboost as xgb import xgboost as xgb
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@attr('gpu') @attr('gpu')
class TestGPUPredict(unittest.TestCase): class TestGPUPredict(unittest.TestCase):
def test_predict(self): def test_predict(self):
iterations = 1 iterations = 10
np.random.seed(1) np.random.seed(1)
test_num_rows = [10, 1000, 5000] test_num_rows = [10, 1000, 5000]
test_num_cols = [10, 50, 500] test_num_cols = [10, 50, 500]
for num_rows in test_num_rows: for num_rows in test_num_rows:
for num_cols in test_num_cols: for num_cols in test_num_cols:
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2)) dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
watchlist = [(dm, 'train')] dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
watchlist = [(dtrain, 'train'), (dval, 'validation')]
res = {} res = {}
param = { param = {
"objective": "binary:logistic", "objective": "binary:logistic",
"predictor": "gpu_predictor", "predictor": "gpu_predictor",
'eval_metric': 'auc', 'eval_metric': 'auc',
} }
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res) bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res)
assert self.non_decreasing(res["train"]["auc"]) assert self.non_decreasing(res["train"]["auc"])
gpu_pred = bst.predict(dm, output_margin=True) gpu_pred_train = bst.predict(dtrain, output_margin=True)
bst.set_param({"predictor": "cpu_predictor"}) gpu_pred_test = bst.predict(dtest, output_margin=True)
cpu_pred = bst.predict(dm, output_margin=True) gpu_pred_val = bst.predict(dval, output_margin=True)
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
param["predictor"] = "cpu_predictor"
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5)
def non_decreasing(self, L): def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:])) return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
# Test case for a bug where multiple batch predictions made on a test set produce incorrect results
def test_multi_predict(self):
from sklearn.datasets import make_regression
from sklearn.cross_validation import train_test_split
n = 1000
X, y = make_regression(n, random_state=rng)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
params = {}
params["tree_method"] = "gpu_hist"
params['predictor'] = "gpu_predictor"
bst_gpu_predict = xgb.train(params, dtrain)
params['predictor'] = "cpu_predictor"
bst_cpu_predict = xgb.train(params, dtrain)
predict0 = bst_gpu_predict.predict(dtest)
predict1 = bst_gpu_predict.predict(dtest)
cpu_predict = bst_cpu_predict.predict(dtest)
assert np.allclose(predict0, predict1)
assert np.allclose(predict0, cpu_predict)