[GPU-Plugin] Add GPU accelerated prediction (#2593)
* [GPU-Plugin] Add GPU accelerated prediction * Improve allocation message * Update documentation * Resolve linker error for predictor * Add unit tests
This commit is contained in:
@@ -45,6 +45,7 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
int process_type;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
std::string predictor;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||
DMLC_DECLARE_FIELD(num_parallel_tree)
|
||||
@@ -67,6 +68,9 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
// add alias
|
||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||
DMLC_DECLARE_FIELD(predictor)
|
||||
.set_default("cpu_predictor")
|
||||
.describe("Predictor algorithm type");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -130,13 +134,10 @@ struct CacheEntry {
|
||||
// gradient boosted trees
|
||||
class GBTree : public GradientBooster {
|
||||
public:
|
||||
explicit GBTree(bst_float base_margin)
|
||||
: model_(base_margin),
|
||||
predictor(
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"))) {}
|
||||
explicit GBTree(bst_float base_margin) : model_(base_margin) {}
|
||||
|
||||
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
|
||||
predictor->InitCache(cache);
|
||||
cache_ = cache;
|
||||
}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
|
||||
@@ -153,6 +154,10 @@ class GBTree : public GradientBooster {
|
||||
if (tparam.process_type == kUpdate) {
|
||||
model_.InitTreesToUpdate();
|
||||
}
|
||||
|
||||
// configure predictor
|
||||
predictor = std::unique_ptr<Predictor>(Predictor::Create(tparam.predictor));
|
||||
predictor->Init(cfg, cache_);
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
@@ -300,7 +305,8 @@ class GBTree : public GradientBooster {
|
||||
std::vector<std::pair<std::string, std::string> > cfg;
|
||||
// the updaters that can be applied to each of tree
|
||||
std::vector<std::unique_ptr<TreeUpdater>> updaters;
|
||||
|
||||
// Cached matrices
|
||||
std::vector<std::shared_ptr<DMatrix>> cache_;
|
||||
std::unique_ptr<Predictor> predictor;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user