[GPU-Plugin] Add GPU accelerated prediction (#2593)
* [GPU-Plugin] Add GPU accelerated prediction * Improve allocation message * Update documentation * Resolve linker error for predictor * Add unit tests
This commit is contained in:
parent
71e5e622b1
commit
ef23e424f1
@ -78,6 +78,8 @@ if(PLUGIN_UPDATER_GPU)
|
|||||||
find_package(CUDA 7.5 REQUIRED)
|
find_package(CUDA 7.5 REQUIRED)
|
||||||
cmake_minimum_required(VERSION 3.5)
|
cmake_minimum_required(VERSION 3.5)
|
||||||
|
|
||||||
|
add_definitions(-DXGBOOST_USE_CUDA)
|
||||||
|
|
||||||
include_directories(
|
include_directories(
|
||||||
nccl/src
|
nccl/src
|
||||||
cub
|
cub
|
||||||
|
|||||||
1
Makefile
1
Makefile
@ -129,6 +129,7 @@ ifeq ($(PLUGIN_UPDATER_GPU),ON)
|
|||||||
CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC))))
|
CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC))))
|
||||||
INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/
|
INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/
|
||||||
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt
|
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt
|
||||||
|
CFLAGS += -DXGBOOST_USE_CUDA
|
||||||
endif
|
endif
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
|
|||||||
@ -56,7 +56,7 @@ Parameters for Tree Booster
|
|||||||
* tree_method, string [default='auto']
|
* tree_method, string [default='auto']
|
||||||
- The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
|
- The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
|
||||||
- Distributed and external memory version only support approximate algorithm.
|
- Distributed and external memory version only support approximate algorithm.
|
||||||
- Choices: {'auto', 'exact', 'approx', 'hist'}
|
- Choices: {'auto', 'exact', 'approx', 'hist', 'gpu_exact', 'gpu_hist'}
|
||||||
- 'auto': Use heuristic to choose faster one.
|
- 'auto': Use heuristic to choose faster one.
|
||||||
- For small to medium dataset, exact greedy will be used.
|
- For small to medium dataset, exact greedy will be used.
|
||||||
- For very large-dataset, approximate algorithm will be chosen.
|
- For very large-dataset, approximate algorithm will be chosen.
|
||||||
@ -65,6 +65,8 @@ Parameters for Tree Booster
|
|||||||
- 'exact': Exact greedy algorithm.
|
- 'exact': Exact greedy algorithm.
|
||||||
- 'approx': Approximate greedy algorithm using sketching and histogram.
|
- 'approx': Approximate greedy algorithm using sketching and histogram.
|
||||||
- 'hist': Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching.
|
- 'hist': Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching.
|
||||||
|
- 'gpu_exact': GPU implementation of exact algorithm.
|
||||||
|
- 'gpu_hist': GPU implementation of hist algorithm.
|
||||||
* sketch_eps, [default=0.03]
|
* sketch_eps, [default=0.03]
|
||||||
- This is only used for approximate greedy algorithm.
|
- This is only used for approximate greedy algorithm.
|
||||||
- This roughly translated into ```O(1 / sketch_eps)``` number of bins.
|
- This roughly translated into ```O(1 / sketch_eps)``` number of bins.
|
||||||
@ -107,7 +109,10 @@ Parameters for Tree Booster
|
|||||||
- This is only used if 'hist' is specified as `tree_method`.
|
- This is only used if 'hist' is specified as `tree_method`.
|
||||||
- Maximum number of discrete bins to bucket continuous features.
|
- Maximum number of discrete bins to bucket continuous features.
|
||||||
- Increasing this number improves the optimality of splits at the cost of higher computation time.
|
- Increasing this number improves the optimality of splits at the cost of higher computation time.
|
||||||
|
* predictor, [default='cpu_predictor']
|
||||||
|
- The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
|
||||||
|
- 'cpu_predictor': Multicore CPU prediction algorithm.
|
||||||
|
- 'gpu_predictor': Prediction using GPU. Default for 'gpu_exact' and 'gpu_hist' tree method.
|
||||||
Additional parameters for Dart Booster
|
Additional parameters for Dart Booster
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
* sample_type [default="uniform"]
|
* sample_type [default="uniform"]
|
||||||
|
|||||||
@ -6,33 +6,31 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
#include "../../src/gbm/gbtree_model.h"
|
#include "../../src/gbm/gbtree_model.h"
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
class DMatrix;
|
|
||||||
class TreeUpdater;
|
class TreeUpdater;
|
||||||
}
|
}
|
||||||
namespace xgboost {
|
|
||||||
namespace gbm {
|
|
||||||
struct GBTreeModel;
|
|
||||||
}
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \class Predictor
|
* \class Predictor
|
||||||
*
|
*
|
||||||
* \brief Performs prediction on individual training instances or batches of instances for GBTree.
|
* \brief Performs prediction on individual training instances or batches of
|
||||||
* The predictor also manages a prediction cache associated with input matrices. If possible,
|
* instances for GBTree. The predictor also manages a prediction cache
|
||||||
* it will use previously calculated predictions instead of calculating new predictions.
|
* associated with input matrices. If possible, it will use previously
|
||||||
* Prediction functions all take a GBTreeModel and a DMatrix as input and output a vector of
|
* calculated predictions instead of calculating new predictions.
|
||||||
* predictions. The predictor does not modify any state of the model itself.
|
* Prediction functions all take a GBTreeModel and a DMatrix as input and
|
||||||
|
* output a vector of predictions. The predictor does not modify any state of
|
||||||
|
* the model itself.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class Predictor {
|
class Predictor {
|
||||||
@ -40,36 +38,47 @@ class Predictor {
|
|||||||
virtual ~Predictor() {}
|
virtual ~Predictor() {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn void Predictor::InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache);
|
* \fn virtual void Predictor::Init(const std::vector<std::pair<std::string,
|
||||||
|
* std::string> >&cfg ,const std::vector<std::shared_ptr<DMatrix> > &cache);
|
||||||
*
|
*
|
||||||
* \brief Register input matrices in prediction cache.
|
* \brief Configure and register input matrices in prediction cache.
|
||||||
*
|
*
|
||||||
|
* \param cfg The configuration.
|
||||||
* \param cache Vector of DMatrix's to be used in prediction.
|
* \param cache Vector of DMatrix's to be used in prediction.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache);
|
virtual void Init(const std::vector<std::pair<std::string, std::string>>& cfg,
|
||||||
|
const std::vector<std::shared_ptr<DMatrix>>& cache);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictBatch( DMatrix* dmat, std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model, int tree_begin, unsigned ntree_limit = 0) = 0;
|
* \fn virtual void Predictor::PredictBatch( DMatrix* dmat,
|
||||||
|
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model, int
|
||||||
|
* tree_begin, unsigned ntree_limit = 0) = 0;
|
||||||
*
|
*
|
||||||
* \brief Generate batch predictions for a given feature matrix. May use cached predictions if available instead of calculating from scratch.
|
* \brief Generate batch predictions for a given feature matrix. May use
|
||||||
|
* cached predictions if available instead of calculating from scratch.
|
||||||
*
|
*
|
||||||
* \param [in,out] dmat Feature matrix.
|
* \param [in,out] dmat Feature matrix.
|
||||||
* \param [in,out] out_preds The output preds.
|
* \param [in,out] out_preds The output preds.
|
||||||
* \param model The model to predict from.
|
* \param model The model to predict from.
|
||||||
* \param tree_begin The tree begin index.
|
* \param tree_begin The tree begin index.
|
||||||
* \param ntree_limit (Optional) The ntree limit. 0 means do not limit trees.
|
* \param ntree_limit (Optional) The ntree limit. 0 means do not
|
||||||
|
* limit trees.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void PredictBatch(
|
virtual void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
DMatrix* dmat, std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
int tree_begin, unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters, int num_new_trees) = 0;
|
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
|
||||||
|
* &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters, int
|
||||||
|
* num_new_trees) = 0;
|
||||||
*
|
*
|
||||||
* \brief Update the internal prediction cache using newly added trees. Will use the tree updater
|
* \brief Update the internal prediction cache using newly added trees. Will
|
||||||
* to do this if possible. Should be called as a part of the tree boosting process to facilitate the look up of predictions at a later time.
|
* use the tree updater to do this if possible. Should be called as a part of
|
||||||
|
* the tree boosting process to facilitate the look up of predictions
|
||||||
|
* at a later time.
|
||||||
*
|
*
|
||||||
* \param model The model.
|
* \param model The model.
|
||||||
* \param [in,out] updaters The updater sequence for gradient boosting.
|
* \param [in,out] updaters The updater sequence for gradient boosting.
|
||||||
@ -77,15 +86,19 @@ class Predictor {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void UpdatePredictionCache(
|
virtual void UpdatePredictionCache(
|
||||||
const gbm::GBTreeModel &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters,
|
const gbm::GBTreeModel& model,
|
||||||
|
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||||
int num_new_trees) = 0;
|
int num_new_trees) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst& inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
|
* \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst&
|
||||||
|
* inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model,
|
||||||
|
* unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
|
||||||
*
|
*
|
||||||
* \brief online prediction function, predict score for one instance at a time NOTE: use the batch
|
* \brief online prediction function, predict score for one instance at a time
|
||||||
* prediction interface if possible, batch prediction is usually more efficient than online
|
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||||
* prediction This function is NOT threadsafe, make sure you only call from one thread.
|
* usually more efficient than online prediction This function is NOT
|
||||||
|
* threadsafe, make sure you only call from one thread.
|
||||||
*
|
*
|
||||||
* \param inst The instance to predict.
|
* \param inst The instance to predict.
|
||||||
* \param [in,out] out_preds The output preds.
|
* \param [in,out] out_preds The output preds.
|
||||||
@ -94,15 +107,19 @@ class Predictor {
|
|||||||
* \param root_index (Optional) Zero-based index of the root.
|
* \param root_index (Optional) Zero-based index of the root.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void PredictInstance(
|
virtual void PredictInstance(const SparseBatch::Inst& inst,
|
||||||
const SparseBatch::Inst& inst, std::vector<bst_float>* out_preds,
|
std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit = 0,
|
||||||
|
unsigned root_index = 0) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0;
|
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,
|
||||||
|
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned
|
||||||
|
* ntree_limit = 0) = 0;
|
||||||
*
|
*
|
||||||
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector this is
|
* \brief predict the leaf index of each tree, the output will be nsample *
|
||||||
* only valid in gbtree predictor.
|
* ntree vector this is only valid in gbtree predictor.
|
||||||
*
|
*
|
||||||
* \param [in,out] dmat The input feature matrix.
|
* \param [in,out] dmat The input feature matrix.
|
||||||
* \param [in,out] out_preds The output preds.
|
* \param [in,out] out_preds The output preds.
|
||||||
@ -111,13 +128,17 @@ class Predictor {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0;
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat, std::vector<bst_float>* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0;
|
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
|
||||||
|
* std::vector<bst_float>* out_contribs, const gbm::GBTreeModel& model,
|
||||||
|
* unsigned ntree_limit = 0) = 0;
|
||||||
*
|
*
|
||||||
* \brief feature contributions to individual predictions; the output will be a vector of length
|
* \brief feature contributions to individual predictions; the output will be
|
||||||
* (nfeats + 1) * num_output_group * nsample, arranged in that order.
|
* a vector of length (nfeats + 1) * num_output_group * nsample, arranged in
|
||||||
|
* that order.
|
||||||
*
|
*
|
||||||
* \param [in,out] dmat The input feature matrix.
|
* \param [in,out] dmat The input feature matrix.
|
||||||
* \param [in,out] out_contribs The output feature contribs.
|
* \param [in,out] out_contribs The output feature contribs.
|
||||||
@ -125,9 +146,10 @@ class Predictor {
|
|||||||
* \param ntree_limit (Optional) The ntree limit.
|
* \param ntree_limit (Optional) The ntree limit.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void PredictContribution(
|
virtual void PredictContribution(DMatrix* dmat,
|
||||||
DMatrix* dmat, std::vector<bst_float>* out_contribs,
|
std::vector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0;
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit = 0) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn static Predictor* Predictor::Create(std::string name);
|
* \fn static Predictor* Predictor::Create(std::string name);
|
||||||
@ -139,6 +161,32 @@ class Predictor {
|
|||||||
static Predictor* Create(std::string name);
|
static Predictor* Create(std::string name);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/**
|
||||||
|
* \fn bool PredictFromCache(DMatrix* dmat, std::vector<bst_float>*
|
||||||
|
* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0)
|
||||||
|
*
|
||||||
|
* \brief Attempt to predict from cache.
|
||||||
|
*
|
||||||
|
* \return True if it succeeds, false if it fails.
|
||||||
|
*/
|
||||||
|
bool PredictFromCache(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit = 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn void Predictor::InitOutPredictions(const MetaInfo& info,
|
||||||
|
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model) const;
|
||||||
|
*
|
||||||
|
* \brief Init out predictions according to base margin.
|
||||||
|
*
|
||||||
|
* \param info Dmatrix info possibly containing base margin.
|
||||||
|
* \param [in,out] out_preds The out preds.
|
||||||
|
* \param model The model.
|
||||||
|
*/
|
||||||
|
void InitOutPredictions(const MetaInfo& info,
|
||||||
|
std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \struct PredictionCacheEntry
|
* \struct PredictionCacheEntry
|
||||||
*
|
*
|
||||||
@ -151,8 +199,8 @@ class Predictor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Map of matrices and associated cached predictions to facilitate storing and looking up
|
* \brief Map of matrices and associated cached predictions to facilitate
|
||||||
* predictions.
|
* storing and looking up predictions.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
std::unordered_map<DMatrix*, PredictionCacheEntry> cache_;
|
std::unordered_map<DMatrix*, PredictionCacheEntry> cache_;
|
||||||
|
|||||||
@ -271,6 +271,10 @@ class TreeModel {
|
|||||||
inline const Node& operator[](int nid) const {
|
inline const Node& operator[](int nid) const {
|
||||||
return nodes[nid];
|
return nodes[nid];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*! \brief get const reference to nodes */
|
||||||
|
inline const std::vector<Node>& GetNodes() const { return nodes; }
|
||||||
|
|
||||||
/*! \brief get node statistics given nid */
|
/*! \brief get node statistics given nid */
|
||||||
inline NodeStat& stat(int nid) {
|
inline NodeStat& stat(int nid) {
|
||||||
return stats[nid];
|
return stats[nid];
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# CUDA Accelerated Tree Construction Algorithms
|
# CUDA Accelerated Tree Construction Algorithms
|
||||||
This plugin adds GPU accelerated tree construction algorithms to XGBoost.
|
This plugin adds GPU accelerated tree construction and prediction algorithms to XGBoost.
|
||||||
## Usage
|
## Usage
|
||||||
Specify the 'tree_method' parameter as one of the following algorithms.
|
Specify the 'tree_method' parameter as one of the following algorithms.
|
||||||
|
|
||||||
@ -18,6 +18,9 @@ colsample_bylevel | ✔ | ✔ |
|
|||||||
max_bin | ✖ | ✔ |
|
max_bin | ✖ | ✔ |
|
||||||
gpu_id | ✔ | ✔ |
|
gpu_id | ✔ | ✔ |
|
||||||
n_gpus | ✖ | ✔ |
|
n_gpus | ✖ | ✔ |
|
||||||
|
predictor | ✔ | ✔ |
|
||||||
|
|
||||||
|
GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'.
|
||||||
|
|
||||||
The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
|
The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
|
||||||
|
|
||||||
@ -37,48 +40,31 @@ To run benchmarks on synthetic data for binary classification:
|
|||||||
$ python benchmark/benchmark.py
|
$ python benchmark/benchmark.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Training time time on 1000000 rows x 50 columns with 500 boosting iterations on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
|
Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations and 0.25/0.75 test/train split on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
|
||||||
|
|
||||||
| tree_method | Time (s) |
|
| tree_method | Time (s) |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| gpu_hist | 11.09 |
|
| gpu_hist | 13.87 |
|
||||||
| hist (histogram XGBoost - CPU) | 41.75 |
|
| hist | 63.55 |
|
||||||
| gpu_exact | 193.90 |
|
| gpu_exact | 161.08 |
|
||||||
| exact (standard XGBoost - CPU) | 720.12 |
|
| exact | 1082.20 |
|
||||||
|
|
||||||
|
|
||||||
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method.
|
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method.
|
||||||
|
|
||||||
## Test
|
## Test
|
||||||
To run tests:Will
|
To run python tests:
|
||||||
```bash
|
```bash
|
||||||
$ python -m nose test/python/
|
$ python -m nose test/python/
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Google tests can be enabled by specifying -DGOOGLE_TEST=ON when building with cmake.
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
|
A CUDA capable GPU with at least compute capability >= 3.5
|
||||||
|
|
||||||
Building the plug-in requires CUDA Toolkit 7.5 or later (https://developer.nvidia.com/cuda-downloads)
|
Building the plug-in requires CUDA Toolkit 7.5 or later (https://developer.nvidia.com/cuda-downloads)
|
||||||
|
|
||||||
submodule: The plugin also depends on CUB 1.6.4 - https://nvlabs.github.io/cub/ . CUB is a header only cuda library which provides sort/reduce/scan primitives.
|
|
||||||
|
|
||||||
submodule: NVIDIA NCCL from https://github.com/NVIDIA/nccl with windows port allowed by git@github.com:h2oai/nccl.git
|
|
||||||
|
|
||||||
## Download full repo + full submodules for your choice (or empty) path <mypath>
|
|
||||||
|
|
||||||
git clone --recursive https://github.com/dmlc/xgboost.git <mypath>
|
|
||||||
|
|
||||||
## Download with shallow submodules for much quicker download:
|
|
||||||
|
|
||||||
git 2.9.0+ (assumes only HEAD used for all submodules, but not true currently for dmlc-core and rabbit)
|
|
||||||
|
|
||||||
git clone --recursive --shallow-submodules https://github.com/dmlc/xgboost.git <mypath>
|
|
||||||
|
|
||||||
git 2.9.0-: (only cub is shallow, as largest repo)
|
|
||||||
|
|
||||||
git clone https://github.com/dmlc/xgboost.git <mypath>
|
|
||||||
cd <mypath>
|
|
||||||
bash plugin/updater/gpu/gitshallow_submodules.sh
|
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
From the command line on Linux starting from the xgboost directory:
|
From the command line on Linux starting from the xgboost directory:
|
||||||
@ -110,14 +96,11 @@ On some systems, nccl libraries are specific to a particular system (IBM Power o
|
|||||||
|
|
||||||
### For Developers!
|
### For Developers!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
In case you want to build only for a specific GPU(s), for eg. GP100 and GP102,
|
In case you want to build only for a specific GPU(s), for eg. GP100 and GP102,
|
||||||
whose compute capability are 60 and 61 respectively:
|
whose compute capability are 60 and 61 respectively:
|
||||||
```bash
|
```bash
|
||||||
$ cmake .. -DPLUGIN_UPDATER_GPU=ON -DGPU_COMPUTE_VER="60;61"
|
$ cmake .. -DPLUGIN_UPDATER_GPU=ON -DGPU_COMPUTE_VER="60;61"
|
||||||
```
|
```
|
||||||
By default, the versions will include support for all GPUs in Maxwell and Pascal architectures.
|
|
||||||
|
|
||||||
### Using make
|
### Using make
|
||||||
Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory
|
Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory
|
||||||
@ -131,19 +114,10 @@ Similar to cmake, if you want to build only for a specific GPU(s):
|
|||||||
$ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
|
$ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
|
||||||
```
|
```
|
||||||
|
|
||||||
### For Developers!
|
|
||||||
|
|
||||||
Now, some of the code-base inside gpu plugins have googletest unit-tests inside 'tests/'.
|
|
||||||
They can be enabled run along with other unit-tests inside '<xgboostRoot>/tests/cpp' using:
|
|
||||||
```bash
|
|
||||||
# make sure CUDA SDK bin directory is in the 'PATH' env variable
|
|
||||||
# below 2 commands need only be executed once
|
|
||||||
$ source ./dmlc-core/scripts/travis/travis_setup_env.sh
|
|
||||||
$ make -f dmlc-core/scripts/packages.mk gtest
|
|
||||||
$ make PLUGIN_UPDATER_GPU=ON GTEST_PATH=${CACHE_PREFIX} test
|
|
||||||
```
|
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
##### 2017/8/14
|
||||||
|
* Added GPU accelerated prediction. Considerably improved performance when using test/eval sets.
|
||||||
|
|
||||||
##### 2017/7/10
|
##### 2017/7/10
|
||||||
* Memory performance improved 4x for gpu_hist
|
* Memory performance improved 4x for gpu_hist
|
||||||
|
|
||||||
|
|||||||
@ -3,19 +3,22 @@ import sys, argparse
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
||||||
print("Generating dataset: {} rows * {} columns".format(args.rows,args.columns))
|
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
|
||||||
|
print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size))
|
||||||
tmp = time.time()
|
tmp = time.time()
|
||||||
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
|
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7)
|
||||||
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
|
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
|
||||||
tmp = time.time()
|
tmp = time.time()
|
||||||
print ("DMatrix Start")
|
print ("DMatrix Start")
|
||||||
# omp way
|
# omp way
|
||||||
dtrain = xgb.DMatrix(X, y, nthread=-1)
|
dtrain = xgb.DMatrix(X_train, y_train, nthread=-1)
|
||||||
# non-omp way
|
dtest = xgb.DMatrix(X_test, y_test, nthread=-1)
|
||||||
#dtrain = xgb.DMatrix(X, y)
|
|
||||||
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
|
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
param = {'objective': 'binary:logistic',
|
param = {'objective': 'binary:logistic',
|
||||||
@ -23,28 +26,30 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
|||||||
'silent': 0,
|
'silent': 0,
|
||||||
'n_gpus': 1,
|
'n_gpus': 1,
|
||||||
'gpu_id': 0,
|
'gpu_id': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'error',
|
||||||
|
'debug_verbose': 0,
|
||||||
|
}
|
||||||
|
|
||||||
param['tree_method'] = gpu_algorithm
|
param['tree_method'] = gpu_algorithm
|
||||||
print("Training with '%s'" % param['tree_method'])
|
print("Training with '%s'" % param['tree_method'])
|
||||||
tmp = time.time()
|
tmp = time.time()
|
||||||
xgb.train(param, dtrain, args.iterations)
|
xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")])
|
||||||
print ("Train Time: %s seconds" % (str(time.time() - tmp)))
|
print ("Train Time: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
param['silent'] = 1
|
param['silent'] = 1
|
||||||
param['tree_method'] = cpu_algorithm
|
param['tree_method'] = cpu_algorithm
|
||||||
print("Training with '%s'" % param['tree_method'])
|
print("Training with '%s'" % param['tree_method'])
|
||||||
tmp = time.time()
|
tmp = time.time()
|
||||||
xgb.train(param, dtrain, args.iterations)
|
xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")])
|
||||||
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--algorithm', choices=['all', 'gpu_exact', 'gpu_hist'], default='all')
|
parser.add_argument('--algorithm', choices=['all', 'gpu_exact', 'gpu_hist'], default='all')
|
||||||
parser.add_argument('--rows',type=int,default=1000000)
|
parser.add_argument('--rows', type=int, default=1000000)
|
||||||
parser.add_argument('--columns',type=int,default=50)
|
parser.add_argument('--columns', type=int, default=50)
|
||||||
parser.add_argument('--iterations',type=int,default=500)
|
parser.add_argument('--iterations', type=int, default=500)
|
||||||
|
parser.add_argument('--test_size', type=float, default=0.25)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if 'gpu_hist' in args.algorithm:
|
if 'gpu_hist' in args.algorithm:
|
||||||
@ -54,4 +59,3 @@ elif 'gpu_exact' in args.algorithm:
|
|||||||
elif 'all' in args.algorithm:
|
elif 'all' in args.algorithm:
|
||||||
run_benchmark(args, 'gpu_exact', 'exact')
|
run_benchmark(args, 'gpu_exact', 'exact')
|
||||||
run_benchmark(args, 'gpu_hist', 'hist')
|
run_benchmark(args, 'gpu_hist', 'hist')
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
git submodule init
|
|
||||||
for i in $(git submodule | awk '{print $2}'); do
|
|
||||||
spath=$(git config -f .gitmodules --get submodule.$i.path)
|
|
||||||
surl=$(git config -f .gitmodules --get submodule.$i.url)
|
|
||||||
if [ $spath == "cub" ]
|
|
||||||
then
|
|
||||||
git submodule update --depth 3 $spath
|
|
||||||
else
|
|
||||||
git submodule update $spath
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
@ -1,5 +1,6 @@
|
|||||||
|
|
||||||
PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \
|
PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \
|
||||||
build_plugin/updater_gpu/src/updater_gpu.o \
|
build_plugin/updater_gpu/src/updater_gpu.o \
|
||||||
build_plugin/updater_gpu/src/gpu_hist_builder.o
|
build_plugin/updater_gpu/src/gpu_hist_builder.o \
|
||||||
|
build_plugin/updater_gpu/src/gpu_predictor.o
|
||||||
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
|
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
* Copyright 2017 XGBoost contributors
|
* Copyright 2017 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <dmlc/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <thrust/random.h>
|
#include <thrust/random.h>
|
||||||
@ -121,6 +121,28 @@ inline std::string device_name(int device_idx) {
|
|||||||
return std::string(prop.name);
|
return std::string(prop.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline size_t available_memory(int device_idx) {
|
||||||
|
size_t device_free = 0;
|
||||||
|
size_t device_total = 0;
|
||||||
|
safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
|
||||||
|
return device_free;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn inline int max_shared_memory(int device_idx)
|
||||||
|
*
|
||||||
|
* \brief Maximum shared memory per block on this device.
|
||||||
|
*
|
||||||
|
* \param device_idx Zero-based index of the device.
|
||||||
|
*/
|
||||||
|
|
||||||
|
inline int max_shared_memory(int device_idx) {
|
||||||
|
cudaDeviceProp prop;
|
||||||
|
dh::safe_cuda(cudaGetDeviceProperties(&prop, device_idx));
|
||||||
|
return prop.sharedMemPerBlock;
|
||||||
|
}
|
||||||
|
|
||||||
// ensure gpu_id is correct, so not dependent upon user knowing details
|
// ensure gpu_id is correct, so not dependent upon user knowing details
|
||||||
inline int get_device_idx(int gpu_id) {
|
inline int get_device_idx(int gpu_id) {
|
||||||
// protect against overrun for gpu_id
|
// protect against overrun for gpu_id
|
||||||
@ -215,7 +237,7 @@ __device__ range block_stride_range(T begin, T end) {
|
|||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Threadblock iterates over range, filling with value
|
// Threadblock iterates over range, filling with value. Requires all threads in block to be active.
|
||||||
template <typename IterT, typename ValueT>
|
template <typename IterT, typename ValueT>
|
||||||
__device__ void block_fill(IterT begin, size_t n, ValueT value) {
|
__device__ void block_fill(IterT begin, size_t n, ValueT value) {
|
||||||
for (auto i : block_stride_range(static_cast<size_t>(0), n)) {
|
for (auto i : block_stride_range(static_cast<size_t>(0), n)) {
|
||||||
@ -463,7 +485,7 @@ class bulk_allocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void allocate(int device_idx, Args... args) {
|
void allocate(int device_idx, bool silent ,Args... args) {
|
||||||
size_t size = get_size_bytes(args...);
|
size_t size = get_size_bytes(args...);
|
||||||
|
|
||||||
char *ptr = allocate_device(device_idx, size, MemoryT);
|
char *ptr = allocate_device(device_idx, size, MemoryT);
|
||||||
@ -473,6 +495,14 @@ class bulk_allocator {
|
|||||||
d_ptr.push_back(ptr);
|
d_ptr.push_back(ptr);
|
||||||
_size.push_back(size);
|
_size.push_back(size);
|
||||||
_device_idx.push_back(device_idx);
|
_device_idx.push_back(device_idx);
|
||||||
|
|
||||||
|
if(!silent)
|
||||||
|
{
|
||||||
|
const int mb_size = 1048576;
|
||||||
|
LOG(CONSOLE) << "Allocated " << size / mb_size << "MB on [" << device_idx
|
||||||
|
<< "] " << device_name(device_idx) << ", "
|
||||||
|
<< available_memory(device_idx) / mb_size << "MB remaining.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -515,13 +545,6 @@ struct CubMemory {
|
|||||||
bool IsAllocated() { return d_temp_storage != NULL; }
|
bool IsAllocated() { return d_temp_storage != NULL; }
|
||||||
};
|
};
|
||||||
|
|
||||||
inline size_t available_memory(int device_idx) {
|
|
||||||
size_t device_free = 0;
|
|
||||||
size_t device_total = 0;
|
|
||||||
safe_cuda(cudaSetDevice(device_idx));
|
|
||||||
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
|
|
||||||
return device_free;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Utility functions
|
* Utility functions
|
||||||
|
|||||||
@ -232,7 +232,7 @@ class GPUBuilder {
|
|||||||
|
|
||||||
void allocateAllData(int offsetSize) {
|
void allocateAllData(int offsetSize) {
|
||||||
int tmpBuffSize = scanTempBufferSize(nVals);
|
int tmpBuffSize = scanTempBufferSize(nVals);
|
||||||
ba.allocate(dh::get_device_idx(param.gpu_id), &vals, nVals, &vals_cached,
|
ba.allocate(dh::get_device_idx(param.gpu_id), param.silent, &vals, nVals, &vals_cached,
|
||||||
nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets,
|
nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets,
|
||||||
offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
|
offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
|
||||||
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
|
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
|
||||||
@ -252,12 +252,6 @@ class GPUBuilder {
|
|||||||
allocateAllData((int)offset.size());
|
allocateAllData((int)offset.size());
|
||||||
transferAndSortData(fval, fId, offset);
|
transferAndSortData(fval, fId, offset);
|
||||||
allocated = true;
|
allocated = true;
|
||||||
if (!param.silent) {
|
|
||||||
const int mb_size = 1048576;
|
|
||||||
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
|
|
||||||
<< free_memory / mb_size << " MB on "
|
|
||||||
<< dh::device_name(dh::get_device_idx(param.gpu_id));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void convertToCsc(DMatrix& hMat, std::vector<float>& fval,
|
void convertToCsc(DMatrix& hMat, std::vector<float>& fval,
|
||||||
|
|||||||
@ -127,16 +127,6 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
|||||||
this->param = param;
|
this->param = param;
|
||||||
|
|
||||||
CHECK(param.n_gpus != 0) << "Must have at least one device";
|
CHECK(param.n_gpus != 0) << "Must have at least one device";
|
||||||
int n_devices_all = dh::n_devices_all(param.n_gpus);
|
|
||||||
for (int device_idx = 0; device_idx < n_devices_all; device_idx++) {
|
|
||||||
if (!param.silent) {
|
|
||||||
size_t free_memory = dh::available_memory(device_idx);
|
|
||||||
const int mb_size = 1048576;
|
|
||||||
LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] "
|
|
||||||
<< dh::device_name(device_idx) << " with "
|
|
||||||
<< free_memory / mb_size << " MB available device memory.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||||
DMatrix& fmat, // NOLINT
|
DMatrix& fmat, // NOLINT
|
||||||
@ -210,9 +200,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
// process)
|
// process)
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
|
||||||
"block. Try setting 'tree_method' "
|
|
||||||
"parameter to 'exact'";
|
|
||||||
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
||||||
dh::Timer time0;
|
dh::Timer time0;
|
||||||
hmat_.Init(&fmat, param.max_bin);
|
hmat_.Init(&fmat, param.max_bin);
|
||||||
@ -326,7 +313,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
bst_ulong num_elements_segment =
|
bst_ulong num_elements_segment =
|
||||||
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
|
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
|
||||||
ba.allocate(
|
ba.allocate(
|
||||||
device_idx, &(hist_vec[d_idx].data),
|
device_idx, param.silent,
|
||||||
|
&(hist_vec[d_idx].data),
|
||||||
n_nodes(param.max_depth - 1) * n_bins, &nodes[d_idx],
|
n_nodes(param.max_depth - 1) * n_bins, &nodes[d_idx],
|
||||||
n_nodes(param.max_depth), &nodes_temp[d_idx], max_num_nodes_device,
|
n_nodes(param.max_depth), &nodes_temp[d_idx], max_num_nodes_device,
|
||||||
&nodes_child_temp[d_idx], max_num_nodes_device,
|
&nodes_child_temp[d_idx], max_num_nodes_device,
|
||||||
@ -367,11 +355,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
feature_flags[d_idx].fill(1); // init device object (assumes comes after
|
feature_flags[d_idx].fill(1); // init device object (assumes comes after
|
||||||
// ba.allocate that sets device)
|
// ba.allocate that sets device)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!param.silent) {
|
|
||||||
const int mb_size = 1048576;
|
|
||||||
LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy or init to do every iteration
|
// copy or init to do every iteration
|
||||||
|
|||||||
411
plugin/updater_gpu/src/gpu_predictor.cu
Normal file
411
plugin/updater_gpu/src/gpu_predictor.cu
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright by Contributors 2017
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <thrust/copy.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
#include <xgboost/predictor.h>
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
#include <memory>
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace predictor {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||||
|
|
||||||
|
/*! \brief prediction parameters */
|
||||||
|
struct GPUPredictionParam : public dmlc::Parameter<GPUPredictionParam> {
|
||||||
|
int gpu_id;
|
||||||
|
int n_gpus;
|
||||||
|
bool silent;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(GPUPredictionParam) {
|
||||||
|
DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe(
|
||||||
|
"Device ordinal for GPU prediction.");
|
||||||
|
DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe(
|
||||||
|
"Number of devices to use for prediction (NOT IMPLEMENTED).");
|
||||||
|
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||||
|
"Do not print information during trainig.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
DMLC_REGISTER_PARAMETER(GPUPredictionParam);
|
||||||
|
|
||||||
|
template <typename iter_t>
|
||||||
|
void increment_offset(iter_t begin_itr, iter_t end_itr, size_t amount) {
|
||||||
|
thrust::transform(begin_itr, end_itr, begin_itr,
|
||||||
|
[=] __device__(size_t elem) { return elem + amount; });
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \struct DeviceMatrix
|
||||||
|
*
|
||||||
|
* \brief A csr representation of the input matrix allocated on the device.
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct DeviceMatrix {
|
||||||
|
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||||
|
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
||||||
|
dh::dvec<size_t> row_ptr;
|
||||||
|
dh::dvec<SparseBatch::Entry> data;
|
||||||
|
thrust::device_vector<float> predictions;
|
||||||
|
|
||||||
|
DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
auto info = dmat->info();
|
||||||
|
ba.allocate(device_idx, silent, &row_ptr, info.num_row + 1, &data,
|
||||||
|
info.num_nonzero);
|
||||||
|
auto iter = dmat->RowIterator();
|
||||||
|
iter->BeforeFirst();
|
||||||
|
size_t data_offset = 0;
|
||||||
|
while (iter->Next()) {
|
||||||
|
auto batch = iter->Value();
|
||||||
|
// Copy row ptr
|
||||||
|
thrust::copy(batch.ind_ptr, batch.ind_ptr + batch.size + 1,
|
||||||
|
row_ptr.tbegin() + batch.base_rowid);
|
||||||
|
if (batch.base_rowid > 0) {
|
||||||
|
auto begin_itr = row_ptr.tbegin() + batch.base_rowid;
|
||||||
|
auto end_itr = begin_itr + batch.size + 1;
|
||||||
|
increment_offset(begin_itr, end_itr, batch.base_rowid);
|
||||||
|
}
|
||||||
|
// Copy data
|
||||||
|
thrust::copy(batch.data_ptr, batch.data_ptr + batch.ind_ptr[batch.size],
|
||||||
|
data.tbegin() + data_offset);
|
||||||
|
data_offset += batch.ind_ptr[batch.size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \struct DevicePredictionNode
|
||||||
|
*
|
||||||
|
* \brief Packed 16 byte representation of a tree node for use in device
|
||||||
|
* prediction
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct DevicePredictionNode {
|
||||||
|
XGBOOST_DEVICE DevicePredictionNode()
|
||||||
|
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
||||||
|
|
||||||
|
union NodeValue {
|
||||||
|
float leaf_weight;
|
||||||
|
float fvalue;
|
||||||
|
};
|
||||||
|
|
||||||
|
int fidx;
|
||||||
|
int left_child_idx;
|
||||||
|
int right_child_idx;
|
||||||
|
NodeValue val;
|
||||||
|
|
||||||
|
DevicePredictionNode(const RegTree::Node& n) { // NOLINT
|
||||||
|
this->left_child_idx = n.cleft();
|
||||||
|
this->right_child_idx = n.cright();
|
||||||
|
this->fidx = n.split_index();
|
||||||
|
if (n.default_left()) {
|
||||||
|
fidx |= (1U << 31);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n.is_leaf()) {
|
||||||
|
this->val.leaf_weight = n.leaf_value();
|
||||||
|
} else {
|
||||||
|
this->val.fvalue = n.split_cond();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; }
|
||||||
|
|
||||||
|
XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; }
|
||||||
|
|
||||||
|
XGBOOST_DEVICE int MissingIdx() const {
|
||||||
|
if (MissingLeft()) {
|
||||||
|
return this->left_child_idx;
|
||||||
|
} else {
|
||||||
|
return this->right_child_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; }
|
||||||
|
|
||||||
|
XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ElementLoader {
|
||||||
|
bool use_shared;
|
||||||
|
size_t* d_row_ptr;
|
||||||
|
SparseBatch::Entry* d_data;
|
||||||
|
int num_features;
|
||||||
|
float* smem;
|
||||||
|
|
||||||
|
__device__ ElementLoader(bool use_shared, size_t* row_ptr,
|
||||||
|
SparseBatch::Entry* entry, int num_features,
|
||||||
|
float* smem, int num_rows)
|
||||||
|
: use_shared(use_shared),
|
||||||
|
d_row_ptr(row_ptr),
|
||||||
|
d_data(entry),
|
||||||
|
num_features(num_features),
|
||||||
|
smem(smem) {
|
||||||
|
// Copy instances
|
||||||
|
if (use_shared) {
|
||||||
|
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
int shared_elements = blockDim.x * num_features;
|
||||||
|
dh::block_fill(smem, shared_elements, nanf(""));
|
||||||
|
__syncthreads();
|
||||||
|
if (global_idx < num_rows) {
|
||||||
|
bst_uint elem_begin = d_row_ptr[global_idx];
|
||||||
|
bst_uint elem_end = d_row_ptr[global_idx + 1];
|
||||||
|
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||||
|
SparseBatch::Entry elem = d_data[elem_idx];
|
||||||
|
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__device__ float GetFvalue(int ridx, int fidx) {
|
||||||
|
if (use_shared) {
|
||||||
|
return smem[threadIdx.x * num_features + fidx];
|
||||||
|
} else {
|
||||||
|
// Binary search
|
||||||
|
auto begin_ptr = d_data + d_row_ptr[ridx];
|
||||||
|
auto end_ptr = d_data + d_row_ptr[ridx + 1];
|
||||||
|
SparseBatch::Entry* previous_middle = nullptr;
|
||||||
|
while (end_ptr != begin_ptr) {
|
||||||
|
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||||
|
if (middle == previous_middle) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
previous_middle = middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (middle->index == fidx) {
|
||||||
|
return middle->fvalue;
|
||||||
|
} else if (middle->index < fidx) {
|
||||||
|
begin_ptr = middle;
|
||||||
|
} else {
|
||||||
|
end_ptr = middle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Value is missing
|
||||||
|
return nanf("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
|
||||||
|
ElementLoader* loader) {
|
||||||
|
DevicePredictionNode n = tree[0];
|
||||||
|
while (!n.IsLeaf()) {
|
||||||
|
float fvalue = loader->GetFvalue(ridx, n.GetFidx());
|
||||||
|
// Missing value
|
||||||
|
if (isnan(fvalue)) {
|
||||||
|
n = tree[n.MissingIdx()];
|
||||||
|
} else {
|
||||||
|
if (fvalue < n.GetFvalue()) {
|
||||||
|
n = tree[n.left_child_idx];
|
||||||
|
} else {
|
||||||
|
n = tree[n.right_child_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n.GetWeight();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_THREADS>
|
||||||
|
__global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||||
|
float* d_out_predictions, int* d_tree_segments,
|
||||||
|
int* d_tree_group, size_t* d_row_ptr,
|
||||||
|
SparseBatch::Entry* d_data, int tree_begin,
|
||||||
|
int tree_end, int num_features, bst_uint num_rows,
|
||||||
|
bool use_shared, int num_group) {
|
||||||
|
extern __shared__ float smem[];
|
||||||
|
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
|
||||||
|
num_rows);
|
||||||
|
if (global_idx >= num_rows) return;
|
||||||
|
if (num_group == 1) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
|
const DevicePredictionNode* d_tree =
|
||||||
|
d_nodes + d_tree_segments[tree_idx - tree_begin];
|
||||||
|
sum += GetLeafWeight(global_idx, d_tree, &loader);
|
||||||
|
}
|
||||||
|
d_out_predictions[global_idx] += sum;
|
||||||
|
} else {
|
||||||
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
|
int tree_group = d_tree_group[tree_idx];
|
||||||
|
const DevicePredictionNode* d_tree =
|
||||||
|
d_nodes + d_tree_segments[tree_idx - tree_begin];
|
||||||
|
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||||
|
d_out_predictions[out_prediction_idx] +=
|
||||||
|
GetLeafWeight(global_idx, d_tree, &loader);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class GPUPredictor : public xgboost::Predictor {
|
||||||
|
private:
|
||||||
|
void DevicePredictInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
|
int tree_end) {
|
||||||
|
if (tree_end - tree_begin == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add dmatrix to device if not seen before
|
||||||
|
if (this->device_matrix_cache_.find(dmat) ==
|
||||||
|
this->device_matrix_cache_.end()) {
|
||||||
|
this->device_matrix_cache_.emplace(
|
||||||
|
dmat, std::unique_ptr<DeviceMatrix>(
|
||||||
|
new DeviceMatrix(dmat, param.gpu_id, param.silent)));
|
||||||
|
}
|
||||||
|
DeviceMatrix* device_matrix = device_matrix_cache_.find(dmat)->second.get();
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||||
|
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||||
|
// Copy decision trees to device
|
||||||
|
thrust::host_vector<int> h_tree_segments;
|
||||||
|
h_tree_segments.reserve((tree_end - tree_end) + 1);
|
||||||
|
int sum = 0;
|
||||||
|
h_tree_segments.push_back(sum);
|
||||||
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
|
sum += model.trees[tree_idx]->GetNodes().size();
|
||||||
|
h_tree_segments.push_back(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
|
||||||
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
|
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||||
|
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||||
|
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes.resize(h_nodes.size());
|
||||||
|
thrust::copy(h_nodes.begin(), h_nodes.end(), nodes.begin());
|
||||||
|
tree_segments.resize(h_tree_segments.size());
|
||||||
|
thrust::copy(h_tree_segments.begin(), h_tree_segments.end(),
|
||||||
|
tree_segments.begin());
|
||||||
|
tree_group.resize(model.tree_info.size());
|
||||||
|
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
|
||||||
|
tree_group.begin());
|
||||||
|
|
||||||
|
if (device_matrix->predictions.size() != out_preds->size()) {
|
||||||
|
device_matrix->predictions.resize(out_preds->size());
|
||||||
|
thrust::copy(out_preds->begin(), out_preds->end(),
|
||||||
|
device_matrix->predictions.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int BLOCK_THREADS = 128;
|
||||||
|
const int GRID_SIZE =
|
||||||
|
dh::div_round_up(device_matrix->row_ptr.size() - 1, BLOCK_THREADS);
|
||||||
|
|
||||||
|
int shared_memory_bytes =
|
||||||
|
sizeof(float) * device_matrix->p_mat->info().num_col * BLOCK_THREADS;
|
||||||
|
bool use_shared = true;
|
||||||
|
if (shared_memory_bytes > dh::max_shared_memory(param.gpu_id)) {
|
||||||
|
shared_memory_bytes = 0;
|
||||||
|
use_shared = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
PredictKernel<BLOCK_THREADS>
|
||||||
|
<<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>(
|
||||||
|
dh::raw(nodes), dh::raw(device_matrix->predictions),
|
||||||
|
dh::raw(tree_segments), dh::raw(tree_group),
|
||||||
|
device_matrix->row_ptr.data(), device_matrix->data.data(),
|
||||||
|
tree_begin, tree_end, device_matrix->p_mat->info().num_col,
|
||||||
|
device_matrix->p_mat->info().num_row, use_shared,
|
||||||
|
model.param.num_output_group);
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
thrust::copy(device_matrix->predictions.begin(),
|
||||||
|
device_matrix->predictions.end(), out_preds->begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {}
|
||||||
|
|
||||||
|
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
|
unsigned ntree_limit = 0) override {
|
||||||
|
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this->InitOutPredictions(dmat->info(), out_preds, model);
|
||||||
|
|
||||||
|
int tree_end = ntree_limit * model.param.num_output_group;
|
||||||
|
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||||
|
tree_end = static_cast<unsigned>(model.trees.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdatePredictionCache(
|
||||||
|
const gbm::GBTreeModel& model,
|
||||||
|
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||||
|
int num_new_trees) override {
|
||||||
|
// dh::Timer t;
|
||||||
|
int old_ntree = model.trees.size() - num_new_trees;
|
||||||
|
// update cache entry
|
||||||
|
for (auto& kv : cache_) {
|
||||||
|
PredictionCacheEntry& e = kv.second;
|
||||||
|
DMatrix* dmat = kv.first;
|
||||||
|
|
||||||
|
if (e.predictions.size() == 0) {
|
||||||
|
cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0,
|
||||||
|
model.trees.size());
|
||||||
|
} else if (model.param.num_output_group == 1 && updaters->size() > 0 &&
|
||||||
|
num_new_trees == 1 &&
|
||||||
|
updaters->back()->UpdatePredictionCache(e.data.get(),
|
||||||
|
&(e.predictions))) {
|
||||||
|
{} // do nothing
|
||||||
|
} else {
|
||||||
|
DevicePredictInternal(dmat, &(e.predictions), model, old_ntree,
|
||||||
|
model.trees.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PredictInstance(const SparseBatch::Inst& inst,
|
||||||
|
std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||||
|
unsigned root_index) override {
|
||||||
|
cpu_predictor->PredictInstance(inst, out_preds, model, root_index);
|
||||||
|
}
|
||||||
|
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit) override {
|
||||||
|
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PredictContribution(DMatrix* p_fmat,
|
||||||
|
std::vector<bst_float>* out_contribs,
|
||||||
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit) override {
|
||||||
|
cpu_predictor->PredictContribution(p_fmat, out_contribs, model,
|
||||||
|
ntree_limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init(const std::vector<std::pair<std::string, std::string>>& cfg,
|
||||||
|
const std::vector<std::shared_ptr<DMatrix>>& cache) override {
|
||||||
|
Predictor::Init(cfg, cache);
|
||||||
|
cpu_predictor->Init(cfg, cache);
|
||||||
|
param.InitAllowUnknown(cfg);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
GPUPredictionParam param;
|
||||||
|
std::unique_ptr<Predictor> cpu_predictor;
|
||||||
|
std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrix>>
|
||||||
|
device_matrix_cache_;
|
||||||
|
thrust::device_vector<DevicePredictionNode> nodes;
|
||||||
|
thrust::device_vector<int> tree_segments;
|
||||||
|
thrust::device_vector<int> tree_group;
|
||||||
|
};
|
||||||
|
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||||
|
.describe("Make predictions using GPU.")
|
||||||
|
.set_body([]() { return new GPUPredictor(); });
|
||||||
|
} // namespace predictor
|
||||||
|
} // namespace xgboost
|
||||||
@ -37,7 +37,7 @@ void SpeedTest() {
|
|||||||
|
|
||||||
dh::Timer t;
|
dh::Timer t;
|
||||||
dh::TransformLbs(
|
dh::TransformLbs(
|
||||||
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1,
|
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false,
|
||||||
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
|
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
@ -65,7 +65,7 @@ void TestLbs() {
|
|||||||
auto d_output_row = output_row.data();
|
auto d_output_row = output_row.data();
|
||||||
|
|
||||||
dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr),
|
dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr),
|
||||||
row_ptr.size() - 1,
|
row_ptr.size() - 1, false,
|
||||||
[=] __device__(size_t idx, size_t ridx) {
|
[=] __device__(size_t idx, size_t ridx) {
|
||||||
d_output_row[idx] = ridx;
|
d_output_row[idx] = ridx;
|
||||||
});
|
});
|
||||||
|
|||||||
73
plugin/updater_gpu/test/cpp/test_gpu_predictor.cu
Normal file
73
plugin/updater_gpu/test/cpp/test_gpu_predictor.cu
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
|
||||||
|
/*!
|
||||||
|
* Copyright 2017 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <xgboost/c_api.h>
|
||||||
|
#include <xgboost/predictor.h>
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "../../../../tests/cpp/helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace predictor {
|
||||||
|
TEST(gpu_predictor, Test) {
|
||||||
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
|
||||||
|
std::unique_ptr<Predictor> cpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
|
trees.push_back(std::make_unique<RegTree>());
|
||||||
|
trees.back()->InitModel();
|
||||||
|
(*trees.back())[0].set_leaf(1.5f);
|
||||||
|
gbm::GBTreeModel model(0.5);
|
||||||
|
model.CommitModel(std::move(trees), 0);
|
||||||
|
model.param.num_output_group = 1;
|
||||||
|
|
||||||
|
int n_row = 5;
|
||||||
|
int n_col = 5;
|
||||||
|
|
||||||
|
auto dmat = CreateDMatrix(n_row, n_col, 0);
|
||||||
|
|
||||||
|
// Test predict batch
|
||||||
|
std::vector<float> gpu_out_predictions;
|
||||||
|
std::vector<float> cpu_out_predictions;
|
||||||
|
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
|
||||||
|
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
|
||||||
|
float abs_tolerance = 0.001;
|
||||||
|
for (int i = 0; i < gpu_out_predictions.size(); i++) {
|
||||||
|
ASSERT_LT(std::abs(gpu_out_predictions[i] - cpu_out_predictions[i]),
|
||||||
|
abs_tolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict instance
|
||||||
|
auto batch = dmat->RowIterator()->Value();
|
||||||
|
for (int i = 0; i < batch.size; i++) {
|
||||||
|
std::vector<float> gpu_instance_out_predictions;
|
||||||
|
std::vector<float> cpu_instance_out_predictions;
|
||||||
|
cpu_predictor->PredictInstance(batch[i], &cpu_instance_out_predictions,
|
||||||
|
model);
|
||||||
|
gpu_predictor->PredictInstance(batch[i], &gpu_instance_out_predictions,
|
||||||
|
model);
|
||||||
|
ASSERT_EQ(gpu_instance_out_predictions[0], cpu_instance_out_predictions[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict leaf
|
||||||
|
std::vector<float> gpu_leaf_out_predictions;
|
||||||
|
std::vector<float> cpu_leaf_out_predictions;
|
||||||
|
cpu_predictor->PredictLeaf(dmat.get(), &cpu_leaf_out_predictions, model);
|
||||||
|
gpu_predictor->PredictLeaf(dmat.get(), &gpu_leaf_out_predictions, model);
|
||||||
|
for (int i = 0; i < gpu_leaf_out_predictions.size(); i++) {
|
||||||
|
ASSERT_EQ(gpu_leaf_out_predictions[i], cpu_leaf_out_predictions[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict contribution
|
||||||
|
std::vector<float> gpu_out_contribution;
|
||||||
|
std::vector<float> cpu_out_contribution;
|
||||||
|
cpu_predictor->PredictContribution(dmat.get(), &cpu_out_contribution, model);
|
||||||
|
gpu_predictor->PredictContribution(dmat.get(), &gpu_out_contribution, model);
|
||||||
|
for (int i = 0; i < gpu_out_contribution.size(); i++) {
|
||||||
|
ASSERT_EQ(gpu_out_contribution[i], cpu_out_contribution[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace predictor
|
||||||
|
} // namespace xgboost
|
||||||
@ -110,5 +110,3 @@ class TestGPU(unittest.TestCase):
|
|||||||
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
37
plugin/updater_gpu/test/python/test_prediction.py
Normal file
37
plugin/updater_gpu/test/python/test_prediction.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
#pylint: skip-file
|
||||||
|
import sys
|
||||||
|
sys.path.append("../../tests/python")
|
||||||
|
import xgboost as xgb
|
||||||
|
import testing as tm
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGPUPredict (unittest.TestCase):
|
||||||
|
def test_predict(self):
|
||||||
|
iterations = 1
|
||||||
|
np.random.seed(1)
|
||||||
|
test_num_rows = [10,1000,5000]
|
||||||
|
test_num_cols = [10,50,500]
|
||||||
|
for num_rows in test_num_rows:
|
||||||
|
for num_cols in test_num_cols:
|
||||||
|
dm = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows/2))
|
||||||
|
watchlist = [(dm, 'train')]
|
||||||
|
res = {}
|
||||||
|
param = {
|
||||||
|
"objective":"binary:logistic",
|
||||||
|
"predictor":"gpu_predictor",
|
||||||
|
'eval_metric': 'auc',
|
||||||
|
}
|
||||||
|
bst = xgb.train(param, dm,iterations,evals=watchlist, evals_result=res)
|
||||||
|
assert self.non_decreasing(res["train"]["auc"])
|
||||||
|
gpu_pred = bst.predict(dm, output_margin=True)
|
||||||
|
bst.set_param({"predictor":"cpu_predictor"})
|
||||||
|
cpu_pred = bst.predict(dm, output_margin=True)
|
||||||
|
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||||
|
|
||||||
|
def non_decreasing(self, L):
|
||||||
|
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||||
@ -45,6 +45,7 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
|||||||
int process_type;
|
int process_type;
|
||||||
// flag to print out detailed breakdown of runtime
|
// flag to print out detailed breakdown of runtime
|
||||||
int debug_verbose;
|
int debug_verbose;
|
||||||
|
std::string predictor;
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||||
DMLC_DECLARE_FIELD(num_parallel_tree)
|
DMLC_DECLARE_FIELD(num_parallel_tree)
|
||||||
@ -67,6 +68,9 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
|||||||
.describe("flag to print out detailed breakdown of runtime");
|
.describe("flag to print out detailed breakdown of runtime");
|
||||||
// add alias
|
// add alias
|
||||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||||
|
DMLC_DECLARE_FIELD(predictor)
|
||||||
|
.set_default("cpu_predictor")
|
||||||
|
.describe("Predictor algorithm type");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -130,13 +134,10 @@ struct CacheEntry {
|
|||||||
// gradient boosted trees
|
// gradient boosted trees
|
||||||
class GBTree : public GradientBooster {
|
class GBTree : public GradientBooster {
|
||||||
public:
|
public:
|
||||||
explicit GBTree(bst_float base_margin)
|
explicit GBTree(bst_float base_margin) : model_(base_margin) {}
|
||||||
: model_(base_margin),
|
|
||||||
predictor(
|
|
||||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"))) {}
|
|
||||||
|
|
||||||
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
|
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
|
||||||
predictor->InitCache(cache);
|
cache_ = cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
||||||
@ -153,6 +154,10 @@ class GBTree : public GradientBooster {
|
|||||||
if (tparam.process_type == kUpdate) {
|
if (tparam.process_type == kUpdate) {
|
||||||
model_.InitTreesToUpdate();
|
model_.InitTreesToUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configure predictor
|
||||||
|
predictor = std::unique_ptr<Predictor>(Predictor::Create(tparam.predictor));
|
||||||
|
predictor->Init(cfg, cache_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Load(dmlc::Stream* fi) override {
|
void Load(dmlc::Stream* fi) override {
|
||||||
@ -300,7 +305,8 @@ class GBTree : public GradientBooster {
|
|||||||
std::vector<std::pair<std::string, std::string> > cfg;
|
std::vector<std::pair<std::string, std::string> > cfg;
|
||||||
// the updaters that can be applied to each of tree
|
// the updaters that can be applied to each of tree
|
||||||
std::vector<std::unique_ptr<TreeUpdater>> updaters;
|
std::vector<std::unique_ptr<TreeUpdater>> updaters;
|
||||||
|
// Cached matrices
|
||||||
|
std::vector<std::shared_ptr<DMatrix>> cache_;
|
||||||
std::unique_ptr<Predictor> predictor;
|
std::unique_ptr<Predictor> predictor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -165,9 +165,19 @@ class LearnerImpl : public Learner {
|
|||||||
<< "grow_fast_histmaker.";
|
<< "grow_fast_histmaker.";
|
||||||
cfg_["updater"] = "grow_fast_histmaker";
|
cfg_["updater"] = "grow_fast_histmaker";
|
||||||
} else if (tparam.tree_method == 4) {
|
} else if (tparam.tree_method == 4) {
|
||||||
cfg_["updater"] = "grow_gpu,prune";
|
if (cfg_.count("updater") == 0) {
|
||||||
|
cfg_["updater"] = "grow_gpu,prune";
|
||||||
|
}
|
||||||
|
if (cfg_.count("predictor") == 0) {
|
||||||
|
cfg_["predictor"] = "gpu_predictor";
|
||||||
|
}
|
||||||
} else if (tparam.tree_method == 5) {
|
} else if (tparam.tree_method == 5) {
|
||||||
cfg_["updater"] = "grow_gpu_hist";
|
if (cfg_.count("updater") == 0) {
|
||||||
|
cfg_["updater"] = "grow_gpu_hist";
|
||||||
|
}
|
||||||
|
if (cfg_.count("predictor") == 0) {
|
||||||
|
cfg_["predictor"] = "gpu_predictor";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||||
|
|
||||||
class CPUPredictor : public Predictor {
|
class CPUPredictor : public Predictor {
|
||||||
protected:
|
protected:
|
||||||
static bst_float PredValue(const RowBatch::Inst& inst,
|
static bst_float PredValue(const RowBatch::Inst& inst,
|
||||||
@ -28,19 +30,6 @@ class CPUPredictor : public Predictor {
|
|||||||
return psum;
|
return psum;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitOutPredictions(const MetaInfo& info,
|
|
||||||
std::vector<bst_float>* out_preds,
|
|
||||||
const gbm::GBTreeModel& model) const {
|
|
||||||
size_t n = model.param.num_output_group * info.num_row;
|
|
||||||
const std::vector<bst_float>& base_margin = info.base_margin;
|
|
||||||
out_preds->resize(n);
|
|
||||||
if (base_margin.size() != 0) {
|
|
||||||
CHECK_EQ(out_preds->size(), n);
|
|
||||||
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
|
|
||||||
} else {
|
|
||||||
std::fill(out_preds->begin(), out_preds->end(), model.base_margin);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// init thread buffers
|
// init thread buffers
|
||||||
inline void InitThreadTemp(int nthread, int num_feature) {
|
inline void InitThreadTemp(int nthread, int num_feature) {
|
||||||
int prev_thread_temp_size = thread_temp.size();
|
int prev_thread_temp_size = thread_temp.size();
|
||||||
@ -106,33 +95,6 @@ class CPUPredictor : public Predictor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* \fn bool PredictFromCache(DMatrix* dmat, std::vector<bst_float>*
|
|
||||||
* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0)
|
|
||||||
*
|
|
||||||
* \brief Attempt to predict from cache.
|
|
||||||
*
|
|
||||||
* \return True if it succeeds, false if it fails.
|
|
||||||
*/
|
|
||||||
bool PredictFromCache(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
|
||||||
const gbm::GBTreeModel& model,
|
|
||||||
unsigned ntree_limit = 0) {
|
|
||||||
if (ntree_limit == 0 ||
|
|
||||||
ntree_limit * model.param.num_output_group >= model.trees.size()) {
|
|
||||||
auto it = cache_.find(dmat);
|
|
||||||
if (it != cache_.end()) {
|
|
||||||
std::vector<bst_float>& y = it->second.predictions;
|
|
||||||
if (y.size() != 0) {
|
|
||||||
out_preds->resize(y.size());
|
|
||||||
std::copy(y.begin(), y.end(), out_preds->begin());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void PredLoopInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
void PredLoopInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, int tree_begin,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
unsigned ntree_limit) {
|
unsigned ntree_limit) {
|
||||||
|
|||||||
@ -8,13 +8,47 @@ namespace dmlc {
|
|||||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||||
} // namespace dmlc
|
} // namespace dmlc
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
void Predictor::InitCache(const std::vector<std::shared_ptr<DMatrix> >& cache) {
|
void Predictor::Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& cfg,
|
||||||
|
const std::vector<std::shared_ptr<DMatrix>>& cache) {
|
||||||
for (const std::shared_ptr<DMatrix>& d : cache) {
|
for (const std::shared_ptr<DMatrix>& d : cache) {
|
||||||
PredictionCacheEntry e;
|
PredictionCacheEntry e;
|
||||||
e.data = d;
|
e.data = d;
|
||||||
cache_[d.get()] = std::move(e);
|
cache_[d.get()] = std::move(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bool Predictor::PredictFromCache(DMatrix* dmat,
|
||||||
|
std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model,
|
||||||
|
unsigned ntree_limit) {
|
||||||
|
if (ntree_limit == 0 ||
|
||||||
|
ntree_limit * model.param.num_output_group >= model.trees.size()) {
|
||||||
|
auto it = cache_.find(dmat);
|
||||||
|
if (it != cache_.end()) {
|
||||||
|
std::vector<bst_float>& y = it->second.predictions;
|
||||||
|
if (y.size() != 0) {
|
||||||
|
out_preds->resize(y.size());
|
||||||
|
std::copy(y.begin(), y.end(), out_preds->begin());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void Predictor::InitOutPredictions(const MetaInfo& info,
|
||||||
|
std::vector<bst_float>* out_preds,
|
||||||
|
const gbm::GBTreeModel& model) const {
|
||||||
|
size_t n = model.param.num_output_group * info.num_row;
|
||||||
|
const std::vector<bst_float>& base_margin = info.base_margin;
|
||||||
|
out_preds->resize(n);
|
||||||
|
if (base_margin.size() != 0) {
|
||||||
|
CHECK_EQ(out_preds->size(), n);
|
||||||
|
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
|
||||||
|
} else {
|
||||||
|
std::fill(out_preds->begin(), out_preds->end(), model.base_margin);
|
||||||
|
}
|
||||||
|
}
|
||||||
Predictor* Predictor::Create(std::string name) {
|
Predictor* Predictor::Create(std::string name) {
|
||||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
@ -23,3 +57,13 @@ Predictor* Predictor::Create(std::string name) {
|
|||||||
return (e->body)();
|
return (e->body)();
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace predictor {
|
||||||
|
// List of files that will be force linked in static links.
|
||||||
|
#ifdef XGBOOST_USE_CUDA
|
||||||
|
DMLC_REGISTRY_LINK_TAG(gpu_predictor);
|
||||||
|
#endif
|
||||||
|
DMLC_REGISTRY_LINK_TAG(cpu_predictor);
|
||||||
|
} // namespace predictor
|
||||||
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
#include "./helpers.h"
|
#include "./helpers.h"
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
|
#include <random>
|
||||||
|
|
||||||
std::string TempFileName() {
|
std::string TempFileName() {
|
||||||
return std::tmpnam(nullptr);
|
return std::tmpnam(nullptr);
|
||||||
@ -60,3 +62,23 @@ xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
|||||||
info.weights = weights;
|
info.weights = weights;
|
||||||
return metric->Eval(preds, info, false);
|
return metric->Eval(preds, info, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns,
|
||||||
|
float sparsity, int seed) {
|
||||||
|
const float missing_value = -1;
|
||||||
|
std::vector<float> test_data(rows * columns);
|
||||||
|
std::mt19937 gen(seed);
|
||||||
|
std::uniform_real_distribution<float> dis(0.0f, 1.0f);
|
||||||
|
for (auto &e : test_data) {
|
||||||
|
if (dis(gen) < sparsity) {
|
||||||
|
e = missing_value;
|
||||||
|
} else {
|
||||||
|
e = dis(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DMatrixHandle handle;
|
||||||
|
XGDMatrixCreateFromMat(test_data.data(), rows, columns, missing_value,
|
||||||
|
&handle);
|
||||||
|
return *static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
|
}
|
||||||
|
|||||||
@ -36,4 +36,19 @@ xgboost::bst_float GetMetricEval(
|
|||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns, float sparsity, int seed);
|
||||||
|
*
|
||||||
|
* \brief Creates dmatrix with uniform random data between 0-1.
|
||||||
|
*
|
||||||
|
* \param rows The rows.
|
||||||
|
* \param columns The columns.
|
||||||
|
* \param sparsity The sparsity.
|
||||||
|
* \param seed The seed.
|
||||||
|
*
|
||||||
|
* \return The new d matrix.
|
||||||
|
*/
|
||||||
|
|
||||||
|
std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns,
|
||||||
|
float sparsity, int seed = 0);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
54
tests/cpp/predictor/test_cpu_predictor.cc
Normal file
54
tests/cpp/predictor/test_cpu_predictor.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/predictor.h>
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
TEST(cpu_predictor, Test) {
|
||||||
|
std::unique_ptr<Predictor> cpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
|
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||||
|
trees.back()->InitModel();
|
||||||
|
(*trees.back())[0].set_leaf(1.5f);
|
||||||
|
gbm::GBTreeModel model(0.5);
|
||||||
|
model.CommitModel(std::move(trees), 0);
|
||||||
|
model.param.num_output_group = 1;
|
||||||
|
model.base_margin = 0;
|
||||||
|
|
||||||
|
int n_row = 5;
|
||||||
|
int n_col = 5;
|
||||||
|
|
||||||
|
auto dmat = CreateDMatrix(n_row, n_col, 0);
|
||||||
|
|
||||||
|
// Test predict batch
|
||||||
|
std::vector<float> out_predictions;
|
||||||
|
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
|
for (int i = 0; i < out_predictions.size(); i++) {
|
||||||
|
ASSERT_EQ(out_predictions[i], 1.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict instance
|
||||||
|
auto batch = dmat->RowIterator()->Value();
|
||||||
|
for (int i = 0; i < batch.size; i++) {
|
||||||
|
std::vector<float> instance_out_predictions;
|
||||||
|
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
|
||||||
|
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict leaf
|
||||||
|
std::vector<float> leaf_out_predictions;
|
||||||
|
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||||
|
for (int i = 0; i < leaf_out_predictions.size(); i++) {
|
||||||
|
ASSERT_EQ(leaf_out_predictions[i], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict contribution
|
||||||
|
std::vector<float> out_contribution;
|
||||||
|
cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model);
|
||||||
|
for (int i = 0; i < out_contribution.size(); i++) {
|
||||||
|
ASSERT_EQ(out_contribution[i], 1.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
14
tests/cpp/test_learner.cc
Normal file
14
tests/cpp/test_learner.cc
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "helpers.h"
|
||||||
|
#include "xgboost/learner.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
TEST(learner, Test) {
|
||||||
|
typedef std::pair<std::string, std::string> arg;
|
||||||
|
auto args = {arg("tree_method", "exact")};
|
||||||
|
auto mat = {CreateDMatrix(10, 10, 0)};
|
||||||
|
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||||
|
learner->Configure(args);
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user