Fix several GPU bugs (#2916)
* Fix #2905 * Fix gpu_exact test failures * Fix bug in GPU prediction where multiple calls to batch prediction can produce incorrect results * Fix GPU documentation formatting
This commit is contained in:
parent
1e3aabbadc
commit
1b77903eeb
@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms.
|
||||
### Algorithms
|
||||
|
||||
```eval_rst
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| tree_method | Description |
|
||||
+==============+=================================================================================================================================================================================================================+
|
||||
+==============+=======================================================================================================================================================================+
|
||||
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. |
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### Supported parameters
|
||||
@ -27,28 +27,27 @@ Specify the 'tree_method' parameter as one of the following algorithms.
|
||||
.. |tick| unicode:: U+2714
|
||||
.. |cross| unicode:: U+2718
|
||||
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| parameter | gpu_exact | gpu_hist |
|
||||
+====================+============+===========+
|
||||
+======================+============+===========+
|
||||
| subsample | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| colsample_bytree | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| colsample_bylevel | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| max_bin | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| gpu_id | |tick| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| n_gpus | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| predictor | |tick| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| grow_policy | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
+----------------------+------------+-----------+
|
||||
| monotone_constraints | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
|
||||
+----------------------+------------+-----------+
|
||||
```
|
||||
|
||||
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
|
||||
|
||||
@ -292,11 +292,9 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
|
||||
tree_group.begin());
|
||||
|
||||
if (device_matrix->predictions.size() != out_preds->size()) {
|
||||
device_matrix->predictions.resize(out_preds->size());
|
||||
thrust::copy(out_preds->begin(), out_preds->end(),
|
||||
device_matrix->predictions.begin());
|
||||
}
|
||||
|
||||
const int BLOCK_THREADS = 128;
|
||||
const int GRID_SIZE = static_cast<int>(
|
||||
|
||||
@ -336,8 +336,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
this->Add(b.GetGrad(), b.GetHess());
|
||||
}
|
||||
/*! \brief calculate leaf weight */
|
||||
template <typename param_t>
|
||||
inline double CalcWeight(const param_t& param) const {
|
||||
template <typename param_t>
|
||||
XGBOOST_DEVICE inline double CalcWeight(const param_t ¶m) const {
|
||||
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
|
||||
}
|
||||
/*! \brief calculate gain of the solution */
|
||||
|
||||
@ -302,7 +302,7 @@ DEV_INLINE void argMaxWithAtomics(
|
||||
ExactSplitCandidate s;
|
||||
bst_gpair missing = parentSum - colSum;
|
||||
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
|
||||
param, 0, ValueConstraint(), tmp);
|
||||
param, tmp);
|
||||
s.index = id;
|
||||
atomicArgMax(nodeSplits + uid, s);
|
||||
} // end if nodeId != UNUSED_NODE
|
||||
@ -580,7 +580,7 @@ class GPUMaker : public TreeUpdater {
|
||||
// get the default direction for the current node
|
||||
bst_gpair missing = n.sum_gradients - gradSum;
|
||||
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
|
||||
gpu_param, 0, ValueConstraint(), missingLeft);
|
||||
gpu_param, missingLeft);
|
||||
// get the score/weight/id/gradSum for left and right child nodes
|
||||
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
|
||||
bst_gpair rGradSum = n.sum_gradients - lGradSum;
|
||||
|
||||
@ -240,6 +240,29 @@ __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
||||
return left_gain + right_gain - parent_gain;
|
||||
}
|
||||
|
||||
// Without constraints
|
||||
template <typename gpair_t>
|
||||
__device__ float inline loss_chg_missing(const gpair_t& scan,
|
||||
const gpair_t& missing,
|
||||
const gpair_t& parent_sum,
|
||||
const float& parent_gain,
|
||||
const GPUTrainingParam& param,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_loss =
|
||||
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
|
||||
float missing_right_loss =
|
||||
device_calc_loss_chg(param, scan, parent_sum, parent_gain);
|
||||
|
||||
if (missing_left_loss >= missing_right_loss) {
|
||||
missing_left_out = true;
|
||||
return missing_left_loss;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_loss;
|
||||
}
|
||||
}
|
||||
|
||||
// With constraints
|
||||
template <typename gpair_t>
|
||||
__device__ float inline loss_chg_missing(
|
||||
const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum,
|
||||
|
||||
@ -287,6 +287,10 @@ struct DeviceShard {
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(
|
||||
ellpack_matrix.size(), num_symbols);
|
||||
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
|
||||
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
|
||||
|
||||
@ -1,35 +1,74 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import unittest
|
||||
import xgboost as xgb
|
||||
from nose.plugins.attrib import attr
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
@attr('gpu')
|
||||
class TestGPUPredict(unittest.TestCase):
|
||||
def test_predict(self):
|
||||
iterations = 1
|
||||
iterations = 10
|
||||
np.random.seed(1)
|
||||
test_num_rows = [10, 1000, 5000]
|
||||
test_num_cols = [10, 50, 500]
|
||||
for num_rows in test_num_rows:
|
||||
for num_cols in test_num_cols:
|
||||
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||
watchlist = [(dm, 'train')]
|
||||
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||
watchlist = [(dtrain, 'train'), (dval, 'validation')]
|
||||
res = {}
|
||||
param = {
|
||||
"objective": "binary:logistic",
|
||||
"predictor": "gpu_predictor",
|
||||
'eval_metric': 'auc',
|
||||
}
|
||||
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res)
|
||||
bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res)
|
||||
assert self.non_decreasing(res["train"]["auc"])
|
||||
gpu_pred = bst.predict(dm, output_margin=True)
|
||||
bst.set_param({"predictor": "cpu_predictor"})
|
||||
cpu_pred = bst.predict(dm, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
gpu_pred_train = bst.predict(dtrain, output_margin=True)
|
||||
gpu_pred_test = bst.predict(dtest, output_margin=True)
|
||||
gpu_pred_val = bst.predict(dval, output_margin=True)
|
||||
|
||||
param["predictor"] = "cpu_predictor"
|
||||
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
|
||||
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
|
||||
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
|
||||
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5)
|
||||
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5)
|
||||
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5)
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
# Test case for a bug where multiple batch predictions made on a test set produce incorrect results
|
||||
def test_multi_predict(self):
|
||||
from sklearn.datasets import make_regression
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
n = 1000
|
||||
X, y = make_regression(n, random_state=rng)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dtest = xgb.DMatrix(X_test)
|
||||
|
||||
params = {}
|
||||
params["tree_method"] = "gpu_hist"
|
||||
|
||||
params['predictor'] = "gpu_predictor"
|
||||
bst_gpu_predict = xgb.train(params, dtrain)
|
||||
|
||||
params['predictor'] = "cpu_predictor"
|
||||
bst_cpu_predict = xgb.train(params, dtrain)
|
||||
|
||||
predict0 = bst_gpu_predict.predict(dtest)
|
||||
predict1 = bst_gpu_predict.predict(dtest)
|
||||
cpu_predict = bst_cpu_predict.predict(dtest)
|
||||
|
||||
assert np.allclose(predict0, predict1)
|
||||
assert np.allclose(predict0, cpu_predict)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user