Objective function evaluation on GPU with minimal PCIe transfers (#2935)

* Added GPU objective function and no-copy interface.

- xgboost::HostDeviceVector<T> syncs automatically between host and device
- no-copy interfaces have been added
- default implementations just sync the data to host
  and call the implementations with std::vector
- GPU objective function, predictor, histogram updater process data
  directly on GPU
This commit is contained in:
Thejaswi 2018-01-12 14:03:39 +05:30 committed by Rory Mitchell
parent a187ed6c8f
commit 84ab74f3a5
23 changed files with 1036 additions and 127 deletions

View File

@ -57,6 +57,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc"
// c_api

View File

@ -165,6 +165,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
- "reg:logistic" --logistic regression
- "binary:logistic" --logistic regression for binary classification, output probability
- "binary:logitraw" --logistic regression for binary classification, output score before logistic transformation
- "gpu:reg:linear", "gpu:reg:logistic", "gpu:binary:logistic", gpu:binary:logitraw" --versions
of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm,
they can only be used when the entire training session uses the same dataset
- "count:poisson" --poisson regression for count data, output mean of poisson distribution
- max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)

View File

@ -18,6 +18,7 @@
#include "./data.h"
#include "./objective.h"
#include "./feature_map.h"
#include "../../src/common/host_device_vector.h"
namespace xgboost {
/*!
@ -70,6 +71,10 @@ class GradientBooster {
virtual void DoBoost(DMatrix* p_fmat,
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr) = 0;
virtual void DoBoost(DMatrix* p_fmat,
HostDeviceVector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr);
/*!
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
@ -80,6 +85,9 @@ class GradientBooster {
virtual void PredictBatch(DMatrix* dmat,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
virtual void PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0);
/*!
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is usually

View File

@ -14,8 +14,11 @@
#include <functional>
#include "./data.h"
#include "./base.h"
#include "../../src/common/host_device_vector.h"
namespace xgboost {
/*! \brief interface of objective function */
class ObjFunction {
public:
@ -45,6 +48,11 @@ class ObjFunction {
const MetaInfo& info,
int iteration,
std::vector<bst_gpair>* out_gpair) = 0;
virtual void GetGradient(HostDeviceVector<bst_float>* preds,
const MetaInfo& info,
int iteration,
HostDeviceVector<bst_gpair>* out_gpair);
/*! \return the default evaluation metric for the objective */
virtual const char* DefaultEvalMetric() const = 0;
// the following functions are optional, most of time default implementation is good enough
@ -53,6 +61,8 @@ class ObjFunction {
* \param io_preds prediction values, saves to this vector as well
*/
virtual void PredTransform(std::vector<bst_float> *io_preds) {}
virtual void PredTransform(HostDeviceVector<bst_float> *io_preds);
/*!
* \brief transform prediction values, this is only called when Eval is called,
* usually it redirect to PredTransform
@ -61,6 +71,9 @@ class ObjFunction {
virtual void EvalTransform(std::vector<bst_float> *io_preds) {
this->PredTransform(io_preds);
}
virtual void EvalTransform(HostDeviceVector<bst_float> *io_preds) {
this->PredTransform(io_preds);
}
/*!
* \brief transform probability value back to margin
* this is used to transform user-set base_score back to margin

View File

@ -13,6 +13,7 @@
#include <utility>
#include <vector>
#include "../../src/gbm/gbtree_model.h"
#include "../../src/common/host_device_vector.h"
// Forward declarations
namespace xgboost {
@ -51,10 +52,6 @@ class Predictor {
const std::vector<std::shared_ptr<DMatrix>>& cache);
/**
* \fn virtual void Predictor::PredictBatch( DMatrix* dmat,
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model, int
* tree_begin, unsigned ntree_limit = 0) = 0;
*
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.
*
@ -70,6 +67,22 @@ class Predictor {
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) = 0;
/**
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.
*
* \param [in,out] dmat Feature matrix.
* \param [in,out] out_preds The output preds.
* \param model The model to predict from.
* \param tree_begin The tree begin index.
* \param ntree_limit (Optional) The ntree limit. 0 means do not
* limit trees.
*/
virtual void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) = 0;
/**
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
* &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters, int

View File

@ -16,6 +16,7 @@
#include "./base.h"
#include "./data.h"
#include "./tree_model.h"
#include "../../src/common/host_device_vector.h"
namespace xgboost {
/*!
@ -42,6 +43,9 @@ class TreeUpdater {
virtual void Update(const std::vector<bst_gpair>& gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) = 0;
virtual void Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees);
/*!
* \brief determines whether updater has enough knowledge about a given dataset
@ -57,6 +61,9 @@ class TreeUpdater {
std::vector<bst_float>* out_preds) {
return false;
}
virtual bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds);
/*!
* \brief Create a tree updater given name
* \param name Name of the tree updater.

View File

@ -484,6 +484,13 @@ class bulk_allocator {
}
public:
bulk_allocator() {}
// prevent accidental copying, moving or assignment of this object
bulk_allocator(const bulk_allocator<MemoryT>&) = delete;
bulk_allocator(bulk_allocator<MemoryT>&&) = delete;
void operator=(const bulk_allocator<MemoryT>&) = delete;
void operator=(bulk_allocator<MemoryT>&&) = delete;
~bulk_allocator() {
for (size_t i = 0; i < d_ptr.size(); i++) {
if (!(d_ptr[i] == nullptr)) {

View File

@ -0,0 +1,54 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#ifndef XGBOOST_USE_CUDA
// dummy implementation of HostDeviceVector in case CUDA is not used
#include <xgboost/base.h>
#include "./host_device_vector.h"
namespace xgboost {
template <typename T>
struct HostDeviceVectorImpl {
explicit HostDeviceVectorImpl(size_t size) : data_h_(size) {}
std::vector<T> data_h_;
};
template <typename T>
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(size);
}
template <typename T>
HostDeviceVector<T>::~HostDeviceVector() {
HostDeviceVectorImpl<T>* tmp = impl_;
impl_ = nullptr;
delete tmp;
}
template <typename T>
size_t HostDeviceVector<T>::size() const { return impl_->data_h_.size(); }
template <typename T>
int HostDeviceVector<T>::device() const { return -1; }
template <typename T>
T* HostDeviceVector<T>::ptr_d(int device) { return nullptr; }
template <typename T>
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h_; }
template <typename T>
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
impl_->data_h_.resize(new_size);
}
// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<bst_gpair>;
} // namespace xgboost
#endif

View File

@ -0,0 +1,135 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include "./host_device_vector.h"
#include "./device_helpers.cuh"
namespace xgboost {
template <typename T>
struct HostDeviceVectorImpl {
HostDeviceVectorImpl(size_t size, int device)
: device_(device), on_d_(device >= 0) {
if (on_d_) {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(size);
} else {
data_h_.resize(size);
}
}
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>&) = delete;
HostDeviceVectorImpl(HostDeviceVectorImpl<T>&&) = delete;
void operator=(const HostDeviceVectorImpl<T>&) = delete;
void operator=(HostDeviceVectorImpl<T>&&) = delete;
size_t size() const { return on_d_ ? data_d_.size() : data_h_.size(); }
int device() const { return device_; }
T* ptr_d(int device) {
lazy_sync_device(device);
return data_d_.data().get();
}
thrust::device_ptr<T> tbegin(int device) {
return thrust::device_ptr<T>(ptr_d(device));
}
thrust::device_ptr<T> tend(int device) {
auto begin = tbegin(device);
return begin + size();
}
std::vector<T>& data_h() {
lazy_sync_host();
return data_h_;
}
void resize(size_t new_size, int new_device) {
if (new_size == this->size() && new_device == device_)
return;
device_ = new_device;
// if !on_d_, but the data size is 0 and the device is set,
// resize the data on device instead
if (!on_d_ && (data_h_.size() > 0 || device_ == -1)) {
data_h_.resize(new_size);
} else {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(new_size);
on_d_ = true;
}
}
void lazy_sync_host() {
if (!on_d_)
return;
if (data_h_.size() != this->size())
data_h_.resize(this->size());
dh::safe_cuda(cudaSetDevice(device_));
thrust::copy(data_d_.begin(), data_d_.end(), data_h_.begin());
on_d_ = false;
}
void lazy_sync_device(int device) {
if (on_d_)
return;
if (device != device_) {
CHECK_EQ(device_, -1);
device_ = device;
}
if (data_d_.size() != this->size()) {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(this->size());
}
dh::safe_cuda(cudaSetDevice(device_));
thrust::copy(data_h_.begin(), data_h_.end(), data_d_.begin());
on_d_ = true;
}
std::vector<T> data_h_;
thrust::device_vector<T> data_d_;
// true if there is an up-to-date copy of data on device, false otherwise
bool on_d_;
int device_;
};
template <typename T>
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(size, device);
}
template <typename T>
HostDeviceVector<T>::~HostDeviceVector() {
HostDeviceVectorImpl<T>* tmp = impl_;
impl_ = nullptr;
delete tmp;
}
template <typename T>
size_t HostDeviceVector<T>::size() const { return impl_->size(); }
template <typename T>
int HostDeviceVector<T>::device() const { return impl_->device(); }
template <typename T>
T* HostDeviceVector<T>::ptr_d(int device) { return impl_->ptr_d(device); }
template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tbegin(int device) {
return impl_->tbegin(device);
}
template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tend(int device) {
return impl_->tend(device);
}
template <typename T>
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h(); }
template <typename T>
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
impl_->resize(new_size, new_device);
}
// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<bst_gpair>;
} // namespace xgboost

View File

@ -0,0 +1,100 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
#include <cstdlib>
#include <vector>
// only include thrust-related files if host_device_vector.h
// is included from a .cu file
#ifdef __CUDACC__
#include <thrust/device_ptr.h>
#endif
namespace xgboost {
template <typename T> struct HostDeviceVectorImpl;
/**
* @file host_device_vector.h
* @brief A device-and-host vector abstraction layer.
*
* Why HostDeviceVector?<br/>
* With CUDA, one has to explicitly manage memory through 'cudaMemcpy' calls.
* This wrapper class hides this management from the users, thereby making it
* easy to integrate GPU/CPU usage under a single interface.
*
* Initialization/Allocation:<br/>
* One can choose to initialize the vector on CPU or GPU during constructor.
* (use the 'device' argument) Or, can choose to use the 'resize' method to
* allocate/resize memory explicitly.
*
* Accessing underling data:<br/>
* Use 'data_h' method to explicitly query for the underlying std::vector.
* If you need the raw device pointer, use the 'ptr_d' method. For perf
* implications of these calls, see below.
*
* Accessing underling data and their perf implications:<br/>
* There are 4 scenarios to be considered here:
* data_h and data on CPU --> no problems, std::vector returned immediately
* data_h but data on GPU --> this causes a cudaMemcpy to be issued internally.
* subsequent calls to data_h, will NOT incur this penalty.
* (assuming 'ptr_d' is not called in between)
* ptr_d but data on CPU --> this causes a cudaMemcpy to be issued internally.
* subsequent calls to ptr_d, will NOT incur this penalty.
* (assuming 'data_h' is not called in between)
* ptr_d and data on GPU --> no problems, the device ptr will be returned immediately
*
* What if xgboost is compiled without CUDA?<br/>
* In that case, there's a special implementation which always falls-back to
* working with std::vector. This logic can be found in host_device_vector.cc
*
* Why not consider CUDA unified memory?<br/>
* We did consider. However, it poses complications if we need to support both
* compiling with and without CUDA toolkit. It was easier to have
* 'HostDeviceVector' with a special-case implementation in host_device_vector.cc
*
* @note: This is not thread-safe!
*/
template <typename T>
class HostDeviceVector {
public:
explicit HostDeviceVector(size_t size = 0, int device = -1);
~HostDeviceVector();
HostDeviceVector(const HostDeviceVector<T>&) = delete;
HostDeviceVector(HostDeviceVector<T>&&) = delete;
void operator=(const HostDeviceVector<T>&) = delete;
void operator=(HostDeviceVector<T>&&) = delete;
size_t size() const;
int device() const;
T* ptr_d(int device);
// only define functions returning device_ptr
// if HostDeviceVector.h is included from a .cu file
#ifdef __CUDACC__
thrust::device_ptr<T> tbegin(int device);
thrust::device_ptr<T> tend(int device);
#endif
std::vector<T>& data_h();
void resize(size_t new_size, int new_device);
// helper functions in case a function needs to be templated
// to work for both HostDeviceVector and std::vector
static std::vector<T>& data_h(HostDeviceVector<T>* v) {
return v->data_h();
}
static std::vector<T>& data_h(std::vector<T>* v) {
return *v;
}
private:
HostDeviceVectorImpl<T>* impl_;
};
} // namespace xgboost
#endif // XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_

View File

@ -20,8 +20,8 @@ namespace common {
* \param x input parameter
* \return the transformed value.
*/
inline float Sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x));
XGBOOST_DEVICE inline float Sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}
inline avx::Float8 Sigmoid(avx::Float8 x) {

View File

@ -21,6 +21,19 @@ GradientBooster* GradientBooster::Create(
}
return (e->body)(cache_mats, base_margin);
}
void GradientBooster::DoBoost(DMatrix* p_fmat,
HostDeviceVector<bst_gpair>* in_gpair,
ObjFunction* obj) {
DoBoost(p_fmat, &in_gpair->data_h(), obj);
}
void GradientBooster::PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) {
PredictBatch(dmat, &out_preds->data_h(), ntree_limit);
}
} // namespace xgboost
namespace xgboost {

View File

@ -18,6 +18,7 @@
#include <limits>
#include <algorithm>
#include "../common/common.h"
#include "../common/host_device_vector.h"
#include "../common/random.h"
#include "gbtree_model.h"
#include "../common/timer.h"
@ -182,35 +183,13 @@ class GBTree : public GradientBooster {
void DoBoost(DMatrix* p_fmat,
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj) override {
const std::vector<bst_gpair>& gpair = *in_gpair;
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
const int ngroup = model_.param.num_output_group;
monitor.Start("BoostNewTrees");
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(gpair, p_fmat, 0, &ret);
new_trees.push_back(std::move(ret));
} else {
CHECK_EQ(gpair.size() % ngroup, 0U)
<< "must have exactly ngroup*nrow gpairs";
std::vector<bst_gpair> tmp(gpair.size() / ngroup);
for (int gid = 0; gid < ngroup; ++gid) {
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
tmp[i] = gpair[i * ngroup + gid];
}
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(tmp, p_fmat, gid, &ret);
new_trees.push_back(std::move(ret));
}
}
monitor.Stop("BoostNewTrees");
monitor.Start("CommitModel");
for (int gid = 0; gid < ngroup; ++gid) {
this->CommitModel(std::move(new_trees[gid]), gid);
}
monitor.Stop("CommitModel");
DoBoostHelper(p_fmat, in_gpair, obj);
}
void DoBoost(DMatrix* p_fmat,
HostDeviceVector<bst_gpair>* in_gpair,
ObjFunction* obj) override {
DoBoostHelper(p_fmat, in_gpair, obj);
}
void PredictBatch(DMatrix* p_fmat,
@ -219,6 +198,12 @@ class GBTree : public GradientBooster {
predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override {
predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
void PredictInstance(const SparseBatch::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
@ -257,9 +242,48 @@ class GBTree : public GradientBooster {
updaters.push_back(std::move(up));
}
}
// TVec is either std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
template <typename TVec>
void DoBoostHelper(DMatrix* p_fmat,
TVec* in_gpair,
ObjFunction* obj) {
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
const int ngroup = model_.param.num_output_group;
monitor.Start("BoostNewTrees");
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
new_trees.push_back(std::move(ret));
} else {
CHECK_EQ(in_gpair->size() % ngroup, 0U)
<< "must have exactly ngroup*nrow gpairs";
std::vector<bst_gpair> tmp(in_gpair->size() / ngroup);
auto& gpair_h = HostDeviceVector<bst_gpair>::data_h(in_gpair);
for (int gid = 0; gid < ngroup; ++gid) {
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
tmp[i] = gpair_h[i * ngroup + gid];
}
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(&tmp, p_fmat, gid, &ret);
new_trees.push_back(std::move(ret));
}
}
monitor.Stop("BoostNewTrees");
monitor.Start("CommitModel");
for (int gid = 0; gid < ngroup; ++gid) {
this->CommitModel(std::move(new_trees[gid]), gid);
}
monitor.Stop("CommitModel");
}
// do group specific group
// TVec is either const std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
template <typename TVec>
inline void
BoostNewTrees(const std::vector<bst_gpair> &gpair,
BoostNewTrees(TVec* gpair,
DMatrix *p_fmat,
int bst_group,
std::vector<std::unique_ptr<RegTree> >* ret) {
@ -286,9 +310,24 @@ class GBTree : public GradientBooster {
}
// update the trees
for (auto& up : updaters) {
up->Update(gpair, p_fmat, new_trees);
UpdateHelper(up.get(), gpair, p_fmat, new_trees);
}
}
void UpdateHelper(TreeUpdater* updater,
std::vector<bst_gpair>* gpair,
DMatrix *p_fmat,
const std::vector<RegTree*>& new_trees) {
updater->Update(*gpair, p_fmat, new_trees);
}
void UpdateHelper(TreeUpdater* updater,
HostDeviceVector<bst_gpair>* gpair,
DMatrix *p_fmat,
const std::vector<RegTree*>& new_trees) {
updater->Update(gpair, p_fmat, new_trees);
}
// commit new trees all at once
virtual void
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,

View File

@ -16,10 +16,12 @@
#include <utility>
#include <vector>
#include "./common/common.h"
#include "./common/host_device_vector.h"
#include "./common/io.h"
#include "./common/random.h"
#include "common/timer.h"
namespace xgboost {
// implementation of base learner.
bool Learner::AllowLazyCheckPoint() const {
@ -360,10 +362,10 @@ class LearnerImpl : public Learner {
}
this->LazyInitDMatrix(train);
monitor.Start("PredictRaw");
this->PredictRaw(train, &preds_);
this->PredictRaw(train, &preds2_);
monitor.Stop("PredictRaw");
monitor.Start("GetGradient");
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
obj_->GetGradient(&preds2_, train->info(), iter, &gpair_);
monitor.Stop("GetGradient");
gbm_->DoBoost(train, &gpair_, obj_.get());
monitor.Stop("UpdateOneIter");
@ -547,6 +549,13 @@ class LearnerImpl : public Learner {
<< "Predict must happen after Load or InitModel";
gbm_->PredictBatch(data, out_preds, ntree_limit);
}
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) const {
CHECK(gbm_.get() != nullptr)
<< "Predict must happen after Load or InitModel";
gbm_->PredictBatch(data, out_preds, ntree_limit);
}
// model parameter
LearnerModelParam mparam;
// training parameter
@ -561,8 +570,9 @@ class LearnerImpl : public Learner {
std::string name_obj_;
// temporal storages for prediction
std::vector<bst_float> preds_;
HostDeviceVector<bst_float> preds2_;
// gradient pairs
std::vector<bst_gpair> gpair_;
HostDeviceVector<bst_gpair> gpair_;
private:
/*! \brief random number transformation seed. */

View File

@ -6,6 +6,8 @@
#include <xgboost/objective.h>
#include <dmlc/registry.h>
#include "../common/host_device_vector.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
} // namespace dmlc
@ -22,12 +24,27 @@ ObjFunction* ObjFunction::Create(const std::string& name) {
}
return (e->body)();
}
void ObjFunction::GetGradient(HostDeviceVector<bst_float>* preds,
const MetaInfo& info,
int iteration,
HostDeviceVector<bst_gpair>* out_gpair) {
GetGradient(preds->data_h(), info, iteration, &out_gpair->data_h());
}
void ObjFunction::PredTransform(HostDeviceVector<bst_float> *io_preds) {
PredTransform(&io_preds->data_h());
}
} // namespace xgboost
namespace xgboost {
namespace obj {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(regression_obj);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
#endif
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
DMLC_REGISTRY_LINK_TAG(rank_obj);
} // namespace obj

View File

@ -0,0 +1,110 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <algorithm>
#include "../common/math.h"
namespace xgboost {
namespace obj {
// common regressions
// linear regression
struct LinearSquareLoss {
// duplication is necessary, as __device__ specifier
// cannot be made conditional on template parameter
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return true; }
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
return predt - label;
}
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
return 1.0f;
}
template <typename T>
static T PredTransform(T x) { return x; }
template <typename T>
static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T>
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; }
};
// logistic loss for probability regression task
struct LogisticRegression {
// duplication is necessary, as __device__ specifier
// cannot be made conditional on template parameter
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); }
XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
return predt - label;
}
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float eps = 1e-16f;
return fmaxf(predt * (1.0f - predt), eps);
}
template <typename T>
static T PredTransform(T x) { return common::Sigmoid(x); }
template <typename T>
static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T>
static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
return std::max(predt * (T(1.0f) - predt), eps);
}
static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f)
<< "base_score must be in (0,1) for logistic loss";
return -logf(1.0f / base_score - 1.0f);
}
static const char* LabelErrorMsg() {
return "label must be in [0,1] for logistic regression";
}
static const char* DefaultEvalMetric() { return "rmse"; }
};
// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "error"; }
};
// logistic loss, but predict un-transformed margin
struct LogisticRaw : public LogisticRegression {
// duplication is necessary, as __device__ specifier
// cannot be made conditional on template parameter
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
predt = common::Sigmoid(predt);
return predt - label;
}
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float eps = 1e-16f;
predt = common::Sigmoid(predt);
return fmaxf(predt * (1.0f - predt), eps);
}
template <typename T>
static T PredTransform(T x) { return x; }
template <typename T>
static T FirstOrderGradient(T predt, T label) {
predt = common::Sigmoid(predt);
return predt - label;
}
template <typename T>
static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
predt = common::Sigmoid(predt);
return std::max(predt * (T(1.0f) - predt), eps);
}
static const char* DefaultEvalMetric() { return "auc"; }
};
} // namespace obj
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_

View File

@ -12,70 +12,13 @@
#include <utility>
#include "../common/math.h"
#include "../common/avx_helpers.h"
#include "./regression_loss.h"
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(regression_obj);
// common regressions
// linear regression
struct LinearSquareLoss {
template <typename T>
static T PredTransform(T x) { return x; }
static bool CheckLabel(bst_float x) { return true; }
template <typename T>
static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T>
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; }
};
// logistic loss for probability regression task
struct LogisticRegression {
template <typename T>
static T PredTransform(T x) { return common::Sigmoid(x); }
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
template <typename T>
static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T>
static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
return std::max(predt * (T(1.0f) - predt), eps);
}
static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f)
<< "base_score must be in (0,1) for logistic loss";
return -std::log(1.0f / base_score - 1.0f);
}
static const char* LabelErrorMsg() {
return "label must be in [0,1] for logistic regression";
}
static const char* DefaultEvalMetric() { return "rmse"; }
};
// logistic loss for binary classification task.
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "error"; }
};
// logistic loss, but predict un-transformed margin
struct LogisticRaw : public LogisticRegression {
template <typename T>
static T PredTransform(T x) { return x; }
template <typename T>
static T FirstOrderGradient(T predt, T label) {
predt = common::Sigmoid(predt);
return predt - label;
}
template <typename T>
static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
predt = common::Sigmoid(predt);
return std::max(predt * (T(1.0f) - predt), eps);
}
static const char* DefaultEvalMetric() { return "auc"; }
};
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
float scale_pos_weight;
// declare parameters

View File

@ -0,0 +1,241 @@
/*!
* Copyright 2017 XGBoost contributors
*/
// GPU implementation of objective function.
// Necessary to avoid extra copying of data to CPU.
#include <dmlc/omp.h>
#include <thrust/device_vector.h>
#include <thrust/copy.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <cmath>
#include <memory>
#include <vector>
#include "../common/device_helpers.cuh"
#include "../common/host_device_vector.h"
#include "./regression_loss.h"
using namespace dh;
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
struct GPURegLossParam : public dmlc::Parameter<GPURegLossParam> {
float scale_pos_weight;
int n_gpus;
int gpu_id;
// declare parameters
DMLC_DECLARE_PARAMETER(GPURegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// GPU kernel for gradient computation
template<typename Loss>
__global__ void get_gradient_k
(bst_gpair *__restrict__ out_gpair, uint *__restrict__ label_correct,
const float * __restrict__ preds, const float * __restrict__ labels,
const float * __restrict__ weights, int n, float scale_pos_weight) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n)
return;
float p = Loss::PredTransform(preds[i]);
float w = weights == nullptr ? 1.0f : weights[i];
float label = labels[i];
if (label == 1.0f)
w *= scale_pos_weight;
if (!Loss::CheckLabel(label))
atomicAnd(label_correct, 0);
out_gpair[i] = bst_gpair
(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w);
}
// GPU kernel for predicate transformation
template<typename Loss>
__global__ void pred_transform_k(float * __restrict__ preds, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n)
return;
preds[i] = Loss::PredTransform(preds[i]);
}
// regression loss function for evaluation on GPU (eventually)
template<typename Loss>
class GPURegLossObj : public ObjFunction {
protected:
// manages device data
struct DeviceData {
dvec<float> labels, weights;
dvec<uint> label_correct;
// allocate everything on device
DeviceData(bulk_allocator<memory_type::DEVICE>* ba, int device_idx, size_t n) {
ba->allocate(device_idx, false,
&labels, n,
&weights, n,
&label_correct, 1);
}
size_t size() const { return labels.size(); }
};
bool copied_;
std::unique_ptr<bulk_allocator<memory_type::DEVICE>> ba_;
std::unique_ptr<DeviceData> data_;
HostDeviceVector<bst_float> preds_d_;
HostDeviceVector<bst_gpair> out_gpair_d_;
// allocate device data for n elements, do nothing if enough memory is allocated already
void LazyResize(int n) {
if (data_.get() != nullptr && data_->size() >= n)
return;
copied_ = false;
// free the old data and allocate the new data
ba_.reset(new bulk_allocator<memory_type::DEVICE>());
data_.reset(new DeviceData(ba_.get(), 0, n));
preds_d_.resize(n, param_.gpu_id);
out_gpair_d_.resize(n, param_.gpu_id);
}
public:
GPURegLossObj() : copied_(false), preds_d_(0, -1), out_gpair_d_(0, -1) {}
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
}
void GetGradient(const std::vector<float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.size(), info.labels.size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.size() << ", label.size=" << info.labels.size();
size_t ndata = preds.size();
out_gpair->resize(ndata);
LazyResize(ndata);
thrust::copy(preds.begin(), preds.end(), preds_d_.tbegin(param_.gpu_id));
GetGradientDevice(preds_d_.ptr_d(param_.gpu_id), info, iter,
out_gpair_d_.ptr_d(param_.gpu_id), ndata);
thrust::copy_n(out_gpair_d_.tbegin(param_.gpu_id), ndata, out_gpair->begin());
}
void GetGradient(HostDeviceVector<float>* preds,
const MetaInfo &info,
int iter,
HostDeviceVector<bst_gpair>* out_gpair) override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds->size(), info.labels.size())
<< "labels are not correctly provided"
<< "preds.size=" << preds->size() << ", label.size=" << info.labels.size();
size_t ndata = preds->size();
out_gpair->resize(ndata, param_.gpu_id);
LazyResize(ndata);
GetGradientDevice(preds->ptr_d(param_.gpu_id), info, iter,
out_gpair->ptr_d(param_.gpu_id), ndata);
}
private:
void GetGradientDevice(float* preds,
const MetaInfo &info,
int iter,
bst_gpair* out_gpair, size_t n) {
safe_cuda(cudaSetDevice(param_.gpu_id));
DeviceData& d = *data_;
d.label_correct.fill(1);
// only copy the labels and weights once, similar to how the data is copied
if (!copied_) {
thrust::copy(info.labels.begin(), info.labels.begin() + n,
d.labels.tbegin());
if (info.weights.size() > 0) {
thrust::copy(info.weights.begin(), info.weights.begin() + n,
d.weights.tbegin());
}
copied_ = true;
}
// run the kernel
const int block = 256;
get_gradient_k<Loss><<<div_round_up(n, block), block>>>
(out_gpair, d.label_correct.data(), preds,
d.labels.data(), info.weights.size() > 0 ? d.weights.data() : nullptr,
n, param_.scale_pos_weight);
safe_cuda(cudaGetLastError());
// copy output data from the GPU
uint label_correct_h;
thrust::copy_n(d.label_correct.tbegin(), 1, &label_correct_h);
bool label_correct = label_correct_h != 0;
if (!label_correct) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}
public:
const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(std::vector<float> *io_preds) override {
LazyResize(io_preds->size());
thrust::copy(io_preds->begin(), io_preds->end(), preds_d_.tbegin(param_.gpu_id));
PredTransformDevice(preds_d_.ptr_d(param_.gpu_id), io_preds->size());
thrust::copy_n(preds_d_.tbegin(param_.gpu_id), io_preds->size(), io_preds->begin());
}
void PredTransform(HostDeviceVector<float> *io_preds) override {
PredTransformDevice(io_preds->ptr_d(param_.gpu_id), io_preds->size());
}
void PredTransformDevice(float* preds, size_t n) {
safe_cuda(cudaSetDevice(param_.gpu_id));
const int block = 256;
pred_transform_k<Loss><<<div_round_up(n, block), block>>>(preds, n);
safe_cuda(cudaGetLastError());
safe_cuda(cudaDeviceSynchronize());
}
float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score);
}
protected:
GPURegLossParam param_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(GPURegLossParam);
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
.describe("Linear regression (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
.describe("Logistic regression for probability regression task (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
.describe("Logistic regression for binary classification task (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
.describe("Logistic regression for classification, output score "
"before logistic transformation (computed on GPU)")
.set_body([]() { return new GPURegLossObj<LogisticRaw>(); });
} // namespace obj
} // namespace xgboost

View File

@ -5,6 +5,7 @@
#include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h>
#include "dmlc/logging.h"
#include "../common/host_device_vector.h"
namespace xgboost {
namespace predictor {
@ -108,6 +109,12 @@ class CPUPredictor : public Predictor {
}
public:
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override {
PredictBatch(dmat, &out_preds->data_h(), model, tree_begin, ntree_limit);
}
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override {

View File

@ -3,13 +3,16 @@
*/
#include <dmlc/parameter.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <xgboost/data.h>
#include <xgboost/predictor.h>
#include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h>
#include <memory>
#include "../common/device_helpers.cuh"
#include "../common/host_device_vector.h"
namespace xgboost {
namespace predictor {
@ -247,8 +250,16 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes,
}
class GPUPredictor : public xgboost::Predictor {
protected:
struct DevicePredictionCacheEntry {
std::shared_ptr<DMatrix> data;
HostDeviceVector<bst_float> predictions;
};
std::unordered_map<DMatrix*, DevicePredictionCacheEntry> device_cache_;
private:
void DevicePredictInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, size_t tree_begin,
size_t tree_end) {
if (tree_end - tree_begin == 0) {
@ -293,7 +304,7 @@ class GPUPredictor : public xgboost::Predictor {
tree_group.begin());
device_matrix->predictions.resize(out_preds->size());
thrust::copy(out_preds->begin(), out_preds->end(),
thrust::copy(out_preds->tbegin(param.gpu_id), out_preds->tend(param.gpu_id),
device_matrix->predictions.begin());
const int BLOCK_THREADS = 128;
@ -319,19 +330,30 @@ class GPUPredictor : public xgboost::Predictor {
dh::safe_cuda(cudaDeviceSynchronize());
thrust::copy(device_matrix->predictions.begin(),
device_matrix->predictions.end(), out_preds->begin());
device_matrix->predictions.end(), out_preds->tbegin(param.gpu_id));
}
public:
GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {}
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override {
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
HostDeviceVector<bst_float> out_preds_d;
PredictBatch(dmat, &out_preds_d, model, tree_begin, ntree_limit);
out_preds->resize(out_preds_d.size());
thrust::copy(out_preds_d.tbegin(param.gpu_id),
out_preds_d.tend(param.gpu_id), out_preds->begin());
}
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override {
if (this->PredictFromCacheDevice(dmat, out_preds, model, ntree_limit)) {
return;
}
this->InitOutPredictions(dmat->info(), out_preds, model);
this->InitOutPredictionsDevice(dmat->info(), out_preds, model);
int tree_end = ntree_limit * model.param.num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
@ -341,26 +363,78 @@ class GPUPredictor : public xgboost::Predictor {
DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
}
void UpdatePredictionCache(
const gbm::GBTreeModel& model,
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
int num_new_trees) override {
void InitOutPredictionsDevice(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model) const {
size_t n = model.param.num_output_group * info.num_row;
const std::vector<bst_float>& base_margin = info.base_margin;
out_preds->resize(n, param.gpu_id);
if (base_margin.size() != 0) {
CHECK_EQ(out_preds->size(), n);
thrust::copy(base_margin.begin(), base_margin.end(), out_preds->tbegin(param.gpu_id));
} else {
thrust::fill(out_preds->tbegin(param.gpu_id),
out_preds->tend(param.gpu_id), model.base_margin);
}
}
bool PredictFromCache(DMatrix* dmat,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit) {
HostDeviceVector<bst_float> out_preds_d(0, -1);
bool result = PredictFromCacheDevice(dmat, &out_preds_d, model, ntree_limit);
if (!result) return false;
out_preds->resize(out_preds_d.size(), param.gpu_id);
thrust::copy(out_preds_d.tbegin(param.gpu_id),
out_preds_d.tend(param.gpu_id), out_preds->begin());
return true;
}
bool PredictFromCacheDevice(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit) {
if (ntree_limit == 0 ||
ntree_limit * model.param.num_output_group >= model.trees.size()) {
auto it = device_cache_.find(dmat);
if (it != device_cache_.end()) {
HostDeviceVector<bst_float>& y = it->second.predictions;
if (y.size() != 0) {
out_preds->resize(y.size(), param.gpu_id);
thrust::copy(y.tbegin(param.gpu_id), y.tend(param.gpu_id),
out_preds->tbegin(param.gpu_id));
return true;
}
}
}
return false;
}
void UpdatePredictionCache(const gbm::GBTreeModel& model,
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
int num_new_trees) override {
auto old_ntree = model.trees.size() - num_new_trees;
// update cache entry
for (auto& kv : cache_) {
PredictionCacheEntry& e = kv.second;
for (auto& kv : device_cache_) {
DevicePredictionCacheEntry& e = kv.second;
DMatrix* dmat = kv.first;
HostDeviceVector<bst_float>& predictions = e.predictions;
if (e.predictions.size() == 0) {
cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0,
if (predictions.size() == 0) {
// ensure that the device in predictions is correct
predictions.resize(0, param.gpu_id);
cpu_predictor->PredictBatch(dmat, &predictions.data_h(), model, 0,
static_cast<bst_uint>(model.trees.size()));
} else if (model.param.num_output_group == 1 && updaters->size() > 0 &&
num_new_trees == 1 &&
updaters->back()->UpdatePredictionCache(e.data.get(),
&(e.predictions))) {
{} // do nothing
&predictions)) {
// do nothing
} else {
DevicePredictInternal(dmat, &(e.predictions), model, old_ntree,
DevicePredictInternal(dmat, &predictions, model, old_ntree,
model.trees.size());
}
}
@ -391,6 +465,8 @@ class GPUPredictor : public xgboost::Predictor {
Predictor::Init(cfg, cache);
cpu_predictor->Init(cfg, cache);
param.InitAllowUnknown(cfg);
for (const std::shared_ptr<DMatrix>& d : cache)
device_cache_[d.get()].data = d;
max_shared_memory_bytes = dh::max_shared_memory(param.gpu_id);
}

View File

@ -6,6 +6,8 @@
#include <xgboost/tree_updater.h>
#include <dmlc/registry.h>
#include "../common/host_device_vector.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
} // namespace dmlc
@ -20,6 +22,17 @@ TreeUpdater* TreeUpdater::Create(const std::string& name) {
return (e->body)();
}
void TreeUpdater::Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) {
Update(gpair->data_h(), data, trees);
}
bool TreeUpdater::UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds) {
return UpdatePredictionCache(data, &out_preds->data_h());
}
} // namespace xgboost
namespace xgboost {

View File

@ -12,6 +12,7 @@
#include <vector>
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/host_device_vector.h"
#include "../common/hist_util.h"
#include "../common/timer.h"
#include "param.h"
@ -349,7 +350,8 @@ struct DeviceShard {
}
// Reset values for each update iteration
void Reset(const std::vector<bst_gpair>& host_gpair) {
void Reset(HostDeviceVector<bst_gpair>* dh_gpair, int device) {
auto begin = dh_gpair->tbegin(device);
dh::safe_cuda(cudaSetDevice(device_idx));
position.current_dvec().fill(0);
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
@ -359,8 +361,8 @@ struct DeviceShard {
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
ridx_segments.front() = Segment(0, ridx.size());
this->gpair.copy(host_gpair.begin() + row_begin_idx,
host_gpair.begin() + row_end_idx);
this->gpair.copy(begin + row_begin_idx,
begin + row_end_idx);
subsample_gpair(&gpair, param.subsample, row_begin_idx);
hist.Reset();
}
@ -504,9 +506,28 @@ class GPUHistMaker : public TreeUpdater {
monitor.Init("updater_gpu_hist", param.debug_verbose);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
// TODO(canonizer): move it into the class if this ever becomes a bottleneck
HostDeviceVector<bst_gpair> gpair_d(gpair.size(), param.gpu_id);
dh::safe_cuda(cudaSetDevice(param.gpu_id));
thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id));
Update(&gpair_d, dmat, trees);
monitor.Stop("Update", dList);
}
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
UpdateHelper(gpair, dmat, trees);
monitor.Stop("Update", dList);
}
private:
void UpdateHelper(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
@ -521,9 +542,9 @@ class GPUHistMaker : public TreeUpdater {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
monitor.Stop("Update", dList);
}
public:
void InitDataOnce(DMatrix* dmat) {
info = &dmat->info();
monitor.Start("Quantiles", dList);
@ -572,7 +593,7 @@ class GPUHistMaker : public TreeUpdater {
initialised = true;
}
void InitData(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
void InitData(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const RegTree& tree) {
monitor.Start("InitDataOnce", dList);
if (!initialised) {
@ -585,11 +606,10 @@ class GPUHistMaker : public TreeUpdater {
// Copy gpair & reset memory
monitor.Start("InitDataReset", dList);
omp_set_num_threads(shards.size());
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
shards[cpu_thread_id]->Reset(gpair);
}
// TODO(canonizer): make it parallel again once HostDeviceVector is thread-safe
for (int shard = 0; shard < shards.size(); ++shard)
shards[shard]->Reset(gpair, param.gpu_id);
monitor.Stop("InitDataReset", dList);
}
@ -687,7 +707,7 @@ class GPUHistMaker : public TreeUpdater {
return std::move(best_splits);
}
void InitRoot(const std::vector<bst_gpair>& gpair, RegTree* p_tree) {
void InitRoot(RegTree* p_tree) {
auto root_nidx = 0;
// Sum gradients
std::vector<bst_gpair> tmp_sums(shards.size());
@ -800,7 +820,7 @@ class GPUHistMaker : public TreeUpdater {
this->UpdatePosition(candidate, p_tree);
}
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
void UpdateTree(HostDeviceVector<bst_gpair>* gpair, DMatrix* p_fmat,
RegTree* p_tree) {
// Temporarily store number of threads so we can change it back later
int nthread = omp_get_max_threads();
@ -811,7 +831,7 @@ class GPUHistMaker : public TreeUpdater {
this->InitData(gpair, p_fmat, *p_tree);
monitor.Stop("InitData", dList);
monitor.Start("InitRoot", dList);
this->InitRoot(gpair, p_tree);
this->InitRoot(p_tree);
monitor.Stop("InitRoot", dList);
auto timestamp = qexpand_->size();
@ -854,6 +874,16 @@ class GPUHistMaker : public TreeUpdater {
omp_set_num_threads(nthread);
}
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* p_out_preds) override {
return false;
}
bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* p_out_preds) override {
return false;
}
struct ExpandEntry {
int nid;
int depth;

View File

@ -0,0 +1,69 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <xgboost/objective.h>
#include "../helpers.h"
TEST(Objective, GPULinearRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{0, 0, 0, 0, 1, 1, 1, 1},
{1, 1, 1, 1, 1, 1, 1, 1},
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
{1, 1, 1, 1, 1, 1, 1, 1});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, GPULogisticRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
}
TEST(Objective, GPULogisticRegressionBasic) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
// test label validation
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0}))
<< "Expected error when label not in range [0,1f] for LogisticRegression";
// test ProbToMargin
EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f);
EXPECT_ANY_THROW(obj->ProbToMargin(10))
<< "Expected error when base_score not in range [0,1f] for LogisticRegression";
// test PredTransform
std::vector<xgboost::bst_float> preds = {0, 0.1f, 0.5f, 0.9f, 1};
std::vector<xgboost::bst_float> out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f};
obj->PredTransform(&preds);
for (int i = 0; i < static_cast<int>(preds.size()); ++i) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
}
}
TEST(Objective, GPULogisticRawGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
}