[GPU-Plugin] Major refactor 2 (#2664)

* Change cmake option

* Move source files

* Move google tests

* Move python tests

* Move benchmarks

* Move documentation

* Remove makefile support

* Fix test run

* Move GPU tests
This commit is contained in:
Rory Mitchell 2017-09-08 09:57:16 +12:00 committed by GitHub
parent 8244f6f120
commit 15267eedf2
21 changed files with 76 additions and 249 deletions

View File

@ -7,12 +7,17 @@ set_default_configuration_release()
msvc_use_static_runtime() msvc_use_static_runtime()
# Options # Options
option(PLUGIN_UPDATER_GPU "Build GPU accelerated tree construction plugin") option(USE_CUDA "Build with GPU acceleration")
option(JVM_BINDINGS "Build JVM bindings" OFF) option(JVM_BINDINGS "Build JVM bindings" OFF)
option(GOOGLE_TEST "Build google tests" OFF) option(GOOGLE_TEST "Build google tests" OFF)
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against") "Space separated list of compute versions to be built against")
# Deprecation warning
if(PLUGIN_UPDATER_GPU)
set(USE_CUDA ON)
message(WARNING "The option 'PLUGIN_UPDATER_GPU' is deprecated. Set 'USE_CUDA' instead.")
endif()
# Compiler flags # Compiler flags
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
@ -40,9 +45,14 @@ file(GLOB_RECURSE SOURCES
src/*.h src/*.h
include/*.h include/*.h
) )
# Only add main function for executable target # Only add main function for executable target
list(REMOVE_ITEM SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc) list(REMOVE_ITEM SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc)
file(GLOB_RECURSE CUDA_SOURCES
src/*.cu
src/*.cuh
)
# rabit # rabit
# TODO: Create rabit cmakelists.txt # TODO: Create rabit cmakelists.txt
@ -68,13 +78,7 @@ endif()
add_subdirectory(dmlc-core) add_subdirectory(dmlc-core)
set(LINK_LIBRARIES dmlccore rabit) set(LINK_LIBRARIES dmlccore rabit)
if(USE_CUDA)
# GPU Plugin
file(GLOB_RECURSE CUDA_SOURCES
plugin/updater_gpu/src/*.cu
plugin/updater_gpu/src/*.cuh
)
if(PLUGIN_UPDATER_GPU)
find_package(CUDA 7.5 REQUIRED) find_package(CUDA 7.5 REQUIRED)
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
@ -144,8 +148,8 @@ if(GOOGLE_TEST)
auto_source_group("${TEST_SOURCES}") auto_source_group("${TEST_SOURCES}")
include_directories(${GTEST_INCLUDE_DIRS}) include_directories(${GTEST_INCLUDE_DIRS})
if(PLUGIN_UPDATER_GPU) if(USE_CUDA)
file(GLOB_RECURSE CUDA_TEST_SOURCES "plugin/updater_gpu/test/cpp/*.cu") file(GLOB_RECURSE CUDA_TEST_SOURCES "tests/cpp/*.cu")
cuda_compile(CUDA_TEST_OBJS ${CUDA_TEST_SOURCES}) cuda_compile(CUDA_TEST_OBJS ${CUDA_TEST_SOURCES})
else() else()
set(CUDA_TEST_OBJS "") set(CUDA_TEST_OBJS "")

View File

@ -116,22 +116,6 @@ else
endif endif
CFLAGS += $(OPENMP_FLAGS) CFLAGS += $(OPENMP_FLAGS)
# for using GPUs
GPU_COMPUTE_VER ?= 35 50 52 60 61
NVCC = nvcc
INCLUDES = -Iinclude -I$(DMLC_CORE)/include -I$(RABIT)/include
INCLUDES += -I$(CUB_PATH)
INCLUDES += -I$(GTEST_PATH)/include
CODE = $(foreach ver,$(GPU_COMPUTE_VER),-gencode arch=compute_$(ver),code=sm_$(ver))
NVCC_FLAGS = --std=c++11 $(CODE) $(INCLUDES) -lineinfo --expt-extended-lambda
NVCC_FLAGS += -Xcompiler=$(OPENMP_FLAGS) -Xcompiler=-fPIC
ifeq ($(PLUGIN_UPDATER_GPU),ON)
CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC))))
INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt
CFLAGS += -DXGBOOST_USE_CUDA
endif
# specify tensor path # specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint .PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint
@ -153,31 +137,16 @@ ALL_DEP = $(filter-out build/cli_main.o, $(ALL_OBJ)) $(LIB_DEP)
CLI_OBJ = build/cli_main.o CLI_OBJ = build/cli_main.o
include tests/cpp/xgboost_test.mk include tests/cpp/xgboost_test.mk
# order of this rule matters wrt %.cc rule below!
build/%.o: src/%.cu
@mkdir -p $(@D)
$(NVCC) -c $(NVCC_FLAGS) $< -o $@
build/%.o: src/%.cc build/%.o: src/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) $< -o $@ $(CXX) -c $(CFLAGS) $< -o $@
# order of this rule matters wrt %.cc rule below!
build_plugin/%.o: plugin/%.cu build_nccl
@mkdir -p $(@D)
$(NVCC) -c $(NVCC_FLAGS) $< -o $@
build_plugin/%.o: plugin/%.cc build_plugin/%.o: plugin/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build_plugin/$*.o $< >build_plugin/$*.d $(CXX) $(CFLAGS) -MM -MT build_plugin/$*.o $< >build_plugin/$*.d
$(CXX) -c $(CFLAGS) $< -o $@ $(CXX) -c $(CFLAGS) $< -o $@
build_nccl:
@mkdir -p build/include
cd build/include ; ln -sf ../../nccl/src/nccl.h .
cd nccl ; make -j ; cd ..
# The should be equivalent to $(ALL_OBJ) except for build/cli_main.o # The should be equivalent to $(ALL_OBJ) except for build/cli_main.o
amalgamation/xgboost-all0.o: amalgamation/xgboost-all0.cc amalgamation/xgboost-all0.o: amalgamation/xgboost-all0.cc
$(CXX) -c $(CFLAGS) $< -o $@ $(CXX) -c $(CFLAGS) $< -o $@

View File

@ -23,6 +23,7 @@
<li><a href="{{url_root}}jvm/index.html">JVM</a></li> <li><a href="{{url_root}}jvm/index.html">JVM</a></li>
<li><a href="{{url_root}}julia/index.html">Julia</a></li> <li><a href="{{url_root}}julia/index.html">Julia</a></li>
<li><a href="{{url_root}}cli/index.html">CLI</a></li> <li><a href="{{url_root}}cli/index.html">CLI</a></li>
<li><a href="{{url_root}}gpu/index.html">GPU</a></li>
</ul> </ul>
</li> </li>
{% endfor %} {% endfor %}

View File

@ -28,6 +28,7 @@ even better to send pull request if you can fix the problem.
- [Building on Ubuntu/Debian](#building-on-ubuntu-debian) - [Building on Ubuntu/Debian](#building-on-ubuntu-debian)
- [Building on OSX](#building-on-osx) - [Building on OSX](#building-on-osx)
- [Building on Windows](#building-on-windows) - [Building on Windows](#building-on-windows)
- [Building with GPU support](#building-with-gpu-support)
- [Windows Binaries](#windows-binaries) - [Windows Binaries](#windows-binaries)
- [Customized Building](#customized-building) - [Customized Building](#customized-building)
- [Python Package Installation](#python-package-installation) - [Python Package Installation](#python-package-installation)
@ -131,6 +132,32 @@ This specifies an out of source build using the MSVC 12 64 bit generator. Open t
Other versions of Visual Studio may work but are untested. Other versions of Visual Studio may work but are untested.
### Building with GPU support
XGBoost can be built with GPU support for both Linux and Windows using cmake. GPU support works with the Python package as well as the CLI version. The R package is not yet supported.
An up-to-date version of the cuda toolkit is required.
From the command line on Linux starting from the xgboost directory:
```bash
$ mkdir build
$ cd build
$ cmake .. -DUSE_CUDA=ON
$ make -j
```
On Windows using cmake, see what options for Generators you have for cmake, and choose one with [arch] replaced by Win64:
```bash
cmake -help
```
Then run cmake as:
```bash
$ mkdir build
$ cd build
$ cmake .. -G"Visual Studio 14 2015 Win64" -DUSE_CUDA=ON
```
Cmake will create an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.
### Windows Binaries ### Windows Binaries
Unofficial windows binaries and instructions on how to use them are hosted on [Guido Tapia's blog](http://www.picnet.com.au/blogs/guido/post/2016/09/22/xgboost-windows-x64-binaries-for-download/) Unofficial windows binaries and instructions on how to use them are hosted on [Guido Tapia's blog](http://www.picnet.com.au/blogs/guido/post/2016/09/22/xgboost-windows-x64-binaries-for-download/)

View File

@ -168,6 +168,8 @@ def setup(app):
'_static/jquery.js') '_static/jquery.js')
app.add_config_value('recommonmark_config', { app.add_config_value('recommonmark_config', {
'url_resolver': lambda url: github_doc_root + url, 'url_resolver': lambda url: github_doc_root + url,
}, True) 'enable_eval_rst': True,
}, True,
)
app.add_transform(AutoStructify) app.add_transform(AutoStructify)
app.add_javascript('jquery.js') app.add_javascript('jquery.js')

View File

@ -8,6 +8,7 @@ This page contains guidelines to use and develop XGBoost.
## Use XGBoost in Specific Ways ## Use XGBoost in Specific Ways
- [Parameter tuning guide](param_tuning.md) - [Parameter tuning guide](param_tuning.md)
- [Use out of core computation for large dataset](external_memory.md) - [Use out of core computation for large dataset](external_memory.md)
- [Use XGBoost GPU algorithms](../gpu/index.md)
## Develop and Hack XGBoost ## Develop and Hack XGBoost
- [Contribute to XGBoost](contribute.md) - [Contribute to XGBoost](contribute.md)

View File

@ -10,6 +10,7 @@ These are used to generate the index used in search.
* [Java/Scala Package Document](jvm/index.md) * [Java/Scala Package Document](jvm/index.md)
* [Julia Package Document](julia/index.md) * [Julia Package Document](julia/index.md)
* [CLI Package Document](cli/index.md) * [CLI Package Document](cli/index.md)
* [GPU Support Document](gpu/index.md)
- [Howto Documents](how_to/index.md) - [Howto Documents](how_to/index.md)
- [Get Started Documents](get_started/index.md) - [Get Started Documents](get_started/index.md)
- [Tutorials](tutorials/index.md) - [Tutorials](tutorials/index.md)

View File

@ -1,161 +1,3 @@
# CUDA Accelerated Tree Construction Algorithms # CUDA Accelerated Tree Construction Algorithms
This plugin adds GPU accelerated tree construction and prediction algorithms to XGBoost.
## Usage
Specify the 'tree_method' parameter as one of the following algorithms.
### Algorithms
| tree_method | Description |
| --- | --- |
gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' |
gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. |
### Supported parameters
| parameter | gpu_exact | gpu_hist |
| --- | --- | --- |
subsample | &#10004; | &#10004; |
colsample_bytree | &#10004; | &#10004;|
colsample_bylevel | &#10004; | &#10004; |
max_bin | &#10006; | &#10004; |
gpu_id | &#10004; | &#10004; |
n_gpus | &#10006; | &#10004; |
predictor | &#10004; | &#10004; |
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
Multiple GPUs can be used with the grow_gpu_hist parameter using the n_gpus parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If gpu_id is specified as non-zero, the gpu device order is mod(gpu_id + i) % n_visible_devices for i=0 to n_gpus-1. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. For example, when n_features * n_bins * 2^depth divided by time of each round/iteration becomes comparable to the real PCI 16x bus bandwidth of order 4GB/s to 10GB/s, then AllReduce will dominant code speed and multiple GPUs become ineffective at increasing performance. Also, CPU overhead between GPU calls can limit usefulness of multiple GPUs.
This plugin currently works with the CLI version and python version.
Python example:
```python
param['gpu_id'] = 0
param['max_bin'] = 16
param['tree_method'] = 'gpu_hist'
```
## Benchmarks
To run benchmarks on synthetic data for binary classification:
```bash
$ python benchmark/benchmark.py
```
Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations and 0.25/0.75 test/train split on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
| tree_method | Time (s) |
| --- | --- |
| gpu_hist | 13.87 |
| hist | 63.55 |
| gpu_exact | 161.08 |
| exact | 1082.20 |
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method.
## Test
To run python tests:
```bash
$ python -m nose test/python/
```
Google tests can be enabled by specifying -DGOOGLE_TEST=ON when building with cmake.
## Dependencies
A CUDA capable GPU with at least compute capability >= 3.5
Building the plug-in requires CUDA Toolkit 7.5 or later (https://developer.nvidia.com/cuda-downloads)
## Build
From the command line on Linux starting from the xgboost directory:
On Linux, from the xgboost directory:
```bash
$ mkdir build
$ cd build
$ cmake .. -DPLUGIN_UPDATER_GPU=ON
$ make -j
```
On Windows using cmake, see what options for Generators you have for cmake, and choose one with [arch] replaced by Win64:
```bash
cmake -help
```
Then run cmake as:
```bash
$ mkdir build
$ cd build
$ cmake .. -G"Visual Studio 14 2015 Win64" -DPLUGIN_UPDATER_GPU=ON
```
Cmake will create an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.
Visual studio community 2015, supported by cuda toolkit (http://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/#axzz4isREr2nS), can be downloaded from: https://my.visualstudio.com/Downloads?q=Visual%20Studio%20Community%202015 . You may also be able to use a later version of visual studio depending on whether the CUDA toolkit supports it. Note that Mingw cannot be used with cuda.
### For other nccl libraries
On some systems, nccl libraries are specific to a particular system (IBM Power or nvidia-docker) and can enable use of nvlink (between GPUs or even between GPUs and system memory). In that case, one wants to avoid the static nccl library by changing "STATIC" to "SHARED" in nccl/CMakeLists.txt and deleting the shared nccl library created (so that the system one is used).
### For Developers!
In case you want to build only for a specific GPU(s), for eg. GP100 and GP102,
whose compute capability are 60 and 61 respectively:
```bash
$ cmake .. -DPLUGIN_UPDATER_GPU=ON -DGPU_COMPUTE_VER="60;61"
```
### Using make
Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory
```bash
# make sure CUDA SDK bin directory is in the 'PATH' env variable
$ make -j PLUGIN_UPDATER_GPU=ON
```
Similar to cmake, if you want to build only for a specific GPU(s):
```bash
$ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
```
## Changelog
##### 2017/8/14
* Added GPU accelerated prediction. Considerably improved performance when using test/eval sets.
##### 2017/7/10
* Memory performance improved 4x for gpu_hist
##### 2017/6/26
* Change API to use tree_method parameter
* Increase required cmake version to 3.5
* Add compute arch 3.5 to default archs
* Set default n_gpus to 1
##### 2017/6/5
* Multi-GPU support for histogram method using NVIDIA NCCL.
##### 2017/5/31
* Faster version of the grow_gpu plugin
* Added support for building gpu plugin through 'make' flow too
##### 2017/5/19
* Further performance enhancements for histogram method.
##### 2017/5/5
* Histogram performance improvements
* Fix gcc build issues
##### 2017/4/25
* Add fast histogram algorithm
* Fix Linux build
* Add 'gpu_id' parameter
## References
[Mitchell, Rory, and Eibe Frank. Accelerating the XGBoost algorithm using GPU computing. No. e2911v1. PeerJ Preprints, 2017.](https://peerj.com/preprints/2911/)
## Author
Rory Mitchell
Jonathan C. McKinney
Shankara Rao Thejaswi Nanditale
Vinay Deshpande
... and the rest of the H2O.ai and NVIDIA team.
Please report bugs to the xgboost/issues page.
[The XGBoost GPU documentation has moved here.](https://xgboost.readthedocs.io/en/latest/gpu/index.html)

View File

@ -1,6 +0,0 @@
PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \
build_plugin/updater_gpu/src/updater_gpu.o \
build_plugin/updater_gpu/src/gpu_hist_builder.o \
build_plugin/updater_gpu/src/gpu_predictor.o
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart

View File

@ -8,7 +8,7 @@
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <memory> #include <memory>
#include "device_helpers.cuh" #include "../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
namespace predictor { namespace predictor {

View File

@ -4,7 +4,7 @@
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../../../src/tree/param.h" #include "param.h"
#include "updater_gpu_common.cuh" #include "updater_gpu_common.cuh"
namespace xgboost { namespace xgboost {

View File

@ -7,10 +7,10 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../../../src/common/random.h" #include "../common/random.h"
#include "../../../src/tree/param.h" #include "param.h"
#include "cub/cub.cuh" #include <cub/cub.cuh>
#include "device_helpers.cuh" #include "../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -5,12 +5,11 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../../../src/common/sync.h" #include "param.h"
#include "../../../src/tree/param.h" #include "../common/compressed_iterator.h"
#include "../../src/common/compressed_iterator.h" #include "../common/hist_util.h"
#include "../../src/common/hist_util.h"
#include "updater_gpu_common.cuh" #include "updater_gpu_common.cuh"
#include "device_helpers.cuh" #include "../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -4,7 +4,7 @@
*/ */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include "../../src/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "gtest/gtest.h" #include "gtest/gtest.h"
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size, void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,

View File

@ -5,7 +5,7 @@
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/predictor.h> #include <xgboost/predictor.h>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "../../../../tests/cpp/helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost {
namespace predictor { namespace predictor {

View File

@ -5,17 +5,6 @@ UNITTEST=$(UTEST_ROOT)/xgboost_test
UNITTEST_SRC=$(wildcard $(UTEST_ROOT)/*.cc $(UTEST_ROOT)/*/*.cc) UNITTEST_SRC=$(wildcard $(UTEST_ROOT)/*.cc $(UTEST_ROOT)/*/*.cc)
UNITTEST_OBJ=$(patsubst $(UTEST_ROOT)%.cc, $(UTEST_OBJ_ROOT)%.o, $(UNITTEST_SRC)) UNITTEST_OBJ=$(patsubst $(UTEST_ROOT)%.cc, $(UTEST_OBJ_ROOT)%.o, $(UNITTEST_SRC))
# for if and when we add cuda source files into xgboost core
UNITTEST_CU_SRC=$(wildcard $(UTEST_ROOT)/*.cu $(UTEST_ROOT)/*/*.cu)
UNITTEST_OBJ += $(patsubst $(UTEST_ROOT)%.cu, $(UTEST_OBJ_ROOT)%.o, $(UNITTEST_CU_SRC))
# tests from grow_gpu plugin (only if CUDA path is enabled!)
ifeq ($(PLUGIN_UPDATER_GPU),ON)
GPU_PLUGIN_FOLDER = plugin/updater_gpu
UNITTEST_CU_PLUGIN_SRC = $(wildcard $(GPU_PLUGIN_FOLDER)/test/cpp/*.cu)
UNITTEST_OBJ += $(patsubst %.cu, $(UTEST_OBJ_ROOT)/%.o, $(UNITTEST_CU_PLUGIN_SRC))
endif
GTEST_LIB=$(GTEST_PATH)/lib/ GTEST_LIB=$(GTEST_PATH)/lib/
GTEST_INC=$(GTEST_PATH)/include/ GTEST_INC=$(GTEST_PATH)/include/
@ -26,14 +15,6 @@ UNITTEST_DEPS=lib/libxgboost.a $(DMLC_CORE)/libdmlc.a $(RABIT)/lib/$(LIB_RABIT)
COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ)) COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ))
# the order of the below targets matter! # the order of the below targets matter!
$(UTEST_OBJ_ROOT)/$(GPU_PLUGIN_FOLDER)/test/cpp/%.o: $(GPU_PLUGIN_FOLDER)/test/cpp/%.cu
@mkdir -p $(@D)
$(NVCC) $(NVCC_FLAGS) -I$(GTEST_INC) -o $@ -c $<
$(UTEST_OBJ_ROOT)/%.o: $(UTEST_ROOT)/%.cu
@mkdir -p $(@D)
$(NVCC) $(NVCC_FLAGS) -I$(GTEST_INC) -o $@ -c $<
$(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc $(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $< $(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $<

View File

@ -1,15 +1,14 @@
from __future__ import print_function from __future__ import print_function
#pylint: skip-file #pylint: skip-file
import sys
sys.path.append("../../tests/python")
import xgboost as xgb import xgboost as xgb
import testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest
from nose.plugins.attrib import attr
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@attr('gpu')
class TestGPUPredict (unittest.TestCase): class TestGPUPredict (unittest.TestCase):
def test_predict(self): def test_predict(self):
iterations = 1 iterations = 1
@ -35,3 +34,4 @@ class TestGPUPredict (unittest.TestCase):
def non_decreasing(self, L): def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:])) return all((x - y) < 0.001 for x, y in zip(L, L[1:]))

View File

@ -6,18 +6,17 @@ import xgboost as xgb
import testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest
from nose.plugins.attrib import attr
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
dpath = '../../demo/data/' dpath = 'demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def eprint(*args, **kwargs): def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs) print(*args, file=sys.stderr, **kwargs)
print(*args, file=sys.stdout, **kwargs) print(*args, file=sys.stdout, **kwargs)
@attr('gpu')
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_grow_gpu(self): def test_grow_gpu(self):
tm._skip_if_no_sklearn() tm._skip_if_no_sklearn()
@ -27,6 +26,9 @@ class TestGPU(unittest.TestCase):
except: except:
from sklearn.cross_validation import train_test_split from sklearn.cross_validation import train_test_split
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
ag_param = {'max_depth': 2, ag_param = {'max_depth': 2,
'tree_method': 'exact', 'tree_method': 'exact',
'nthread': 0, 'nthread': 0,
@ -123,6 +125,10 @@ class TestGPU(unittest.TestCase):
except: except:
from sklearn.cross_validation import train_test_split from sklearn.cross_validation import train_test_split
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
for max_depth in range(3,10): # TODO: Doesn't work with 2 for some tests for max_depth in range(3,10): # TODO: Doesn't work with 2 for some tests
#eprint("max_depth=%d" % (max_depth)) #eprint("max_depth=%d" % (max_depth))

View File

@ -7,13 +7,12 @@ import xgboost as xgb
import testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest
from sklearn.datasets import make_classification from nose.plugins.attrib import attr
def eprint(*args, **kwargs): def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush() print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush() print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
eprint("Testing Big Data (this may take a while)")
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy # "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
@ -26,6 +25,7 @@ rows1 = 42360032 # large
rowslist = [rows1, rows2, rows3] rowslist = [rows1, rows2, rows3]
@attr('slow')
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_large(self): def test_large(self):
eprint("Starting test for large data") eprint("Starting test for large data")