From ef23e424f1d07f3eb21bcec900548893906a4531 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 16 Aug 2017 12:31:59 +1200 Subject: [PATCH] [GPU-Plugin] Add GPU accelerated prediction (#2593) * [GPU-Plugin] Add GPU accelerated prediction * Improve allocation message * Update documentation * Resolve linker error for predictor * Add unit tests --- CMakeLists.txt | 2 + Makefile | 1 + doc/parameter.md | 9 +- include/xgboost/predictor.h | 136 ++++-- include/xgboost/tree_model.h | 4 + plugin/updater_gpu/README.md | 60 +-- plugin/updater_gpu/benchmark/benchmark.py | 28 +- plugin/updater_gpu/gitshallow_submodules.sh | 12 - plugin/updater_gpu/plugin.mk | 3 +- plugin/updater_gpu/src/device_helpers.cuh | 43 +- plugin/updater_gpu/src/exact/gpu_builder.cuh | 8 +- plugin/updater_gpu/src/gpu_hist_builder.cu | 21 +- plugin/updater_gpu/src/gpu_predictor.cu | 411 ++++++++++++++++++ .../test/cpp/test_device_helpers.cu | 4 +- .../test/cpp/test_gpu_predictor.cu | 73 ++++ plugin/updater_gpu/test/python/test_large.py | 2 - .../test/python/test_prediction.py | 37 ++ src/gbm/gbtree.cc | 18 +- src/learner.cc | 14 +- src/predictor/cpu_predictor.cc | 42 +- src/predictor/predictor.cc | 46 +- tests/cpp/helpers.cc | 22 + tests/cpp/helpers.h | 15 + tests/cpp/predictor/test_cpu_predictor.cc | 54 +++ tests/cpp/test_learner.cc | 14 + 25 files changed, 876 insertions(+), 203 deletions(-) delete mode 100644 plugin/updater_gpu/gitshallow_submodules.sh create mode 100644 plugin/updater_gpu/src/gpu_predictor.cu create mode 100644 plugin/updater_gpu/test/cpp/test_gpu_predictor.cu create mode 100644 plugin/updater_gpu/test/python/test_prediction.py create mode 100644 tests/cpp/predictor/test_cpu_predictor.cc create mode 100644 tests/cpp/test_learner.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 005ff3cc9..b12ba5662 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,8 @@ file(GLOB_RECURSE CUDA_SOURCES if(PLUGIN_UPDATER_GPU) find_package(CUDA 7.5 REQUIRED) cmake_minimum_required(VERSION 3.5) + + add_definitions(-DXGBOOST_USE_CUDA) include_directories( nccl/src diff --git a/Makefile b/Makefile index b77a5c7f9..9d8c6b388 100644 --- a/Makefile +++ b/Makefile @@ -129,6 +129,7 @@ ifeq ($(PLUGIN_UPDATER_GPU),ON) CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC)))) INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/ LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt + CFLAGS += -DXGBOOST_USE_CUDA endif # specify tensor path diff --git a/doc/parameter.md b/doc/parameter.md index 31f5dd8f8..b2fcabbae 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -56,7 +56,7 @@ Parameters for Tree Booster * tree_method, string [default='auto'] - The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754)) - Distributed and external memory version only support approximate algorithm. - - Choices: {'auto', 'exact', 'approx', 'hist'} + - Choices: {'auto', 'exact', 'approx', 'hist', 'gpu_exact', 'gpu_hist'} - 'auto': Use heuristic to choose faster one. - For small to medium dataset, exact greedy will be used. - For very large-dataset, approximate algorithm will be chosen. @@ -65,6 +65,8 @@ Parameters for Tree Booster - 'exact': Exact greedy algorithm. - 'approx': Approximate greedy algorithm using sketching and histogram. - 'hist': Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching. + - 'gpu_exact': GPU implementation of exact algorithm. + - 'gpu_hist': GPU implementation of hist algorithm. * sketch_eps, [default=0.03] - This is only used for approximate greedy algorithm. - This roughly translated into ```O(1 / sketch_eps)``` number of bins. @@ -107,7 +109,10 @@ Parameters for Tree Booster - This is only used if 'hist' is specified as `tree_method`. - Maximum number of discrete bins to bucket continuous features. - Increasing this number improves the optimality of splits at the cost of higher computation time. - +* predictor, [default='cpu_predictor'] + - The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU. + - 'cpu_predictor': Multicore CPU prediction algorithm. + - 'gpu_predictor': Prediction using GPU. Default for 'gpu_exact' and 'gpu_hist' tree method. Additional parameters for Dart Booster -------------------------------------- * sample_type [default="uniform"] diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index bc37f66d7..8f0104f5a 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -6,33 +6,31 @@ */ #pragma once #include +#include #include #include -#include #include +#include +#include #include "../../src/gbm/gbtree_model.h" // Forward declarations namespace xgboost { -class DMatrix; class TreeUpdater; } -namespace xgboost { -namespace gbm { -struct GBTreeModel; -} -} // namespace xgboost namespace xgboost { /** * \class Predictor * - * \brief Performs prediction on individual training instances or batches of instances for GBTree. - * The predictor also manages a prediction cache associated with input matrices. If possible, - * it will use previously calculated predictions instead of calculating new predictions. - * Prediction functions all take a GBTreeModel and a DMatrix as input and output a vector of - * predictions. The predictor does not modify any state of the model itself. + * \brief Performs prediction on individual training instances or batches of + * instances for GBTree. The predictor also manages a prediction cache + * associated with input matrices. If possible, it will use previously + * calculated predictions instead of calculating new predictions. + * Prediction functions all take a GBTreeModel and a DMatrix as input and + * output a vector of predictions. The predictor does not modify any state of + * the model itself. */ class Predictor { @@ -40,36 +38,47 @@ class Predictor { virtual ~Predictor() {} /** - * \fn void Predictor::InitCache(const std::vector > &cache); + * \fn virtual void Predictor::Init(const std::vector >&cfg ,const std::vector > &cache); * - * \brief Register input matrices in prediction cache. + * \brief Configure and register input matrices in prediction cache. * + * \param cfg The configuration. * \param cache Vector of DMatrix's to be used in prediction. */ - void InitCache(const std::vector > &cache); + virtual void Init(const std::vector>& cfg, + const std::vector>& cache); /** - * \fn virtual void Predictor::PredictBatch( DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel &model, int tree_begin, unsigned ntree_limit = 0) = 0; + * \fn virtual void Predictor::PredictBatch( DMatrix* dmat, + * std::vector* out_preds, const gbm::GBTreeModel &model, int + * tree_begin, unsigned ntree_limit = 0) = 0; * - * \brief Generate batch predictions for a given feature matrix. May use cached predictions if available instead of calculating from scratch. + * \brief Generate batch predictions for a given feature matrix. May use + * cached predictions if available instead of calculating from scratch. * * \param [in,out] dmat Feature matrix. * \param [in,out] out_preds The output preds. * \param model The model to predict from. * \param tree_begin The tree begin index. - * \param ntree_limit (Optional) The ntree limit. 0 means do not limit trees. + * \param ntree_limit (Optional) The ntree limit. 0 means do not + * limit trees. */ - virtual void PredictBatch( - DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel &model, - int tree_begin, unsigned ntree_limit = 0) = 0; + virtual void PredictBatch(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) = 0; /** - * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel &model, std::vector >* updaters, int num_new_trees) = 0; + * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel + * &model, std::vector >* updaters, int + * num_new_trees) = 0; * - * \brief Update the internal prediction cache using newly added trees. Will use the tree updater - * to do this if possible. Should be called as a part of the tree boosting process to facilitate the look up of predictions at a later time. + * \brief Update the internal prediction cache using newly added trees. Will + * use the tree updater to do this if possible. Should be called as a part of + * the tree boosting process to facilitate the look up of predictions + * at a later time. * * \param model The model. * \param [in,out] updaters The updater sequence for gradient boosting. @@ -77,15 +86,19 @@ class Predictor { */ virtual void UpdatePredictionCache( - const gbm::GBTreeModel &model, std::vector >* updaters, + const gbm::GBTreeModel& model, + std::vector>* updaters, int num_new_trees) = 0; /** - * \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0; + * \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst& + * inst, std::vector* out_preds, const gbm::GBTreeModel& model, + * unsigned ntree_limit = 0, unsigned root_index = 0) = 0; * - * \brief online prediction function, predict score for one instance at a time NOTE: use the batch - * prediction interface if possible, batch prediction is usually more efficient than online - * prediction This function is NOT threadsafe, make sure you only call from one thread. + * \brief online prediction function, predict score for one instance at a time + * NOTE: use the batch prediction interface if possible, batch prediction is + * usually more efficient than online prediction This function is NOT + * threadsafe, make sure you only call from one thread. * * \param inst The instance to predict. * \param [in,out] out_preds The output preds. @@ -94,15 +107,19 @@ class Predictor { * \param root_index (Optional) Zero-based index of the root. */ - virtual void PredictInstance( - const SparseBatch::Inst& inst, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0; + virtual void PredictInstance(const SparseBatch::Inst& inst, + std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit = 0, + unsigned root_index = 0) = 0; /** - * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, + * std::vector* out_preds, const gbm::GBTreeModel& model, unsigned + * ntree_limit = 0) = 0; * - * \brief predict the leaf index of each tree, the output will be nsample * ntree vector this is - * only valid in gbtree predictor. + * \brief predict the leaf index of each tree, the output will be nsample * + * ntree vector this is only valid in gbtree predictor. * * \param [in,out] dmat The input feature matrix. * \param [in,out] out_preds The output preds. @@ -111,13 +128,17 @@ class Predictor { */ virtual void PredictLeaf(DMatrix* dmat, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + const gbm::GBTreeModel& model, + unsigned ntree_limit = 0) = 0; /** - * \fn virtual void Predictor::PredictContribution( DMatrix* dmat, std::vector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + * \fn virtual void Predictor::PredictContribution( DMatrix* dmat, + * std::vector* out_contribs, const gbm::GBTreeModel& model, + * unsigned ntree_limit = 0) = 0; * - * \brief feature contributions to individual predictions; the output will be a vector of length - * (nfeats + 1) * num_output_group * nsample, arranged in that order. + * \brief feature contributions to individual predictions; the output will be + * a vector of length (nfeats + 1) * num_output_group * nsample, arranged in + * that order. * * \param [in,out] dmat The input feature matrix. * \param [in,out] out_contribs The output feature contribs. @@ -125,9 +146,10 @@ class Predictor { * \param ntree_limit (Optional) The ntree limit. */ - virtual void PredictContribution( - DMatrix* dmat, std::vector* out_contribs, - const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + virtual void PredictContribution(DMatrix* dmat, + std::vector* out_contribs, + const gbm::GBTreeModel& model, + unsigned ntree_limit = 0) = 0; /** * \fn static Predictor* Predictor::Create(std::string name); @@ -139,6 +161,32 @@ class Predictor { static Predictor* Create(std::string name); protected: + /** + * \fn bool PredictFromCache(DMatrix* dmat, std::vector* + * out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) + * + * \brief Attempt to predict from cache. + * + * \return True if it succeeds, false if it fails. + */ + bool PredictFromCache(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit = 0); + + /** + * \fn void Predictor::InitOutPredictions(const MetaInfo& info, + * std::vector* out_preds, const gbm::GBTreeModel& model) const; + * + * \brief Init out predictions according to base margin. + * + * \param info Dmatrix info possibly containing base margin. + * \param [in,out] out_preds The out preds. + * \param model The model. + */ + void InitOutPredictions(const MetaInfo& info, + std::vector* out_preds, + const gbm::GBTreeModel& model) const; + /** * \struct PredictionCacheEntry * @@ -151,8 +199,8 @@ class Predictor { }; /** - * \brief Map of matrices and associated cached predictions to facilitate storing and looking up - * predictions. + * \brief Map of matrices and associated cached predictions to facilitate + * storing and looking up predictions. */ std::unordered_map cache_; diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 74cbed2e2..3cf33780d 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -271,6 +271,10 @@ class TreeModel { inline const Node& operator[](int nid) const { return nodes[nid]; } + + /*! \brief get const reference to nodes */ + inline const std::vector& GetNodes() const { return nodes; } + /*! \brief get node statistics given nid */ inline NodeStat& stat(int nid) { return stats[nid]; diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index baeaf710f..0cae6f716 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -1,5 +1,5 @@ # CUDA Accelerated Tree Construction Algorithms -This plugin adds GPU accelerated tree construction algorithms to XGBoost. +This plugin adds GPU accelerated tree construction and prediction algorithms to XGBoost. ## Usage Specify the 'tree_method' parameter as one of the following algorithms. @@ -18,6 +18,9 @@ colsample_bylevel | ✔ | ✔ | max_bin | ✖ | ✔ | gpu_id | ✔ | ✔ | n_gpus | ✖ | ✔ | +predictor | ✔ | ✔ | + +GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0. @@ -37,48 +40,31 @@ To run benchmarks on synthetic data for binary classification: $ python benchmark/benchmark.py ``` -Training time time on 1000000 rows x 50 columns with 500 boosting iterations on i7-6700K CPU @ 4.00GHz and Pascal Titan X. +Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations and 0.25/0.75 test/train split on i7-6700K CPU @ 4.00GHz and Pascal Titan X. | tree_method | Time (s) | | --- | --- | -| gpu_hist | 11.09 | -| hist (histogram XGBoost - CPU) | 41.75 | -| gpu_exact | 193.90 | -| exact (standard XGBoost - CPU) | 720.12 | +| gpu_hist | 13.87 | +| hist | 63.55 | +| gpu_exact | 161.08 | +| exact | 1082.20 | [See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method. ## Test -To run tests:Will +To run python tests: ```bash $ python -m nose test/python/ ``` + +Google tests can be enabled by specifying -DGOOGLE_TEST=ON when building with cmake. + ## Dependencies -A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). +A CUDA capable GPU with at least compute capability >= 3.5 Building the plug-in requires CUDA Toolkit 7.5 or later (https://developer.nvidia.com/cuda-downloads) -submodule: The plugin also depends on CUB 1.6.4 - https://nvlabs.github.io/cub/ . CUB is a header only cuda library which provides sort/reduce/scan primitives. - -submodule: NVIDIA NCCL from https://github.com/NVIDIA/nccl with windows port allowed by git@github.com:h2oai/nccl.git - -## Download full repo + full submodules for your choice (or empty) path - -git clone --recursive https://github.com/dmlc/xgboost.git - -## Download with shallow submodules for much quicker download: - -git 2.9.0+ (assumes only HEAD used for all submodules, but not true currently for dmlc-core and rabbit) - -git clone --recursive --shallow-submodules https://github.com/dmlc/xgboost.git - -git 2.9.0-: (only cub is shallow, as largest repo) - -git clone https://github.com/dmlc/xgboost.git -cd -bash plugin/updater/gpu/gitshallow_submodules.sh - ## Build From the command line on Linux starting from the xgboost directory: @@ -110,14 +96,11 @@ On some systems, nccl libraries are specific to a particular system (IBM Power o ### For Developers! - - In case you want to build only for a specific GPU(s), for eg. GP100 and GP102, whose compute capability are 60 and 61 respectively: ```bash $ cmake .. -DPLUGIN_UPDATER_GPU=ON -DGPU_COMPUTE_VER="60;61" ``` -By default, the versions will include support for all GPUs in Maxwell and Pascal architectures. ### Using make Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory @@ -131,19 +114,10 @@ Similar to cmake, if you want to build only for a specific GPU(s): $ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61" ``` -### For Developers! - -Now, some of the code-base inside gpu plugins have googletest unit-tests inside 'tests/'. -They can be enabled run along with other unit-tests inside '/tests/cpp' using: -```bash -# make sure CUDA SDK bin directory is in the 'PATH' env variable -# below 2 commands need only be executed once -$ source ./dmlc-core/scripts/travis/travis_setup_env.sh -$ make -f dmlc-core/scripts/packages.mk gtest -$ make PLUGIN_UPDATER_GPU=ON GTEST_PATH=${CACHE_PREFIX} test -``` - ## Changelog +##### 2017/8/14 +* Added GPU accelerated prediction. Considerably improved performance when using test/eval sets. + ##### 2017/7/10 * Memory performance improved 4x for gpu_hist diff --git a/plugin/updater_gpu/benchmark/benchmark.py b/plugin/updater_gpu/benchmark/benchmark.py index 41615563a..e0b51a934 100644 --- a/plugin/updater_gpu/benchmark/benchmark.py +++ b/plugin/updater_gpu/benchmark/benchmark.py @@ -3,19 +3,22 @@ import sys, argparse import xgboost as xgb import numpy as np from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split import time + def run_benchmark(args, gpu_algorithm, cpu_algorithm): - print("Generating dataset: {} rows * {} columns".format(args.rows,args.columns)) + print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) + print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) tmp = time.time() X, y = make_classification(args.rows, n_features=args.columns, random_state=7) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7) print ("Generate Time: %s seconds" % (str(time.time() - tmp))) tmp = time.time() print ("DMatrix Start") # omp way - dtrain = xgb.DMatrix(X, y, nthread=-1) - # non-omp way - #dtrain = xgb.DMatrix(X, y) + dtrain = xgb.DMatrix(X_train, y_train, nthread=-1) + dtest = xgb.DMatrix(X_test, y_test, nthread=-1) print ("DMatrix Time: %s seconds" % (str(time.time() - tmp))) param = {'objective': 'binary:logistic', @@ -23,28 +26,30 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm): 'silent': 0, 'n_gpus': 1, 'gpu_id': 0, - 'eval_metric': 'auc'} + 'eval_metric': 'error', + 'debug_verbose': 0, + } param['tree_method'] = gpu_algorithm print("Training with '%s'" % param['tree_method']) tmp = time.time() - xgb.train(param, dtrain, args.iterations) + xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")]) print ("Train Time: %s seconds" % (str(time.time() - tmp))) param['silent'] = 1 param['tree_method'] = cpu_algorithm print("Training with '%s'" % param['tree_method']) tmp = time.time() - xgb.train(param, dtrain, args.iterations) + xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")]) print ("Time: %s seconds" % (str(time.time() - tmp))) - parser = argparse.ArgumentParser() parser.add_argument('--algorithm', choices=['all', 'gpu_exact', 'gpu_hist'], default='all') -parser.add_argument('--rows',type=int,default=1000000) -parser.add_argument('--columns',type=int,default=50) -parser.add_argument('--iterations',type=int,default=500) +parser.add_argument('--rows', type=int, default=1000000) +parser.add_argument('--columns', type=int, default=50) +parser.add_argument('--iterations', type=int, default=500) +parser.add_argument('--test_size', type=float, default=0.25) args = parser.parse_args() if 'gpu_hist' in args.algorithm: @@ -54,4 +59,3 @@ elif 'gpu_exact' in args.algorithm: elif 'all' in args.algorithm: run_benchmark(args, 'gpu_exact', 'exact') run_benchmark(args, 'gpu_hist', 'hist') - diff --git a/plugin/updater_gpu/gitshallow_submodules.sh b/plugin/updater_gpu/gitshallow_submodules.sh deleted file mode 100644 index 68d3e36e5..000000000 --- a/plugin/updater_gpu/gitshallow_submodules.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -git submodule init -for i in $(git submodule | awk '{print $2}'); do - spath=$(git config -f .gitmodules --get submodule.$i.path) - surl=$(git config -f .gitmodules --get submodule.$i.url) - if [ $spath == "cub" ] - then - git submodule update --depth 3 $spath - else - git submodule update $spath - fi -done diff --git a/plugin/updater_gpu/plugin.mk b/plugin/updater_gpu/plugin.mk index b4fb52e3e..e8c642ebd 100644 --- a/plugin/updater_gpu/plugin.mk +++ b/plugin/updater_gpu/plugin.mk @@ -1,5 +1,6 @@ PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \ build_plugin/updater_gpu/src/updater_gpu.o \ - build_plugin/updater_gpu/src/gpu_hist_builder.o + build_plugin/updater_gpu/src/gpu_hist_builder.o \ + build_plugin/updater_gpu/src/gpu_predictor.o PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index eedfdea48..f1dd06c4b 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -2,7 +2,7 @@ * Copyright 2017 XGBoost contributors */ #pragma once -#include +#include #include #include #include @@ -121,6 +121,28 @@ inline std::string device_name(int device_idx) { return std::string(prop.name); } +inline size_t available_memory(int device_idx) { + size_t device_free = 0; + size_t device_total = 0; + safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); + return device_free; +} + +/** + * \fn inline int max_shared_memory(int device_idx) + * + * \brief Maximum shared memory per block on this device. + * + * \param device_idx Zero-based index of the device. + */ + +inline int max_shared_memory(int device_idx) { + cudaDeviceProp prop; + dh::safe_cuda(cudaGetDeviceProperties(&prop, device_idx)); + return prop.sharedMemPerBlock; +} + // ensure gpu_id is correct, so not dependent upon user knowing details inline int get_device_idx(int gpu_id) { // protect against overrun for gpu_id @@ -215,7 +237,7 @@ __device__ range block_stride_range(T begin, T end) { return r; } -// Threadblock iterates over range, filling with value +// Threadblock iterates over range, filling with value. Requires all threads in block to be active. template __device__ void block_fill(IterT begin, size_t n, ValueT value) { for (auto i : block_stride_range(static_cast(0), n)) { @@ -463,7 +485,7 @@ class bulk_allocator { } template - void allocate(int device_idx, Args... args) { + void allocate(int device_idx, bool silent ,Args... args) { size_t size = get_size_bytes(args...); char *ptr = allocate_device(device_idx, size, MemoryT); @@ -473,6 +495,14 @@ class bulk_allocator { d_ptr.push_back(ptr); _size.push_back(size); _device_idx.push_back(device_idx); + + if(!silent) + { + const int mb_size = 1048576; + LOG(CONSOLE) << "Allocated " << size / mb_size << "MB on [" << device_idx + << "] " << device_name(device_idx) << ", " + << available_memory(device_idx) / mb_size << "MB remaining."; + } } }; @@ -515,13 +545,6 @@ struct CubMemory { bool IsAllocated() { return d_temp_storage != NULL; } }; -inline size_t available_memory(int device_idx) { - size_t device_free = 0; - size_t device_total = 0; - safe_cuda(cudaSetDevice(device_idx)); - dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); - return device_free; -} /* * Utility functions diff --git a/plugin/updater_gpu/src/exact/gpu_builder.cuh b/plugin/updater_gpu/src/exact/gpu_builder.cuh index 0ebe451c1..5de2f3841 100644 --- a/plugin/updater_gpu/src/exact/gpu_builder.cuh +++ b/plugin/updater_gpu/src/exact/gpu_builder.cuh @@ -232,7 +232,7 @@ class GPUBuilder { void allocateAllData(int offsetSize) { int tmpBuffSize = scanTempBufferSize(nVals); - ba.allocate(dh::get_device_idx(param.gpu_id), &vals, nVals, &vals_cached, + ba.allocate(dh::get_device_idx(param.gpu_id), param.silent, &vals, nVals, &vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals, &nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst, @@ -252,12 +252,6 @@ class GPUBuilder { allocateAllData((int)offset.size()); transferAndSortData(fval, fId, offset); allocated = true; - if (!param.silent) { - const int mb_size = 1048576; - LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" - << free_memory / mb_size << " MB on " - << dh::device_name(dh::get_device_idx(param.gpu_id)); - } } void convertToCsc(DMatrix& hMat, std::vector& fval, diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index 5a89a15b8..e7109550a 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -127,16 +127,6 @@ void GPUHistBuilder::Init(const TrainParam& param) { this->param = param; CHECK(param.n_gpus != 0) << "Must have at least one device"; - int n_devices_all = dh::n_devices_all(param.n_gpus); - for (int device_idx = 0; device_idx < n_devices_all; device_idx++) { - if (!param.silent) { - size_t free_memory = dh::available_memory(device_idx); - const int mb_size = 1048576; - LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] " - << dh::device_name(device_idx) << " with " - << free_memory / mb_size << " MB available device memory."; - } - } } void GPUHistBuilder::InitData(const std::vector& gpair, DMatrix& fmat, // NOLINT @@ -210,9 +200,6 @@ void GPUHistBuilder::InitData(const std::vector& gpair, // process) } - CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column " - "block. Try setting 'tree_method' " - "parameter to 'exact'"; is_dense = info->num_nonzero == info->num_col * info->num_row; dh::Timer time0; hmat_.Init(&fmat, param.max_bin); @@ -326,7 +313,8 @@ void GPUHistBuilder::InitData(const std::vector& gpair, bst_ulong num_elements_segment = device_element_segments[d_idx + 1] - device_element_segments[d_idx]; ba.allocate( - device_idx, &(hist_vec[d_idx].data), + device_idx, param.silent, + &(hist_vec[d_idx].data), n_nodes(param.max_depth - 1) * n_bins, &nodes[d_idx], n_nodes(param.max_depth), &nodes_temp[d_idx], max_num_nodes_device, &nodes_child_temp[d_idx], max_num_nodes_device, @@ -367,11 +355,6 @@ void GPUHistBuilder::InitData(const std::vector& gpair, feature_flags[d_idx].fill(1); // init device object (assumes comes after // ba.allocate that sets device) } - - if (!param.silent) { - const int mb_size = 1048576; - LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB"; - } } // copy or init to do every iteration diff --git a/plugin/updater_gpu/src/gpu_predictor.cu b/plugin/updater_gpu/src/gpu_predictor.cu new file mode 100644 index 000000000..32f92a1b4 --- /dev/null +++ b/plugin/updater_gpu/src/gpu_predictor.cu @@ -0,0 +1,411 @@ +/*! + * Copyright by Contributors 2017 + */ +#include +#include +#include +#include +#include +#include +#include +#include "device_helpers.cuh" + +namespace xgboost { +namespace predictor { + +DMLC_REGISTRY_FILE_TAG(gpu_predictor); + +/*! \brief prediction parameters */ +struct GPUPredictionParam : public dmlc::Parameter { + int gpu_id; + int n_gpus; + bool silent; + // declare parameters + DMLC_DECLARE_PARAMETER(GPUPredictionParam) { + DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe( + "Device ordinal for GPU prediction."); + DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe( + "Number of devices to use for prediction (NOT IMPLEMENTED)."); + DMLC_DECLARE_FIELD(silent).set_default(false).describe( + "Do not print information during trainig."); + } +}; +DMLC_REGISTER_PARAMETER(GPUPredictionParam); + +template +void increment_offset(iter_t begin_itr, iter_t end_itr, size_t amount) { + thrust::transform(begin_itr, end_itr, begin_itr, + [=] __device__(size_t elem) { return elem + amount; }); +} + +/** + * \struct DeviceMatrix + * + * \brief A csr representation of the input matrix allocated on the device. + */ + +struct DeviceMatrix { + DMatrix* p_mat; // Pointer to the original matrix on the host + dh::bulk_allocator ba; + dh::dvec row_ptr; + dh::dvec data; + thrust::device_vector predictions; + + DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) { + dh::safe_cuda(cudaSetDevice(device_idx)); + auto info = dmat->info(); + ba.allocate(device_idx, silent, &row_ptr, info.num_row + 1, &data, + info.num_nonzero); + auto iter = dmat->RowIterator(); + iter->BeforeFirst(); + size_t data_offset = 0; + while (iter->Next()) { + auto batch = iter->Value(); + // Copy row ptr + thrust::copy(batch.ind_ptr, batch.ind_ptr + batch.size + 1, + row_ptr.tbegin() + batch.base_rowid); + if (batch.base_rowid > 0) { + auto begin_itr = row_ptr.tbegin() + batch.base_rowid; + auto end_itr = begin_itr + batch.size + 1; + increment_offset(begin_itr, end_itr, batch.base_rowid); + } + // Copy data + thrust::copy(batch.data_ptr, batch.data_ptr + batch.ind_ptr[batch.size], + data.tbegin() + data_offset); + data_offset += batch.ind_ptr[batch.size]; + } + } +}; + +/** + * \struct DevicePredictionNode + * + * \brief Packed 16 byte representation of a tree node for use in device + * prediction + */ + +struct DevicePredictionNode { + XGBOOST_DEVICE DevicePredictionNode() + : fidx(-1), left_child_idx(-1), right_child_idx(-1) {} + + union NodeValue { + float leaf_weight; + float fvalue; + }; + + int fidx; + int left_child_idx; + int right_child_idx; + NodeValue val; + + DevicePredictionNode(const RegTree::Node& n) { // NOLINT + this->left_child_idx = n.cleft(); + this->right_child_idx = n.cright(); + this->fidx = n.split_index(); + if (n.default_left()) { + fidx |= (1U << 31); + } + + if (n.is_leaf()) { + this->val.leaf_weight = n.leaf_value(); + } else { + this->val.fvalue = n.split_cond(); + } + } + + XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; } + + XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); } + + XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; } + + XGBOOST_DEVICE int MissingIdx() const { + if (MissingLeft()) { + return this->left_child_idx; + } else { + return this->right_child_idx; + } + } + + XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; } + + XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; } +}; + +struct ElementLoader { + bool use_shared; + size_t* d_row_ptr; + SparseBatch::Entry* d_data; + int num_features; + float* smem; + + __device__ ElementLoader(bool use_shared, size_t* row_ptr, + SparseBatch::Entry* entry, int num_features, + float* smem, int num_rows) + : use_shared(use_shared), + d_row_ptr(row_ptr), + d_data(entry), + num_features(num_features), + smem(smem) { + // Copy instances + if (use_shared) { + bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; + int shared_elements = blockDim.x * num_features; + dh::block_fill(smem, shared_elements, nanf("")); + __syncthreads(); + if (global_idx < num_rows) { + bst_uint elem_begin = d_row_ptr[global_idx]; + bst_uint elem_end = d_row_ptr[global_idx + 1]; + for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { + SparseBatch::Entry elem = d_data[elem_idx]; + smem[threadIdx.x * num_features + elem.index] = elem.fvalue; + } + } + __syncthreads(); + } + } + __device__ float GetFvalue(int ridx, int fidx) { + if (use_shared) { + return smem[threadIdx.x * num_features + fidx]; + } else { + // Binary search + auto begin_ptr = d_data + d_row_ptr[ridx]; + auto end_ptr = d_data + d_row_ptr[ridx + 1]; + SparseBatch::Entry* previous_middle = nullptr; + while (end_ptr != begin_ptr) { + auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; + if (middle == previous_middle) { + break; + } else { + previous_middle = middle; + } + + if (middle->index == fidx) { + return middle->fvalue; + } else if (middle->index < fidx) { + begin_ptr = middle; + } else { + end_ptr = middle; + } + } + // Value is missing + return nanf(""); + } + } +}; + +__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree, + ElementLoader* loader) { + DevicePredictionNode n = tree[0]; + while (!n.IsLeaf()) { + float fvalue = loader->GetFvalue(ridx, n.GetFidx()); + // Missing value + if (isnan(fvalue)) { + n = tree[n.MissingIdx()]; + } else { + if (fvalue < n.GetFvalue()) { + n = tree[n.left_child_idx]; + } else { + n = tree[n.right_child_idx]; + } + } + } + return n.GetWeight(); +} + +template +__global__ void PredictKernel(const DevicePredictionNode* d_nodes, + float* d_out_predictions, int* d_tree_segments, + int* d_tree_group, size_t* d_row_ptr, + SparseBatch::Entry* d_data, int tree_begin, + int tree_end, int num_features, bst_uint num_rows, + bool use_shared, int num_group) { + extern __shared__ float smem[]; + bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; + ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem, + num_rows); + if (global_idx >= num_rows) return; + if (num_group == 1) { + float sum = 0; + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + const DevicePredictionNode* d_tree = + d_nodes + d_tree_segments[tree_idx - tree_begin]; + sum += GetLeafWeight(global_idx, d_tree, &loader); + } + d_out_predictions[global_idx] += sum; + } else { + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + int tree_group = d_tree_group[tree_idx]; + const DevicePredictionNode* d_tree = + d_nodes + d_tree_segments[tree_idx - tree_begin]; + bst_uint out_prediction_idx = global_idx * num_group + tree_group; + d_out_predictions[out_prediction_idx] += + GetLeafWeight(global_idx, d_tree, &loader); + } + } +} + +class GPUPredictor : public xgboost::Predictor { + private: + void DevicePredictInternal(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + int tree_end) { + if (tree_end - tree_begin == 0) { + return; + } + + // Add dmatrix to device if not seen before + if (this->device_matrix_cache_.find(dmat) == + this->device_matrix_cache_.end()) { + this->device_matrix_cache_.emplace( + dmat, std::unique_ptr( + new DeviceMatrix(dmat, param.gpu_id, param.silent))); + } + DeviceMatrix* device_matrix = device_matrix_cache_.find(dmat)->second.get(); + + dh::safe_cuda(cudaSetDevice(param.gpu_id)); + CHECK_EQ(model.param.size_leaf_vector, 0); + // Copy decision trees to device + thrust::host_vector h_tree_segments; + h_tree_segments.reserve((tree_end - tree_end) + 1); + int sum = 0; + h_tree_segments.push_back(sum); + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + sum += model.trees[tree_idx]->GetNodes().size(); + h_tree_segments.push_back(sum); + } + + thrust::host_vector h_nodes(h_tree_segments.back()); + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + auto& src_nodes = model.trees[tree_idx]->GetNodes(); + std::copy(src_nodes.begin(), src_nodes.end(), + h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); + } + + nodes.resize(h_nodes.size()); + thrust::copy(h_nodes.begin(), h_nodes.end(), nodes.begin()); + tree_segments.resize(h_tree_segments.size()); + thrust::copy(h_tree_segments.begin(), h_tree_segments.end(), + tree_segments.begin()); + tree_group.resize(model.tree_info.size()); + thrust::copy(model.tree_info.begin(), model.tree_info.end(), + tree_group.begin()); + + if (device_matrix->predictions.size() != out_preds->size()) { + device_matrix->predictions.resize(out_preds->size()); + thrust::copy(out_preds->begin(), out_preds->end(), + device_matrix->predictions.begin()); + } + + const int BLOCK_THREADS = 128; + const int GRID_SIZE = + dh::div_round_up(device_matrix->row_ptr.size() - 1, BLOCK_THREADS); + + int shared_memory_bytes = + sizeof(float) * device_matrix->p_mat->info().num_col * BLOCK_THREADS; + bool use_shared = true; + if (shared_memory_bytes > dh::max_shared_memory(param.gpu_id)) { + shared_memory_bytes = 0; + use_shared = false; + } + + PredictKernel + <<>>( + dh::raw(nodes), dh::raw(device_matrix->predictions), + dh::raw(tree_segments), dh::raw(tree_group), + device_matrix->row_ptr.data(), device_matrix->data.data(), + tree_begin, tree_end, device_matrix->p_mat->info().num_col, + device_matrix->p_mat->info().num_row, use_shared, + model.param.num_output_group); + + dh::safe_cuda(cudaDeviceSynchronize()); + thrust::copy(device_matrix->predictions.begin(), + device_matrix->predictions.end(), out_preds->begin()); + } + + public: + GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {} + + void PredictBatch(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) override { + if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { + return; + } + this->InitOutPredictions(dmat->info(), out_preds, model); + + int tree_end = ntree_limit * model.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model.trees.size()) { + tree_end = static_cast(model.trees.size()); + } + + DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); + } + + void UpdatePredictionCache( + const gbm::GBTreeModel& model, + std::vector>* updaters, + int num_new_trees) override { + // dh::Timer t; + int old_ntree = model.trees.size() - num_new_trees; + // update cache entry + for (auto& kv : cache_) { + PredictionCacheEntry& e = kv.second; + DMatrix* dmat = kv.first; + + if (e.predictions.size() == 0) { + cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0, + model.trees.size()); + } else if (model.param.num_output_group == 1 && updaters->size() > 0 && + num_new_trees == 1 && + updaters->back()->UpdatePredictionCache(e.data.get(), + &(e.predictions))) { + {} // do nothing + } else { + DevicePredictInternal(dmat, &(e.predictions), model, old_ntree, + model.trees.size()); + } + } + } + + void PredictInstance(const SparseBatch::Inst& inst, + std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit, + unsigned root_index) override { + cpu_predictor->PredictInstance(inst, out_preds, model, root_index); + } + void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit) override { + cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); + } + + void PredictContribution(DMatrix* p_fmat, + std::vector* out_contribs, + const gbm::GBTreeModel& model, + unsigned ntree_limit) override { + cpu_predictor->PredictContribution(p_fmat, out_contribs, model, + ntree_limit); + } + + void Init(const std::vector>& cfg, + const std::vector>& cache) override { + Predictor::Init(cfg, cache); + cpu_predictor->Init(cfg, cache); + param.InitAllowUnknown(cfg); + } + + private: + GPUPredictionParam param; + std::unique_ptr cpu_predictor; + std::unordered_map> + device_matrix_cache_; + thrust::device_vector nodes; + thrust::device_vector tree_segments; + thrust::device_vector tree_group; +}; +XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") + .describe("Make predictions using GPU.") + .set_body([]() { return new GPUPredictor(); }); +} // namespace predictor +} // namespace xgboost diff --git a/plugin/updater_gpu/test/cpp/test_device_helpers.cu b/plugin/updater_gpu/test/cpp/test_device_helpers.cu index 6bd520b71..4a5528adf 100644 --- a/plugin/updater_gpu/test/cpp/test_device_helpers.cu +++ b/plugin/updater_gpu/test/cpp/test_device_helpers.cu @@ -37,7 +37,7 @@ void SpeedTest() { dh::Timer t; dh::TransformLbs( - 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, + 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false, [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); dh::safe_cuda(cudaDeviceSynchronize()); @@ -65,7 +65,7 @@ void TestLbs() { auto d_output_row = output_row.data(); dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr), - row_ptr.size() - 1, + row_ptr.size() - 1, false, [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); diff --git a/plugin/updater_gpu/test/cpp/test_gpu_predictor.cu b/plugin/updater_gpu/test/cpp/test_gpu_predictor.cu new file mode 100644 index 000000000..8e0063823 --- /dev/null +++ b/plugin/updater_gpu/test/cpp/test_gpu_predictor.cu @@ -0,0 +1,73 @@ + +/*! + * Copyright 2017 XGBoost contributors + */ +#include +#include +#include "gtest/gtest.h" +#include "../../../../tests/cpp/helpers.h" + +namespace xgboost { +namespace predictor { +TEST(gpu_predictor, Test) { + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor")); + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor")); + + std::vector> trees; + trees.push_back(std::make_unique()); + trees.back()->InitModel(); + (*trees.back())[0].set_leaf(1.5f); + gbm::GBTreeModel model(0.5); + model.CommitModel(std::move(trees), 0); + model.param.num_output_group = 1; + + int n_row = 5; + int n_col = 5; + + auto dmat = CreateDMatrix(n_row, n_col, 0); + + // Test predict batch + std::vector gpu_out_predictions; + std::vector cpu_out_predictions; + gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0); + cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0); + float abs_tolerance = 0.001; + for (int i = 0; i < gpu_out_predictions.size(); i++) { + ASSERT_LT(std::abs(gpu_out_predictions[i] - cpu_out_predictions[i]), + abs_tolerance); + } + + // Test predict instance + auto batch = dmat->RowIterator()->Value(); + for (int i = 0; i < batch.size; i++) { + std::vector gpu_instance_out_predictions; + std::vector cpu_instance_out_predictions; + cpu_predictor->PredictInstance(batch[i], &cpu_instance_out_predictions, + model); + gpu_predictor->PredictInstance(batch[i], &gpu_instance_out_predictions, + model); + ASSERT_EQ(gpu_instance_out_predictions[0], cpu_instance_out_predictions[0]); + } + + // Test predict leaf + std::vector gpu_leaf_out_predictions; + std::vector cpu_leaf_out_predictions; + cpu_predictor->PredictLeaf(dmat.get(), &cpu_leaf_out_predictions, model); + gpu_predictor->PredictLeaf(dmat.get(), &gpu_leaf_out_predictions, model); + for (int i = 0; i < gpu_leaf_out_predictions.size(); i++) { + ASSERT_EQ(gpu_leaf_out_predictions[i], cpu_leaf_out_predictions[i]); + } + + // Test predict contribution + std::vector gpu_out_contribution; + std::vector cpu_out_contribution; + cpu_predictor->PredictContribution(dmat.get(), &cpu_out_contribution, model); + gpu_predictor->PredictContribution(dmat.get(), &gpu_out_contribution, model); + for (int i = 0; i < gpu_out_contribution.size(); i++) { + ASSERT_EQ(gpu_out_contribution[i], cpu_out_contribution[i]); + } +} +} // namespace predictor +} // namespace xgboost diff --git a/plugin/updater_gpu/test/python/test_large.py b/plugin/updater_gpu/test/python/test_large.py index 13aa930fc..66d6bbdd8 100644 --- a/plugin/updater_gpu/test/python/test_large.py +++ b/plugin/updater_gpu/test/python/test_large.py @@ -109,6 +109,4 @@ class TestGPU(unittest.TestCase): evals_result=ag_res3) print("Time to Train: %s seconds" % (str(time.time() - tmp))) - - diff --git a/plugin/updater_gpu/test/python/test_prediction.py b/plugin/updater_gpu/test/python/test_prediction.py new file mode 100644 index 000000000..d707c5239 --- /dev/null +++ b/plugin/updater_gpu/test/python/test_prediction.py @@ -0,0 +1,37 @@ +from __future__ import print_function +#pylint: skip-file +import sys +sys.path.append("../../tests/python") +import xgboost as xgb +import testing as tm +import numpy as np +import unittest + +rng = np.random.RandomState(1994) + + +class TestGPUPredict (unittest.TestCase): + def test_predict(self): + iterations = 1 + np.random.seed(1) + test_num_rows = [10,1000,5000] + test_num_cols = [10,50,500] + for num_rows in test_num_rows: + for num_cols in test_num_cols: + dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows/2)) + watchlist = [(dm, 'train')] + res = {} + param = { + "objective":"binary:logistic", + "predictor":"gpu_predictor", + 'eval_metric': 'auc', + } + bst = xgb.train(param, dm,iterations,evals=watchlist, evals_result=res) + assert self.non_decreasing(res["train"]["auc"]) + gpu_pred = bst.predict(dm, output_margin=True) + bst.set_param({"predictor":"cpu_predictor"}) + cpu_pred = bst.predict(dm, output_margin=True) + np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5) + + def non_decreasing(self, L): + return all((x - y) < 0.001 for x, y in zip(L, L[1:])) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index c5ca3aeb4..2235b79e6 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -45,6 +45,7 @@ struct GBTreeTrainParam : public dmlc::Parameter { int process_type; // flag to print out detailed breakdown of runtime int debug_verbose; + std::string predictor; // declare parameters DMLC_DECLARE_PARAMETER(GBTreeTrainParam) { DMLC_DECLARE_FIELD(num_parallel_tree) @@ -67,6 +68,9 @@ struct GBTreeTrainParam : public dmlc::Parameter { .describe("flag to print out detailed breakdown of runtime"); // add alias DMLC_DECLARE_ALIAS(updater_seq, updater); + DMLC_DECLARE_FIELD(predictor) + .set_default("cpu_predictor") + .describe("Predictor algorithm type"); } }; @@ -130,13 +134,10 @@ struct CacheEntry { // gradient boosted trees class GBTree : public GradientBooster { public: - explicit GBTree(bst_float base_margin) - : model_(base_margin), - predictor( - std::unique_ptr(Predictor::Create("cpu_predictor"))) {} + explicit GBTree(bst_float base_margin) : model_(base_margin) {} void InitCache(const std::vector > &cache) { - predictor->InitCache(cache); + cache_ = cache; } void Configure(const std::vector >& cfg) override { @@ -153,6 +154,10 @@ class GBTree : public GradientBooster { if (tparam.process_type == kUpdate) { model_.InitTreesToUpdate(); } + + // configure predictor + predictor = std::unique_ptr(Predictor::Create(tparam.predictor)); + predictor->Init(cfg, cache_); } void Load(dmlc::Stream* fi) override { @@ -300,7 +305,8 @@ class GBTree : public GradientBooster { std::vector > cfg; // the updaters that can be applied to each of tree std::vector> updaters; - + // Cached matrices + std::vector> cache_; std::unique_ptr predictor; }; diff --git a/src/learner.cc b/src/learner.cc index 9e225b031..689e1f977 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -165,9 +165,19 @@ class LearnerImpl : public Learner { << "grow_fast_histmaker."; cfg_["updater"] = "grow_fast_histmaker"; } else if (tparam.tree_method == 4) { - cfg_["updater"] = "grow_gpu,prune"; + if (cfg_.count("updater") == 0) { + cfg_["updater"] = "grow_gpu,prune"; + } + if (cfg_.count("predictor") == 0) { + cfg_["predictor"] = "gpu_predictor"; + } } else if (tparam.tree_method == 5) { - cfg_["updater"] = "grow_gpu_hist"; + if (cfg_.count("updater") == 0) { + cfg_["updater"] = "grow_gpu_hist"; + } + if (cfg_.count("predictor") == 0) { + cfg_["predictor"] = "gpu_predictor"; + } } } diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 01ffe7bf5..0b5190bd6 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -9,6 +9,8 @@ namespace xgboost { namespace predictor { +DMLC_REGISTRY_FILE_TAG(cpu_predictor); + class CPUPredictor : public Predictor { protected: static bst_float PredValue(const RowBatch::Inst& inst, @@ -28,19 +30,6 @@ class CPUPredictor : public Predictor { return psum; } - void InitOutPredictions(const MetaInfo& info, - std::vector* out_preds, - const gbm::GBTreeModel& model) const { - size_t n = model.param.num_output_group * info.num_row; - const std::vector& base_margin = info.base_margin; - out_preds->resize(n); - if (base_margin.size() != 0) { - CHECK_EQ(out_preds->size(), n); - std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); - } else { - std::fill(out_preds->begin(), out_preds->end(), model.base_margin); - } - } // init thread buffers inline void InitThreadTemp(int nthread, int num_feature) { int prev_thread_temp_size = thread_temp.size(); @@ -106,33 +95,6 @@ class CPUPredictor : public Predictor { } } - /** - * \fn bool PredictFromCache(DMatrix* dmat, std::vector* - * out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) - * - * \brief Attempt to predict from cache. - * - * \return True if it succeeds, false if it fails. - */ - bool PredictFromCache(DMatrix* dmat, std::vector* out_preds, - const gbm::GBTreeModel& model, - unsigned ntree_limit = 0) { - if (ntree_limit == 0 || - ntree_limit * model.param.num_output_group >= model.trees.size()) { - auto it = cache_.find(dmat); - if (it != cache_.end()) { - std::vector& y = it->second.predictions; - if (y.size() != 0) { - out_preds->resize(y.size()); - std::copy(y.begin(), y.end(), out_preds->begin()); - return true; - } - } - } - - return false; - } - void PredLoopInternal(DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit) { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 82200771a..7e1ee3312 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -8,13 +8,47 @@ namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc namespace xgboost { -void Predictor::InitCache(const std::vector >& cache) { +void Predictor::Init( + const std::vector>& cfg, + const std::vector>& cache) { for (const std::shared_ptr& d : cache) { PredictionCacheEntry e; e.data = d; cache_[d.get()] = std::move(e); } } +bool Predictor::PredictFromCache(DMatrix* dmat, + std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit) { + if (ntree_limit == 0 || + ntree_limit * model.param.num_output_group >= model.trees.size()) { + auto it = cache_.find(dmat); + if (it != cache_.end()) { + std::vector& y = it->second.predictions; + if (y.size() != 0) { + out_preds->resize(y.size()); + std::copy(y.begin(), y.end(), out_preds->begin()); + return true; + } + } + } + + return false; +} +void Predictor::InitOutPredictions(const MetaInfo& info, + std::vector* out_preds, + const gbm::GBTreeModel& model) const { + size_t n = model.param.num_output_group * info.num_row; + const std::vector& base_margin = info.base_margin; + out_preds->resize(n); + if (base_margin.size() != 0) { + CHECK_EQ(out_preds->size(), n); + std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); + } else { + std::fill(out_preds->begin(), out_preds->end(), model.base_margin); + } +} Predictor* Predictor::Create(std::string name) { auto* e = ::dmlc::Registry::Get()->Find(name); if (e == nullptr) { @@ -23,3 +57,13 @@ Predictor* Predictor::Create(std::string name) { return (e->body)(); } } // namespace xgboost + +namespace xgboost { +namespace predictor { +// List of files that will be force linked in static links. +#ifdef XGBOOST_USE_CUDA +DMLC_REGISTRY_LINK_TAG(gpu_predictor); +#endif +DMLC_REGISTRY_LINK_TAG(cpu_predictor); +} // namespace predictor +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index f07c980ef..425fb91a3 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -1,4 +1,6 @@ #include "./helpers.h" +#include "xgboost/c_api.h" +#include std::string TempFileName() { return std::tmpnam(nullptr); @@ -60,3 +62,23 @@ xgboost::bst_float GetMetricEval(xgboost::Metric * metric, info.weights = weights; return metric->Eval(preds, info, false); } + +std::shared_ptr CreateDMatrix(int rows, int columns, + float sparsity, int seed) { + const float missing_value = -1; + std::vector test_data(rows * columns); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(0.0f, 1.0f); + for (auto &e : test_data) { + if (dis(gen) < sparsity) { + e = missing_value; + } else { + e = dis(gen); + } + } + + DMatrixHandle handle; + XGDMatrixCreateFromMat(test_data.data(), rows, columns, missing_value, + &handle); + return *static_cast *>(handle); +} diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 94bef0771..6846075c4 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -36,4 +36,19 @@ xgboost::bst_float GetMetricEval( std::vector labels, std::vector weights = std::vector ()); +/** + * \fn std::shared_ptr CreateDMatrix(int rows, int columns, float sparsity, int seed); + * + * \brief Creates dmatrix with uniform random data between 0-1. + * + * \param rows The rows. + * \param columns The columns. + * \param sparsity The sparsity. + * \param seed The seed. + * + * \return The new d matrix. + */ + +std::shared_ptr CreateDMatrix(int rows, int columns, + float sparsity, int seed = 0); #endif diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc new file mode 100644 index 000000000..011e2c392 --- /dev/null +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -0,0 +1,54 @@ +// Copyright by Contributors +#include +#include +#include "../helpers.h" + +namespace xgboost { +TEST(cpu_predictor, Test) { + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor")); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + trees.back()->InitModel(); + (*trees.back())[0].set_leaf(1.5f); + gbm::GBTreeModel model(0.5); + model.CommitModel(std::move(trees), 0); + model.param.num_output_group = 1; + model.base_margin = 0; + + int n_row = 5; + int n_col = 5; + + auto dmat = CreateDMatrix(n_row, n_col, 0); + + // Test predict batch + std::vector out_predictions; + cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + for (int i = 0; i < out_predictions.size(); i++) { + ASSERT_EQ(out_predictions[i], 1.5); + } + + // Test predict instance + auto batch = dmat->RowIterator()->Value(); + for (int i = 0; i < batch.size; i++) { + std::vector instance_out_predictions; + cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model); + ASSERT_EQ(instance_out_predictions[0], 1.5); + } + + // Test predict leaf + std::vector leaf_out_predictions; + cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); + for (int i = 0; i < leaf_out_predictions.size(); i++) { + ASSERT_EQ(leaf_out_predictions[i], 0); + } + + // Test predict contribution + std::vector out_contribution; + cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model); + for (int i = 0; i < out_contribution.size(); i++) { + ASSERT_EQ(out_contribution[i], 1.5); + } +} +} // namespace xgboost \ No newline at end of file diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc new file mode 100644 index 000000000..25d10432d --- /dev/null +++ b/tests/cpp/test_learner.cc @@ -0,0 +1,14 @@ +// Copyright by Contributors +#include +#include "helpers.h" +#include "xgboost/learner.h" + +namespace xgboost { +TEST(learner, Test) { + typedef std::pair arg; + auto args = {arg("tree_method", "exact")}; + auto mat = {CreateDMatrix(10, 10, 0)}; + auto learner = std::unique_ptr(Learner::Create(mat)); + learner->Configure(args); +} +} // namespace xgboost \ No newline at end of file