[GPU-Plugin] Resolve double compilation issue (#2479)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include "updater_gpu.cuh"
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "../../../src/common/random.h"
|
||||
#include "../../../src/common/sync.h"
|
||||
#include "../../../src/tree/param.h"
|
||||
@@ -11,87 +14,64 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
|
||||
|
||||
/*! \brief column-wise update to construct a tree */
|
||||
template <typename TStats>
|
||||
class GPUMaker : public TreeUpdater {
|
||||
public:
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
builder.Init(param);
|
||||
}
|
||||
GPUMaker::GPUMaker() : builder(new exact::GPUBuilder<int16_t>()) {}
|
||||
|
||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
TStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
builder.UpdateParam(param);
|
||||
void GPUMaker::Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||
param.InitAllowUnknown(args);
|
||||
builder->Init(param);
|
||||
}
|
||||
|
||||
try {
|
||||
// build tree
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
builder.Update(gpair, dmat, trees[i]);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
void GPUMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
GradStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
builder->UpdateParam(param);
|
||||
|
||||
protected:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
exact::GPUBuilder<int16_t> builder;
|
||||
};
|
||||
|
||||
template <typename TStats>
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
public:
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
builder.Init(param);
|
||||
}
|
||||
|
||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
TStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
builder.UpdateParam(param);
|
||||
try {
|
||||
// build tree
|
||||
try {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
builder.Update(gpair, dmat, trees[i]);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
builder->Update(gpair, dmat, trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) override {
|
||||
return builder.UpdatePredictionCache(data, out_preds);
|
||||
GPUHistMaker::GPUHistMaker() : builder(new GPUHistBuilder()) {}
|
||||
|
||||
void GPUHistMaker::Init(
|
||||
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||
param.InitAllowUnknown(args);
|
||||
builder->Init(param);
|
||||
}
|
||||
|
||||
void GPUHistMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
GradStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
builder->UpdateParam(param);
|
||||
// build tree
|
||||
try {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
builder->Update(gpair, dmat, trees[i]);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
protected:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
GPUHistBuilder builder;
|
||||
};
|
||||
bool GPUHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) {
|
||||
return builder->UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([]() { return new GPUMaker<GradStats>(); });
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([]() { return new GPUHistMaker<GradStats>(); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user