Fix several GPU bugs (#2916)
* Fix #2905 * Fix gpu_exact test failures * Fix bug in GPU prediction where multiple calls to batch prediction can produce incorrect results * Fix GPU documentation formatting
This commit is contained in:
parent
1e3aabbadc
commit
1b77903eeb
@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms.
|
|||||||
### Algorithms
|
### Algorithms
|
||||||
|
|
||||||
```eval_rst
|
```eval_rst
|
||||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| tree_method | Description |
|
| tree_method | Description |
|
||||||
+==============+=================================================================================================================================================================================================================+
|
+==============+=======================================================================================================================================================================+
|
||||||
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
|
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
|
||||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. |
|
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. |
|
||||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
```
|
```
|
||||||
|
|
||||||
### Supported parameters
|
### Supported parameters
|
||||||
@ -27,28 +27,27 @@ Specify the 'tree_method' parameter as one of the following algorithms.
|
|||||||
.. |tick| unicode:: U+2714
|
.. |tick| unicode:: U+2714
|
||||||
.. |cross| unicode:: U+2718
|
.. |cross| unicode:: U+2718
|
||||||
|
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| parameter | gpu_exact | gpu_hist |
|
| parameter | gpu_exact | gpu_hist |
|
||||||
+====================+============+===========+
|
+======================+============+===========+
|
||||||
| subsample | |cross| | |tick| |
|
| subsample | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| colsample_bytree | |cross| | |tick| |
|
| colsample_bytree | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| colsample_bylevel | |cross| | |tick| |
|
| colsample_bylevel | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| max_bin | |cross| | |tick| |
|
| max_bin | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| gpu_id | |tick| | |tick| |
|
| gpu_id | |tick| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| n_gpus | |cross| | |tick| |
|
| n_gpus | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| predictor | |tick| | |tick| |
|
| predictor | |tick| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| grow_policy | |cross| | |tick| |
|
| grow_policy | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
| monotone_constraints | |cross| | |tick| |
|
| monotone_constraints | |cross| | |tick| |
|
||||||
+--------------------+------------+-----------+
|
+----------------------+------------+-----------+
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
|
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
|
||||||
|
|||||||
@ -292,11 +292,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
|
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
|
||||||
tree_group.begin());
|
tree_group.begin());
|
||||||
|
|
||||||
if (device_matrix->predictions.size() != out_preds->size()) {
|
|
||||||
device_matrix->predictions.resize(out_preds->size());
|
device_matrix->predictions.resize(out_preds->size());
|
||||||
thrust::copy(out_preds->begin(), out_preds->end(),
|
thrust::copy(out_preds->begin(), out_preds->end(),
|
||||||
device_matrix->predictions.begin());
|
device_matrix->predictions.begin());
|
||||||
}
|
|
||||||
|
|
||||||
const int BLOCK_THREADS = 128;
|
const int BLOCK_THREADS = 128;
|
||||||
const int GRID_SIZE = static_cast<int>(
|
const int GRID_SIZE = static_cast<int>(
|
||||||
|
|||||||
@ -336,8 +336,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
this->Add(b.GetGrad(), b.GetHess());
|
this->Add(b.GetGrad(), b.GetHess());
|
||||||
}
|
}
|
||||||
/*! \brief calculate leaf weight */
|
/*! \brief calculate leaf weight */
|
||||||
template <typename param_t>
|
template <typename param_t>
|
||||||
inline double CalcWeight(const param_t& param) const {
|
XGBOOST_DEVICE inline double CalcWeight(const param_t ¶m) const {
|
||||||
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
|
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
|
||||||
}
|
}
|
||||||
/*! \brief calculate gain of the solution */
|
/*! \brief calculate gain of the solution */
|
||||||
|
|||||||
@ -302,7 +302,7 @@ DEV_INLINE void argMaxWithAtomics(
|
|||||||
ExactSplitCandidate s;
|
ExactSplitCandidate s;
|
||||||
bst_gpair missing = parentSum - colSum;
|
bst_gpair missing = parentSum - colSum;
|
||||||
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
|
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
|
||||||
param, 0, ValueConstraint(), tmp);
|
param, tmp);
|
||||||
s.index = id;
|
s.index = id;
|
||||||
atomicArgMax(nodeSplits + uid, s);
|
atomicArgMax(nodeSplits + uid, s);
|
||||||
} // end if nodeId != UNUSED_NODE
|
} // end if nodeId != UNUSED_NODE
|
||||||
@ -580,7 +580,7 @@ class GPUMaker : public TreeUpdater {
|
|||||||
// get the default direction for the current node
|
// get the default direction for the current node
|
||||||
bst_gpair missing = n.sum_gradients - gradSum;
|
bst_gpair missing = n.sum_gradients - gradSum;
|
||||||
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
|
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
|
||||||
gpu_param, 0, ValueConstraint(), missingLeft);
|
gpu_param, missingLeft);
|
||||||
// get the score/weight/id/gradSum for left and right child nodes
|
// get the score/weight/id/gradSum for left and right child nodes
|
||||||
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
|
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
|
||||||
bst_gpair rGradSum = n.sum_gradients - lGradSum;
|
bst_gpair rGradSum = n.sum_gradients - lGradSum;
|
||||||
|
|||||||
@ -240,6 +240,29 @@ __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
|||||||
return left_gain + right_gain - parent_gain;
|
return left_gain + right_gain - parent_gain;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Without constraints
|
||||||
|
template <typename gpair_t>
|
||||||
|
__device__ float inline loss_chg_missing(const gpair_t& scan,
|
||||||
|
const gpair_t& missing,
|
||||||
|
const gpair_t& parent_sum,
|
||||||
|
const float& parent_gain,
|
||||||
|
const GPUTrainingParam& param,
|
||||||
|
bool& missing_left_out) { // NOLINT
|
||||||
|
float missing_left_loss =
|
||||||
|
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
|
||||||
|
float missing_right_loss =
|
||||||
|
device_calc_loss_chg(param, scan, parent_sum, parent_gain);
|
||||||
|
|
||||||
|
if (missing_left_loss >= missing_right_loss) {
|
||||||
|
missing_left_out = true;
|
||||||
|
return missing_left_loss;
|
||||||
|
} else {
|
||||||
|
missing_left_out = false;
|
||||||
|
return missing_right_loss;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With constraints
|
||||||
template <typename gpair_t>
|
template <typename gpair_t>
|
||||||
__device__ float inline loss_chg_missing(
|
__device__ float inline loss_chg_missing(
|
||||||
const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum,
|
const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum,
|
||||||
|
|||||||
@ -287,6 +287,10 @@ struct DeviceShard {
|
|||||||
size_t compressed_size_bytes =
|
size_t compressed_size_bytes =
|
||||||
common::CompressedBufferWriter::CalculateBufferSize(
|
common::CompressedBufferWriter::CalculateBufferSize(
|
||||||
ellpack_matrix.size(), num_symbols);
|
ellpack_matrix.size(), num_symbols);
|
||||||
|
|
||||||
|
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||||
|
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||||
|
"gpu_hist.";
|
||||||
int max_nodes =
|
int max_nodes =
|
||||||
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
|
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
|
||||||
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
|
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
|
||||||
|
|||||||
@ -1,35 +1,74 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from nose.plugins.attrib import attr
|
from nose.plugins.attrib import attr
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
@attr('gpu')
|
@attr('gpu')
|
||||||
class TestGPUPredict(unittest.TestCase):
|
class TestGPUPredict(unittest.TestCase):
|
||||||
def test_predict(self):
|
def test_predict(self):
|
||||||
iterations = 1
|
iterations = 10
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
test_num_rows = [10, 1000, 5000]
|
test_num_rows = [10, 1000, 5000]
|
||||||
test_num_cols = [10, 50, 500]
|
test_num_cols = [10, 50, 500]
|
||||||
for num_rows in test_num_rows:
|
for num_rows in test_num_rows:
|
||||||
for num_cols in test_num_cols:
|
for num_cols in test_num_cols:
|
||||||
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||||
watchlist = [(dm, 'train')]
|
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||||
|
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||||
|
watchlist = [(dtrain, 'train'), (dval, 'validation')]
|
||||||
res = {}
|
res = {}
|
||||||
param = {
|
param = {
|
||||||
"objective": "binary:logistic",
|
"objective": "binary:logistic",
|
||||||
"predictor": "gpu_predictor",
|
"predictor": "gpu_predictor",
|
||||||
'eval_metric': 'auc',
|
'eval_metric': 'auc',
|
||||||
}
|
}
|
||||||
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res)
|
bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res)
|
||||||
assert self.non_decreasing(res["train"]["auc"])
|
assert self.non_decreasing(res["train"]["auc"])
|
||||||
gpu_pred = bst.predict(dm, output_margin=True)
|
gpu_pred_train = bst.predict(dtrain, output_margin=True)
|
||||||
bst.set_param({"predictor": "cpu_predictor"})
|
gpu_pred_test = bst.predict(dtest, output_margin=True)
|
||||||
cpu_pred = bst.predict(dm, output_margin=True)
|
gpu_pred_val = bst.predict(dval, output_margin=True)
|
||||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
|
||||||
|
param["predictor"] = "cpu_predictor"
|
||||||
|
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
|
||||||
|
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
|
||||||
|
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
|
||||||
|
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
|
||||||
|
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5)
|
||||||
|
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5)
|
||||||
|
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5)
|
||||||
|
|
||||||
def non_decreasing(self, L):
|
def non_decreasing(self, L):
|
||||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||||
|
|
||||||
|
# Test case for a bug where multiple batch predictions made on a test set produce incorrect results
|
||||||
|
def test_multi_predict(self):
|
||||||
|
from sklearn.datasets import make_regression
|
||||||
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
|
n = 1000
|
||||||
|
X, y = make_regression(n, random_state=rng)
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
|
||||||
|
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
params["tree_method"] = "gpu_hist"
|
||||||
|
|
||||||
|
params['predictor'] = "gpu_predictor"
|
||||||
|
bst_gpu_predict = xgb.train(params, dtrain)
|
||||||
|
|
||||||
|
params['predictor'] = "cpu_predictor"
|
||||||
|
bst_cpu_predict = xgb.train(params, dtrain)
|
||||||
|
|
||||||
|
predict0 = bst_gpu_predict.predict(dtest)
|
||||||
|
predict1 = bst_gpu_predict.predict(dtest)
|
||||||
|
cpu_predict = bst_cpu_predict.predict(dtest)
|
||||||
|
|
||||||
|
assert np.allclose(predict0, predict1)
|
||||||
|
assert np.allclose(predict0, cpu_predict)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user