diff --git a/doc/gpu/index.md b/doc/gpu/index.md index a19767e1d..e72a01b2d 100644 --- a/doc/gpu/index.md +++ b/doc/gpu/index.md @@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms. ### Algorithms ```eval_rst -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| tree_method | Description | -+==============+=================================================================================================================================================================================================================+ -| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. | -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| tree_method | Description | ++==============+=======================================================================================================================================================================+ +| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. | ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` ### Supported parameters @@ -27,28 +27,27 @@ Specify the 'tree_method' parameter as one of the following algorithms. .. |tick| unicode:: U+2714 .. |cross| unicode:: U+2718 -+--------------------+------------+-----------+ -| parameter | gpu_exact | gpu_hist | -+====================+============+===========+ -| subsample | |cross| | |tick| | -+--------------------+------------+-----------+ -| colsample_bytree | |cross| | |tick| | -+--------------------+------------+-----------+ -| colsample_bylevel | |cross| | |tick| | -+--------------------+------------+-----------+ -| max_bin | |cross| | |tick| | -+--------------------+------------+-----------+ -| gpu_id | |tick| | |tick| | -+--------------------+------------+-----------+ -| n_gpus | |cross| | |tick| | -+--------------------+------------+-----------+ -| predictor | |tick| | |tick| | -+--------------------+------------+-----------+ -| grow_policy | |cross| | |tick| | -+--------------------+------------+-----------+ -| monotone_constraints | |cross| | |tick| | -+--------------------+------------+-----------+ - ++----------------------+------------+-----------+ +| parameter | gpu_exact | gpu_hist | ++======================+============+===========+ +| subsample | |cross| | |tick| | ++----------------------+------------+-----------+ +| colsample_bytree | |cross| | |tick| | ++----------------------+------------+-----------+ +| colsample_bylevel | |cross| | |tick| | ++----------------------+------------+-----------+ +| max_bin | |cross| | |tick| | ++----------------------+------------+-----------+ +| gpu_id | |tick| | |tick| | ++----------------------+------------+-----------+ +| n_gpus | |cross| | |tick| | ++----------------------+------------+-----------+ +| predictor | |tick| | |tick| | ++----------------------+------------+-----------+ +| grow_policy | |cross| | |tick| | ++----------------------+------------+-----------+ +| monotone_constraints | |cross| | |tick| | ++----------------------+------------+-----------+ ``` GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'. diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index a6f93ec71..f4f9f8cc6 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -292,11 +292,9 @@ class GPUPredictor : public xgboost::Predictor { thrust::copy(model.tree_info.begin(), model.tree_info.end(), tree_group.begin()); - if (device_matrix->predictions.size() != out_preds->size()) { - device_matrix->predictions.resize(out_preds->size()); - thrust::copy(out_preds->begin(), out_preds->end(), - device_matrix->predictions.begin()); - } + device_matrix->predictions.resize(out_preds->size()); + thrust::copy(out_preds->begin(), out_preds->end(), + device_matrix->predictions.begin()); const int BLOCK_THREADS = 128; const int GRID_SIZE = static_cast( diff --git a/src/tree/param.h b/src/tree/param.h index 6eddcc9d4..59c24eefa 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -336,8 +336,8 @@ struct XGBOOST_ALIGNAS(16) GradStats { this->Add(b.GetGrad(), b.GetHess()); } /*! \brief calculate leaf weight */ -template - inline double CalcWeight(const param_t& param) const { + template + XGBOOST_DEVICE inline double CalcWeight(const param_t ¶m) const { return xgboost::tree::CalcWeight(param, sum_grad, sum_hess); } /*! \brief calculate gain of the solution */ diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index e7a8285f0..50a466539 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -302,7 +302,7 @@ DEV_INLINE void argMaxWithAtomics( ExactSplitCandidate s; bst_gpair missing = parentSum - colSum; s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain, - param, 0, ValueConstraint(), tmp); + param, tmp); s.index = id; atomicArgMax(nodeSplits + uid, s); } // end if nodeId != UNUSED_NODE @@ -580,7 +580,7 @@ class GPUMaker : public TreeUpdater { // get the default direction for the current node bst_gpair missing = n.sum_gradients - gradSum; loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain, - gpu_param, 0, ValueConstraint(), missingLeft); + gpu_param, missingLeft); // get the score/weight/id/gradSum for left and right child nodes bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan; bst_gpair rGradSum = n.sum_gradients - lGradSum; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 462ae7f2a..63d5f98ef 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -240,6 +240,29 @@ __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param, return left_gain + right_gain - parent_gain; } +// Without constraints +template +__device__ float inline loss_chg_missing(const gpair_t& scan, + const gpair_t& missing, + const gpair_t& parent_sum, + const float& parent_gain, + const GPUTrainingParam& param, + bool& missing_left_out) { // NOLINT + float missing_left_loss = + device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain); + float missing_right_loss = + device_calc_loss_chg(param, scan, parent_sum, parent_gain); + + if (missing_left_loss >= missing_right_loss) { + missing_left_out = true; + return missing_left_loss; + } else { + missing_left_out = false; + return missing_right_loss; + } +} + +// With constraints template __device__ float inline loss_chg_missing( const gpair_t& scan, const gpair_t& missing, const gpair_t& parent_sum, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 4e9235d31..3e3700a10 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -287,6 +287,10 @@ struct DeviceShard { size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( ellpack_matrix.size(), num_symbols); + + CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) + << "Max leaves and max depth cannot both be unconstrained for " + "gpu_hist."; int max_nodes = param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth); ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 3322535dc..0b8d5b0ef 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -1,35 +1,74 @@ from __future__ import print_function import numpy as np +import sys import unittest import xgboost as xgb from nose.plugins.attrib import attr rng = np.random.RandomState(1994) + @attr('gpu') class TestGPUPredict(unittest.TestCase): def test_predict(self): - iterations = 1 + iterations = 10 np.random.seed(1) test_num_rows = [10, 1000, 5000] test_num_cols = [10, 50, 500] for num_rows in test_num_rows: for num_cols in test_num_cols: - dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2)) - watchlist = [(dm, 'train')] + dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2)) + dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2)) + dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2)) + watchlist = [(dtrain, 'train'), (dval, 'validation')] res = {} param = { "objective": "binary:logistic", "predictor": "gpu_predictor", 'eval_metric': 'auc', } - bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res) + bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res) assert self.non_decreasing(res["train"]["auc"]) - gpu_pred = bst.predict(dm, output_margin=True) - bst.set_param({"predictor": "cpu_predictor"}) - cpu_pred = bst.predict(dm, output_margin=True) - np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5) + gpu_pred_train = bst.predict(dtrain, output_margin=True) + gpu_pred_test = bst.predict(dtest, output_margin=True) + gpu_pred_val = bst.predict(dval, output_margin=True) + + param["predictor"] = "cpu_predictor" + bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist) + cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True) + cpu_pred_test = bst_cpu.predict(dtest, output_margin=True) + cpu_pred_val = bst_cpu.predict(dval, output_margin=True) + np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5) + np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5) + np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5) def non_decreasing(self, L): return all((x - y) < 0.001 for x, y in zip(L, L[1:])) + + # Test case for a bug where multiple batch predictions made on a test set produce incorrect results + def test_multi_predict(self): + from sklearn.datasets import make_regression + from sklearn.cross_validation import train_test_split + + n = 1000 + X, y = make_regression(n, random_state=rng) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123) + dtrain = xgb.DMatrix(X_train, label=y_train) + dtest = xgb.DMatrix(X_test) + + params = {} + params["tree_method"] = "gpu_hist" + + params['predictor'] = "gpu_predictor" + bst_gpu_predict = xgb.train(params, dtrain) + + params['predictor'] = "cpu_predictor" + bst_cpu_predict = xgb.train(params, dtrain) + + predict0 = bst_gpu_predict.predict(dtest) + predict1 = bst_gpu_predict.predict(dtest) + cpu_predict = bst_cpu_predict.predict(dtest) + + assert np.allclose(predict0, predict1) + assert np.allclose(predict0, cpu_predict)