Add GPU documentation (#2695)
* Add GPU documentation * Update Python GPU tests
This commit is contained in:
parent
e6a9063344
commit
9c85903f0b
102
doc/gpu/index.md
Normal file
102
doc/gpu/index.md
Normal file
@ -0,0 +1,102 @@
|
||||
XGBoost GPU Support
|
||||
===================
|
||||
|
||||
This page contains information about GPU algorithms supported in XGBoost.
|
||||
To install GPU support, checkout the [build and installation instructions](../build.md).
|
||||
|
||||
# CUDA Accelerated Tree Construction Algorithms
|
||||
This plugin adds GPU accelerated tree construction and prediction algorithms to XGBoost.
|
||||
## Usage
|
||||
Specify the 'tree_method' parameter as one of the following algorithms.
|
||||
|
||||
### Algorithms
|
||||
|
||||
```eval_rst
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| tree_method | Description |
|
||||
+==============+===============================================================================================================================================+
|
||||
| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. |
|
||||
+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### Supported parameters
|
||||
|
||||
```eval_rst
|
||||
.. |tick| unicode:: U+2714
|
||||
.. |cross| unicode:: U+2718
|
||||
|
||||
+--------------------+------------+-----------+
|
||||
| parameter | gpu_exact | gpu_hist |
|
||||
+====================+============+===========+
|
||||
| subsample | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| colsample_bytree | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| colsample_bylevel | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| max_bin | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| gpu_id | |tick| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| n_gpus | |cross| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
| predictor | |tick| | |tick| |
|
||||
+--------------------+------------+-----------+
|
||||
|
||||
|
|
||||
```
|
||||
|
||||
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
|
||||
|
||||
The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
|
||||
|
||||
Multiple GPUs can be used with the grow_gpu_hist parameter using the n_gpus parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If gpu_id is specified as non-zero, the gpu device order is mod(gpu_id + i) % n_visible_devices for i=0 to n_gpus-1. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. For example, when n_features * n_bins * 2^depth divided by time of each round/iteration becomes comparable to the real PCI 16x bus bandwidth of order 4GB/s to 10GB/s, then AllReduce will dominant code speed and multiple GPUs become ineffective at increasing performance. Also, CPU overhead between GPU calls can limit usefulness of multiple GPUs.
|
||||
|
||||
This plugin currently works with the CLI version and python version.
|
||||
|
||||
Python example:
|
||||
```python
|
||||
param['gpu_id'] = 0
|
||||
param['max_bin'] = 16
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
```
|
||||
## Benchmarks
|
||||
To run benchmarks on synthetic data for binary classification:
|
||||
```bash
|
||||
$ python tests/benchmark/benchmark.py
|
||||
```
|
||||
|
||||
Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations and 0.25/0.75 test/train split on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
|
||||
|
||||
```eval_rst
|
||||
+--------------+----------+
|
||||
| tree_method | Time (s) |
|
||||
+==============+==========+
|
||||
| gpu_hist | 13.87 |
|
||||
+--------------+----------+
|
||||
| hist | 63.55 |
|
||||
+--------------+----------+
|
||||
| gpu_exact | 161.08 |
|
||||
+--------------+----------+
|
||||
| exact | 1082.20 |
|
||||
+--------------+----------+
|
||||
|
||||
|
|
||||
```
|
||||
|
||||
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method.
|
||||
|
||||
## References
|
||||
[Mitchell R, Frank E. (2017) Accelerating the XGBoost algorithm using GPU computing. PeerJ Computer Science 3:e127 https://doi.org/10.7717/peerj-cs.127](https://peerj.com/articles/cs-127/)
|
||||
|
||||
## Author
|
||||
Rory Mitchell
|
||||
Jonathan C. McKinney
|
||||
Shankara Rao Thejaswi Nanditale
|
||||
Vinay Deshpande
|
||||
... and the rest of the H2O.ai and NVIDIA team.
|
||||
|
||||
Please report bugs to the xgboost/issues page.
|
||||
|
||||
@ -1,37 +1,35 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import xgboost as xgb
|
||||
from nose.plugins.attrib import attr
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@attr('gpu')
|
||||
class TestGPUPredict (unittest.TestCase):
|
||||
class TestGPUPredict(unittest.TestCase):
|
||||
def test_predict(self):
|
||||
iterations = 1
|
||||
np.random.seed(1)
|
||||
test_num_rows = [10,1000,5000]
|
||||
test_num_cols = [10,50,500]
|
||||
test_num_rows = [10, 1000, 5000]
|
||||
test_num_cols = [10, 50, 500]
|
||||
for num_rows in test_num_rows:
|
||||
for num_cols in test_num_cols:
|
||||
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows/2))
|
||||
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
|
||||
watchlist = [(dm, 'train')]
|
||||
res = {}
|
||||
param = {
|
||||
"objective":"binary:logistic",
|
||||
"predictor":"gpu_predictor",
|
||||
"objective": "binary:logistic",
|
||||
"predictor": "gpu_predictor",
|
||||
'eval_metric': 'auc',
|
||||
}
|
||||
bst = xgb.train(param, dm,iterations,evals=watchlist, evals_result=res)
|
||||
bst = xgb.train(param, dm, iterations, evals=watchlist, evals_result=res)
|
||||
assert self.non_decreasing(res["train"]["auc"])
|
||||
gpu_pred = bst.predict(dm, output_margin=True)
|
||||
bst.set_param({"predictor":"cpu_predictor"})
|
||||
bst.set_param({"predictor": "cpu_predictor"})
|
||||
cpu_pred = bst.predict(dm, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("../../tests/python")
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
from nose.plugins.attrib import attr
|
||||
@ -12,14 +12,15 @@ rng = np.random.RandomState(1994)
|
||||
|
||||
dpath = 'demo/data/'
|
||||
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
print(*args, file=sys.stdout, **kwargs)
|
||||
|
||||
|
||||
|
||||
@attr('gpu')
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_grow_gpu(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
@ -115,10 +116,8 @@ class TestGPU(unittest.TestCase):
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
|
||||
def test_grow_gpu_hist(self):
|
||||
n_gpus=-1
|
||||
tm._skip_if_no_sklearn()
|
||||
n_gpus = -1
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
@ -128,16 +127,15 @@ class TestGPU(unittest.TestCase):
|
||||
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
|
||||
for max_depth in range(3, 10): # TODO: Doesn't work with 2 for some tests
|
||||
# eprint("max_depth=%d" % (max_depth))
|
||||
|
||||
for max_bin_i in range(3, 11):
|
||||
max_bin = np.power(2, max_bin_i)
|
||||
# eprint("max_bin=%d" % (max_bin))
|
||||
|
||||
|
||||
for max_depth in range(3,10): # TODO: Doesn't work with 2 for some tests
|
||||
#eprint("max_depth=%d" % (max_depth))
|
||||
|
||||
for max_bin_i in range(3,11):
|
||||
max_bin = np.power(2,max_bin_i)
|
||||
#eprint("max_bin=%d" % (max_bin))
|
||||
|
||||
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
@ -172,13 +170,13 @@ class TestGPU(unittest.TestCase):
|
||||
ag_res3 = {}
|
||||
|
||||
num_rounds = 10
|
||||
#eprint("normal updater");
|
||||
# eprint("normal updater");
|
||||
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res)
|
||||
#eprint("grow_gpu_hist updater 1 gpu");
|
||||
# eprint("grow_gpu_hist updater 1 gpu");
|
||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res2)
|
||||
#eprint("grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||
# eprint("grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res3)
|
||||
# assert 1==0
|
||||
@ -203,11 +201,11 @@ class TestGPU(unittest.TestCase):
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
#eprint("digits: grow_gpu_hist updater 1 gpu");
|
||||
# eprint("digits: grow_gpu_hist updater 1 gpu");
|
||||
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||
evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
#assert self.non_decreasing(res['test']['auc'])
|
||||
# assert self.non_decreasing(res['test']['auc'])
|
||||
param2 = {'objective': 'binary:logistic',
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_hist',
|
||||
@ -217,13 +215,13 @@ class TestGPU(unittest.TestCase):
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res2 = {}
|
||||
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||
# eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||
xgb.train(param2, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||
evals_result=res2)
|
||||
assert self.non_decreasing(res2['train']['auc'])
|
||||
#assert self.non_decreasing(res2['test']['auc'])
|
||||
# assert self.non_decreasing(res2['test']['auc'])
|
||||
assert res['train']['auc'] == res2['train']['auc']
|
||||
#assert res['test']['auc'] == res2['test']['auc']
|
||||
# assert res['test']['auc'] == res2['test']['auc']
|
||||
|
||||
######################################################################
|
||||
# fail-safe test for dense data
|
||||
@ -244,7 +242,7 @@ class TestGPU(unittest.TestCase):
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
for j in range(X2.shape[1]):
|
||||
@ -257,7 +255,7 @@ class TestGPU(unittest.TestCase):
|
||||
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
|
||||
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
for j in range(X2.shape[1]):
|
||||
@ -268,7 +266,7 @@ class TestGPU(unittest.TestCase):
|
||||
res = {}
|
||||
xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
######################################################################
|
||||
@ -284,7 +282,7 @@ class TestGPU(unittest.TestCase):
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
######################################################################
|
||||
# subsampling
|
||||
@ -302,7 +300,7 @@ class TestGPU(unittest.TestCase):
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
######################################################################
|
||||
# fail-safe test for max_bin=2
|
||||
@ -317,9 +315,8 @@ class TestGPU(unittest.TestCase):
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
if max_bin>32:
|
||||
if max_bin > 32:
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
@ -1,39 +1,38 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.append("../../tests/python")
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
from nose.plugins.attrib import attr
|
||||
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
|
||||
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
sys.stderr.flush()
|
||||
print(*args, file=sys.stdout, **kwargs)
|
||||
sys.stdout.flush()
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
|
||||
cols = 31
|
||||
# reduced to fit onto 1 gpu but still be large
|
||||
rows3 = 5000 # small
|
||||
rows2 = 4360032 # medium
|
||||
rows1 = 42360032 # large
|
||||
#rows1 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rows3 = 5000 # small
|
||||
rows2 = 4360032 # medium
|
||||
rows1 = 42360032 # large
|
||||
# rows1 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rowslist = [rows1, rows2, rows3]
|
||||
|
||||
|
||||
@attr('slow')
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_large(self):
|
||||
eprint("Starting test for large data")
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
for rows in rowslist:
|
||||
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows, cols))
|
||||
tmp = time.time()
|
||||
np.random.seed(7)
|
||||
X = np.random.rand(rows, cols)
|
||||
@ -42,12 +41,12 @@ class TestGPU(unittest.TestCase):
|
||||
|
||||
eprint("Starting DMatrix(X,y)")
|
||||
tmp = time.time()
|
||||
ag_dtrain = xgb.DMatrix(X,y,nthread=40)
|
||||
ag_dtrain = xgb.DMatrix(X, y, nthread=40)
|
||||
print("Time to DMatrix: %r" % (time.time() - tmp))
|
||||
|
||||
max_depth=6
|
||||
max_bin=1024
|
||||
|
||||
max_depth = 6
|
||||
max_bin = 1024
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
@ -58,23 +57,23 @@ class TestGPU(unittest.TestCase):
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_paramb = {'max_depth': max_depth,
|
||||
'tree_method': 'hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
'tree_method': 'hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
ag_param3 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'nthread': 0,
|
||||
@ -92,10 +91,10 @@ class TestGPU(unittest.TestCase):
|
||||
|
||||
num_rounds = 1
|
||||
tmp = time.time()
|
||||
#eprint("hist updater")
|
||||
#xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
# eprint("hist updater")
|
||||
# xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
# evals_result=ag_resb)
|
||||
#print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
# print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
tmp = time.time()
|
||||
eprint("gpu_hist updater 1 gpu")
|
||||
@ -108,5 +107,3 @@ class TestGPU(unittest.TestCase):
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
evals_result=ag_res3)
|
||||
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user