Add GPU documentation (#2695)

* Add GPU documentation

* Update Python GPU tests
This commit is contained in:
Rory Mitchell 2017-09-10 19:42:46 +12:00 committed by GitHub
parent e6a9063344
commit 9c85903f0b
4 changed files with 176 additions and 82 deletions

102
doc/gpu/index.md Normal file
View File

@ -0,0 +1,102 @@
XGBoost GPU Support
===================
This page contains information about GPU algorithms supported in XGBoost.
To install GPU support, checkout the [build and installation instructions](../build.md).
# CUDA Accelerated Tree Construction Algorithms
This plugin adds GPU accelerated tree construction and prediction algorithms to XGBoost.
## Usage
Specify the 'tree_method' parameter as one of the following algorithms.
### Algorithms
```eval_rst
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
| tree_method | Description |
+==============+===============================================================================================================================================+
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. |
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
```
### Supported parameters
```eval_rst
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
+--------------------+------------+-----------+
| parameter | gpu_exact | gpu_hist |
+====================+============+===========+
| subsample | |cross| | |tick| |
+--------------------+------------+-----------+
| colsample_bytree | |cross| | |tick| |
+--------------------+------------+-----------+
| colsample_bylevel | |cross| | |tick| |
+--------------------+------------+-----------+
| max_bin | |cross| | |tick| |
+--------------------+------------+-----------+
| gpu_id | |tick| | |tick| |
+--------------------+------------+-----------+
| n_gpus | |cross| | |tick| |
+--------------------+------------+-----------+
| predictor | |tick| | |tick| |
+--------------------+------------+-----------+
|
```
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
Multiple GPUs can be used with the grow_gpu_hist parameter using the n_gpus parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If gpu_id is specified as non-zero, the gpu device order is mod(gpu_id + i) % n_visible_devices for i=0 to n_gpus-1. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. For example, when n_features * n_bins * 2^depth divided by time of each round/iteration becomes comparable to the real PCI 16x bus bandwidth of order 4GB/s to 10GB/s, then AllReduce will dominant code speed and multiple GPUs become ineffective at increasing performance. Also, CPU overhead between GPU calls can limit usefulness of multiple GPUs.
This plugin currently works with the CLI version and python version.
Python example:
```python
param['gpu_id'] = 0
param['max_bin'] = 16
param['tree_method'] = 'gpu_hist'
```
## Benchmarks
To run benchmarks on synthetic data for binary classification:
```bash
$ python tests/benchmark/benchmark.py
```
Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations and 0.25/0.75 test/train split on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
```eval_rst
+--------------+----------+
| tree_method | Time (s) |
+==============+==========+
| gpu_hist | 13.87 |
+--------------+----------+
| hist | 63.55 |
+--------------+----------+
| gpu_exact | 161.08 |
+--------------+----------+
| exact | 1082.20 |
+--------------+----------+
|
```
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method.
## References
[Mitchell R, Frank E. (2017) Accelerating the XGBoost algorithm using GPU computing. PeerJ Computer Science 3:e127 https://doi.org/10.7717/peerj-cs.127](https://peerj.com/articles/cs-127/)
## Author
Rory Mitchell
Jonathan C. McKinney
Shankara Rao Thejaswi Nanditale
Vinay Deshpande
... and the rest of the H2O.ai and NVIDIA team.
Please report bugs to the xgboost/issues page.

View File

@ -1,37 +1,35 @@
from __future__ import print_function
#pylint: skip-file
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
import xgboost as xgb
from nose.plugins.attrib import attr
rng = np.random.RandomState(1994)
@attr('gpu')
class TestGPUPredict (unittest.TestCase):
class TestGPUPredict(unittest.TestCase):
def test_predict(self):
iterations = 1
np.random.seed(1)
test_num_rows = [10,1000,5000]
test_num_cols = [10,50,500]
test_num_rows = [10, 1000, 5000]
test_num_cols = [10, 50, 500]
for num_rows in test_num_rows:
for num_cols in test_num_cols:
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows/2))
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
watchlist = [(dm, 'train')]
res = {}
param = {
"objective":"binary:logistic",
"predictor":"gpu_predictor",
"objective": "binary:logistic",
"predictor": "gpu_predictor",
'eval_metric': 'auc',
}
bst = xgb.train(param, dm,iterations,evals=watchlist, evals_result=res)
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res)
assert self.non_decreasing(res["train"]["auc"])
gpu_pred = bst.predict(dm, output_margin=True)
bst.set_param({"predictor":"cpu_predictor"})
bst.set_param({"predictor": "cpu_predictor"})
cpu_pred = bst.predict(dm, output_margin=True)
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))

View File

@ -1,9 +1,9 @@
from __future__ import print_function
#pylint: skip-file
import sys
sys.path.append("../../tests/python")
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
from nose.plugins.attrib import attr
@ -12,14 +12,15 @@ rng = np.random.RandomState(1994)
dpath = 'demo/data/'
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
print(*args, file=sys.stdout, **kwargs)
@attr('gpu')
class TestGPU(unittest.TestCase):
def test_grow_gpu(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
@ -115,10 +116,8 @@ class TestGPU(unittest.TestCase):
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def test_grow_gpu_hist(self):
n_gpus=-1
tm._skip_if_no_sklearn()
n_gpus = -1
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
@ -128,16 +127,15 @@ class TestGPU(unittest.TestCase):
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
for max_depth in range(3, 10): # TODO: Doesn't work with 2 for some tests
# eprint("max_depth=%d" % (max_depth))
for max_bin_i in range(3, 11):
max_bin = np.power(2, max_bin_i)
# eprint("max_bin=%d" % (max_bin))
for max_depth in range(3,10): # TODO: Doesn't work with 2 for some tests
#eprint("max_depth=%d" % (max_depth))
for max_bin_i in range(3,11):
max_bin = np.power(2,max_bin_i)
#eprint("max_bin=%d" % (max_bin))
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
@ -172,13 +170,13 @@ class TestGPU(unittest.TestCase):
ag_res3 = {}
num_rounds = 10
#eprint("normal updater");
# eprint("normal updater");
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
#eprint("grow_gpu_hist updater 1 gpu");
# eprint("grow_gpu_hist updater 1 gpu");
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
#eprint("grow_gpu_hist updater %d gpus" % (n_gpus));
# eprint("grow_gpu_hist updater %d gpus" % (n_gpus));
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res3)
# assert 1==0
@ -203,11 +201,11 @@ class TestGPU(unittest.TestCase):
'debug_verbose': 0,
'eval_metric': 'auc'}
res = {}
#eprint("digits: grow_gpu_hist updater 1 gpu");
# eprint("digits: grow_gpu_hist updater 1 gpu");
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
#assert self.non_decreasing(res['test']['auc'])
# assert self.non_decreasing(res['test']['auc'])
param2 = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
@ -217,13 +215,13 @@ class TestGPU(unittest.TestCase):
'debug_verbose': 0,
'eval_metric': 'auc'}
res2 = {}
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
# eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
xgb.train(param2, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res2)
assert self.non_decreasing(res2['train']['auc'])
#assert self.non_decreasing(res2['test']['auc'])
# assert self.non_decreasing(res2['test']['auc'])
assert res['train']['auc'] == res2['train']['auc']
#assert res['test']['auc'] == res2['test']['auc']
# assert res['test']['auc'] == res2['test']['auc']
######################################################################
# fail-safe test for dense data
@ -244,7 +242,7 @@ class TestGPU(unittest.TestCase):
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
@ -257,7 +255,7 @@ class TestGPU(unittest.TestCase):
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
@ -268,7 +266,7 @@ class TestGPU(unittest.TestCase):
res = {}
xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
@ -284,7 +282,7 @@ class TestGPU(unittest.TestCase):
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
# subsampling
@ -302,7 +300,7 @@ class TestGPU(unittest.TestCase):
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
# fail-safe test for max_bin=2
@ -317,9 +315,8 @@ class TestGPU(unittest.TestCase):
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin>32:
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))

View File

@ -1,39 +1,38 @@
from __future__ import print_function
#pylint: skip-file
import sys
import time
sys.path.append("../../tests/python")
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
from nose.plugins.attrib import attr
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
print(*args, file=sys.stderr, **kwargs)
sys.stderr.flush()
print(*args, file=sys.stdout, **kwargs)
sys.stdout.flush()
rng = np.random.RandomState(1994)
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
cols = 31
# reduced to fit onto 1 gpu but still be large
rows3 = 5000 # small
rows2 = 4360032 # medium
rows1 = 42360032 # large
#rows1 = 152360032 # can do this for multi-gpu test (very large)
rows3 = 5000 # small
rows2 = 4360032 # medium
rows1 = 42360032 # large
# rows1 = 152360032 # can do this for multi-gpu test (very large)
rowslist = [rows1, rows2, rows3]
@attr('slow')
class TestGPU(unittest.TestCase):
def test_large(self):
eprint("Starting test for large data")
tm._skip_if_no_sklearn()
for rows in rowslist:
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
eprint("Creating train data rows=%d cols=%d" % (rows, cols))
tmp = time.time()
np.random.seed(7)
X = np.random.rand(rows, cols)
@ -42,12 +41,12 @@ class TestGPU(unittest.TestCase):
eprint("Starting DMatrix(X,y)")
tmp = time.time()
ag_dtrain = xgb.DMatrix(X,y,nthread=40)
ag_dtrain = xgb.DMatrix(X, y, nthread=40)
print("Time to DMatrix: %r" % (time.time() - tmp))
max_depth=6
max_bin=1024
max_depth = 6
max_bin = 1024
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
@ -58,23 +57,23 @@ class TestGPU(unittest.TestCase):
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_paramb = {'max_depth': max_depth,
'tree_method': 'hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'debug_verbose': 5,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
'tree_method': 'hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'debug_verbose': 5,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'debug_verbose': 5,
'n_gpus': 1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
'tree_method': 'gpu_hist',
'nthread': 0,
'eta': 1,
'silent': 0,
'debug_verbose': 5,
'n_gpus': 1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'nthread': 0,
@ -92,10 +91,10 @@ class TestGPU(unittest.TestCase):
num_rounds = 1
tmp = time.time()
#eprint("hist updater")
#xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
# eprint("hist updater")
# xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
# evals_result=ag_resb)
#print("Time to Train: %s seconds" % (str(time.time() - tmp)))
# print("Time to Train: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
eprint("gpu_hist updater 1 gpu")
@ -108,5 +107,3 @@ class TestGPU(unittest.TestCase):
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res3)
print("Time to Train: %s seconds" % (str(time.time() - tmp)))