Objective function evaluation on GPU with minimal PCIe transfers (#2935)
* Added GPU objective function and no-copy interface. - xgboost::HostDeviceVector<T> syncs automatically between host and device - no-copy interfaces have been added - default implementations just sync the data to host and call the implementations with std::vector - GPU objective function, predictor, histogram updater process data directly on GPU
This commit is contained in:
parent
a187ed6c8f
commit
84ab74f3a5
@ -57,6 +57,7 @@
|
||||
#include "../src/learner.cc"
|
||||
#include "../src/logging.cc"
|
||||
#include "../src/common/common.cc"
|
||||
#include "../src/common/host_device_vector.cc"
|
||||
#include "../src/common/hist_util.cc"
|
||||
|
||||
// c_api
|
||||
|
||||
@ -165,6 +165,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
- "reg:logistic" --logistic regression
|
||||
- "binary:logistic" --logistic regression for binary classification, output probability
|
||||
- "binary:logitraw" --logistic regression for binary classification, output score before logistic transformation
|
||||
- "gpu:reg:linear", "gpu:reg:logistic", "gpu:binary:logistic", gpu:binary:logitraw" --versions
|
||||
of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm,
|
||||
they can only be used when the entire training session uses the same dataset
|
||||
- "count:poisson" --poisson regression for count data, output mean of poisson distribution
|
||||
- max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
|
||||
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "./data.h"
|
||||
#include "./objective.h"
|
||||
#include "./feature_map.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
@ -70,6 +71,10 @@ class GradientBooster {
|
||||
virtual void DoBoost(DMatrix* p_fmat,
|
||||
std::vector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj = nullptr) = 0;
|
||||
virtual void DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief generate predictions for given feature matrix
|
||||
* \param dmat feature matrix
|
||||
@ -80,6 +85,9 @@ class GradientBooster {
|
||||
virtual void PredictBatch(DMatrix* dmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
virtual void PredictBatch(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0);
|
||||
/*!
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is usually
|
||||
|
||||
@ -14,8 +14,11 @@
|
||||
#include <functional>
|
||||
#include "./data.h"
|
||||
#include "./base.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
/*! \brief interface of objective function */
|
||||
class ObjFunction {
|
||||
public:
|
||||
@ -45,6 +48,11 @@ class ObjFunction {
|
||||
const MetaInfo& info,
|
||||
int iteration,
|
||||
std::vector<bst_gpair>* out_gpair) = 0;
|
||||
virtual void GetGradient(HostDeviceVector<bst_float>* preds,
|
||||
const MetaInfo& info,
|
||||
int iteration,
|
||||
HostDeviceVector<bst_gpair>* out_gpair);
|
||||
|
||||
/*! \return the default evaluation metric for the objective */
|
||||
virtual const char* DefaultEvalMetric() const = 0;
|
||||
// the following functions are optional, most of time default implementation is good enough
|
||||
@ -53,6 +61,8 @@ class ObjFunction {
|
||||
* \param io_preds prediction values, saves to this vector as well
|
||||
*/
|
||||
virtual void PredTransform(std::vector<bst_float> *io_preds) {}
|
||||
virtual void PredTransform(HostDeviceVector<bst_float> *io_preds);
|
||||
|
||||
/*!
|
||||
* \brief transform prediction values, this is only called when Eval is called,
|
||||
* usually it redirect to PredTransform
|
||||
@ -61,6 +71,9 @@ class ObjFunction {
|
||||
virtual void EvalTransform(std::vector<bst_float> *io_preds) {
|
||||
this->PredTransform(io_preds);
|
||||
}
|
||||
virtual void EvalTransform(HostDeviceVector<bst_float> *io_preds) {
|
||||
this->PredTransform(io_preds);
|
||||
}
|
||||
/*!
|
||||
* \brief transform probability value back to margin
|
||||
* this is used to transform user-set base_score back to margin
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "../../src/gbm/gbtree_model.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
@ -51,10 +52,6 @@ class Predictor {
|
||||
const std::vector<std::shared_ptr<DMatrix>>& cache);
|
||||
|
||||
/**
|
||||
* \fn virtual void Predictor::PredictBatch( DMatrix* dmat,
|
||||
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model, int
|
||||
* tree_begin, unsigned ntree_limit = 0) = 0;
|
||||
*
|
||||
* \brief Generate batch predictions for a given feature matrix. May use
|
||||
* cached predictions if available instead of calculating from scratch.
|
||||
*
|
||||
@ -70,6 +67,22 @@ class Predictor {
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/**
|
||||
* \brief Generate batch predictions for a given feature matrix. May use
|
||||
* cached predictions if available instead of calculating from scratch.
|
||||
*
|
||||
* \param [in,out] dmat Feature matrix.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model The model to predict from.
|
||||
* \param tree_begin The tree begin index.
|
||||
* \param ntree_limit (Optional) The ntree limit. 0 means do not
|
||||
* limit trees.
|
||||
*/
|
||||
|
||||
virtual void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/**
|
||||
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
|
||||
* &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters, int
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include "./base.h"
|
||||
#include "./data.h"
|
||||
#include "./tree_model.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
@ -42,6 +43,9 @@ class TreeUpdater {
|
||||
virtual void Update(const std::vector<bst_gpair>& gpair,
|
||||
DMatrix* data,
|
||||
const std::vector<RegTree*>& trees) = 0;
|
||||
virtual void Update(HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix* data,
|
||||
const std::vector<RegTree*>& trees);
|
||||
|
||||
/*!
|
||||
* \brief determines whether updater has enough knowledge about a given dataset
|
||||
@ -57,6 +61,9 @@ class TreeUpdater {
|
||||
std::vector<bst_float>* out_preds) {
|
||||
return false;
|
||||
}
|
||||
virtual bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds);
|
||||
|
||||
/*!
|
||||
* \brief Create a tree updater given name
|
||||
* \param name Name of the tree updater.
|
||||
|
||||
@ -484,6 +484,13 @@ class bulk_allocator {
|
||||
}
|
||||
|
||||
public:
|
||||
bulk_allocator() {}
|
||||
// prevent accidental copying, moving or assignment of this object
|
||||
bulk_allocator(const bulk_allocator<MemoryT>&) = delete;
|
||||
bulk_allocator(bulk_allocator<MemoryT>&&) = delete;
|
||||
void operator=(const bulk_allocator<MemoryT>&) = delete;
|
||||
void operator=(bulk_allocator<MemoryT>&&) = delete;
|
||||
|
||||
~bulk_allocator() {
|
||||
for (size_t i = 0; i < d_ptr.size(); i++) {
|
||||
if (!(d_ptr[i] == nullptr)) {
|
||||
|
||||
54
src/common/host_device_vector.cc
Normal file
54
src/common/host_device_vector.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
// dummy implementation of HostDeviceVector in case CUDA is not used
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include "./host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
struct HostDeviceVectorImpl {
|
||||
explicit HostDeviceVectorImpl(size_t size) : data_h_(size) {}
|
||||
std::vector<T> data_h_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
|
||||
impl_ = new HostDeviceVectorImpl<T>(size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::~HostDeviceVector() {
|
||||
HostDeviceVectorImpl<T>* tmp = impl_;
|
||||
impl_ = nullptr;
|
||||
delete tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::size() const { return impl_->data_h_.size(); }
|
||||
|
||||
template <typename T>
|
||||
int HostDeviceVector<T>::device() const { return -1; }
|
||||
|
||||
template <typename T>
|
||||
T* HostDeviceVector<T>::ptr_d(int device) { return nullptr; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h_; }
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
|
||||
impl_->data_h_.resize(new_size);
|
||||
}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<bst_gpair>;
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif
|
||||
135
src/common/host_device_vector.cu
Normal file
135
src/common/host_device_vector.cu
Normal file
@ -0,0 +1,135 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include "./host_device_vector.h"
|
||||
#include "./device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
struct HostDeviceVectorImpl {
|
||||
HostDeviceVectorImpl(size_t size, int device)
|
||||
: device_(device), on_d_(device >= 0) {
|
||||
if (on_d_) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(size);
|
||||
} else {
|
||||
data_h_.resize(size);
|
||||
}
|
||||
}
|
||||
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>&) = delete;
|
||||
HostDeviceVectorImpl(HostDeviceVectorImpl<T>&&) = delete;
|
||||
void operator=(const HostDeviceVectorImpl<T>&) = delete;
|
||||
void operator=(HostDeviceVectorImpl<T>&&) = delete;
|
||||
|
||||
size_t size() const { return on_d_ ? data_d_.size() : data_h_.size(); }
|
||||
|
||||
int device() const { return device_; }
|
||||
|
||||
T* ptr_d(int device) {
|
||||
lazy_sync_device(device);
|
||||
return data_d_.data().get();
|
||||
}
|
||||
thrust::device_ptr<T> tbegin(int device) {
|
||||
return thrust::device_ptr<T>(ptr_d(device));
|
||||
}
|
||||
thrust::device_ptr<T> tend(int device) {
|
||||
auto begin = tbegin(device);
|
||||
return begin + size();
|
||||
}
|
||||
std::vector<T>& data_h() {
|
||||
lazy_sync_host();
|
||||
return data_h_;
|
||||
}
|
||||
void resize(size_t new_size, int new_device) {
|
||||
if (new_size == this->size() && new_device == device_)
|
||||
return;
|
||||
device_ = new_device;
|
||||
// if !on_d_, but the data size is 0 and the device is set,
|
||||
// resize the data on device instead
|
||||
if (!on_d_ && (data_h_.size() > 0 || device_ == -1)) {
|
||||
data_h_.resize(new_size);
|
||||
} else {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(new_size);
|
||||
on_d_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void lazy_sync_host() {
|
||||
if (!on_d_)
|
||||
return;
|
||||
if (data_h_.size() != this->size())
|
||||
data_h_.resize(this->size());
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
thrust::copy(data_d_.begin(), data_d_.end(), data_h_.begin());
|
||||
on_d_ = false;
|
||||
}
|
||||
|
||||
void lazy_sync_device(int device) {
|
||||
if (on_d_)
|
||||
return;
|
||||
if (device != device_) {
|
||||
CHECK_EQ(device_, -1);
|
||||
device_ = device;
|
||||
}
|
||||
if (data_d_.size() != this->size()) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(this->size());
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
thrust::copy(data_h_.begin(), data_h_.end(), data_d_.begin());
|
||||
on_d_ = true;
|
||||
}
|
||||
|
||||
std::vector<T> data_h_;
|
||||
thrust::device_vector<T> data_d_;
|
||||
// true if there is an up-to-date copy of data on device, false otherwise
|
||||
bool on_d_;
|
||||
int device_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
|
||||
impl_ = new HostDeviceVectorImpl<T>(size, device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::~HostDeviceVector() {
|
||||
HostDeviceVectorImpl<T>* tmp = impl_;
|
||||
impl_ = nullptr;
|
||||
delete tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::size() const { return impl_->size(); }
|
||||
|
||||
template <typename T>
|
||||
int HostDeviceVector<T>::device() const { return impl_->device(); }
|
||||
|
||||
template <typename T>
|
||||
T* HostDeviceVector<T>::ptr_d(int device) { return impl_->ptr_d(device); }
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> HostDeviceVector<T>::tbegin(int device) {
|
||||
return impl_->tbegin(device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> HostDeviceVector<T>::tend(int device) {
|
||||
return impl_->tend(device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h(); }
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
|
||||
impl_->resize(new_size, new_device);
|
||||
}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<bst_gpair>;
|
||||
|
||||
} // namespace xgboost
|
||||
100
src/common/host_device_vector.h
Normal file
100
src/common/host_device_vector.h
Normal file
@ -0,0 +1,100 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
// only include thrust-related files if host_device_vector.h
|
||||
// is included from a .cu file
|
||||
#ifdef __CUDACC__
|
||||
#include <thrust/device_ptr.h>
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T> struct HostDeviceVectorImpl;
|
||||
|
||||
/**
|
||||
* @file host_device_vector.h
|
||||
* @brief A device-and-host vector abstraction layer.
|
||||
*
|
||||
* Why HostDeviceVector?<br/>
|
||||
* With CUDA, one has to explicitly manage memory through 'cudaMemcpy' calls.
|
||||
* This wrapper class hides this management from the users, thereby making it
|
||||
* easy to integrate GPU/CPU usage under a single interface.
|
||||
*
|
||||
* Initialization/Allocation:<br/>
|
||||
* One can choose to initialize the vector on CPU or GPU during constructor.
|
||||
* (use the 'device' argument) Or, can choose to use the 'resize' method to
|
||||
* allocate/resize memory explicitly.
|
||||
*
|
||||
* Accessing underling data:<br/>
|
||||
* Use 'data_h' method to explicitly query for the underlying std::vector.
|
||||
* If you need the raw device pointer, use the 'ptr_d' method. For perf
|
||||
* implications of these calls, see below.
|
||||
*
|
||||
* Accessing underling data and their perf implications:<br/>
|
||||
* There are 4 scenarios to be considered here:
|
||||
* data_h and data on CPU --> no problems, std::vector returned immediately
|
||||
* data_h but data on GPU --> this causes a cudaMemcpy to be issued internally.
|
||||
* subsequent calls to data_h, will NOT incur this penalty.
|
||||
* (assuming 'ptr_d' is not called in between)
|
||||
* ptr_d but data on CPU --> this causes a cudaMemcpy to be issued internally.
|
||||
* subsequent calls to ptr_d, will NOT incur this penalty.
|
||||
* (assuming 'data_h' is not called in between)
|
||||
* ptr_d and data on GPU --> no problems, the device ptr will be returned immediately
|
||||
*
|
||||
* What if xgboost is compiled without CUDA?<br/>
|
||||
* In that case, there's a special implementation which always falls-back to
|
||||
* working with std::vector. This logic can be found in host_device_vector.cc
|
||||
*
|
||||
* Why not consider CUDA unified memory?<br/>
|
||||
* We did consider. However, it poses complications if we need to support both
|
||||
* compiling with and without CUDA toolkit. It was easier to have
|
||||
* 'HostDeviceVector' with a special-case implementation in host_device_vector.cc
|
||||
*
|
||||
* @note: This is not thread-safe!
|
||||
*/
|
||||
template <typename T>
|
||||
class HostDeviceVector {
|
||||
public:
|
||||
explicit HostDeviceVector(size_t size = 0, int device = -1);
|
||||
~HostDeviceVector();
|
||||
HostDeviceVector(const HostDeviceVector<T>&) = delete;
|
||||
HostDeviceVector(HostDeviceVector<T>&&) = delete;
|
||||
void operator=(const HostDeviceVector<T>&) = delete;
|
||||
void operator=(HostDeviceVector<T>&&) = delete;
|
||||
size_t size() const;
|
||||
int device() const;
|
||||
T* ptr_d(int device);
|
||||
|
||||
// only define functions returning device_ptr
|
||||
// if HostDeviceVector.h is included from a .cu file
|
||||
#ifdef __CUDACC__
|
||||
thrust::device_ptr<T> tbegin(int device);
|
||||
thrust::device_ptr<T> tend(int device);
|
||||
#endif
|
||||
|
||||
std::vector<T>& data_h();
|
||||
void resize(size_t new_size, int new_device);
|
||||
|
||||
// helper functions in case a function needs to be templated
|
||||
// to work for both HostDeviceVector and std::vector
|
||||
static std::vector<T>& data_h(HostDeviceVector<T>* v) {
|
||||
return v->data_h();
|
||||
}
|
||||
|
||||
static std::vector<T>& data_h(std::vector<T>* v) {
|
||||
return *v;
|
||||
}
|
||||
|
||||
private:
|
||||
HostDeviceVectorImpl<T>* impl_;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
@ -20,8 +20,8 @@ namespace common {
|
||||
* \param x input parameter
|
||||
* \return the transformed value.
|
||||
*/
|
||||
inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + std::exp(-x));
|
||||
XGBOOST_DEVICE inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
inline avx::Float8 Sigmoid(avx::Float8 x) {
|
||||
|
||||
@ -21,6 +21,19 @@ GradientBooster* GradientBooster::Create(
|
||||
}
|
||||
return (e->body)(cache_mats, base_margin);
|
||||
}
|
||||
|
||||
void GradientBooster::DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) {
|
||||
DoBoost(p_fmat, &in_gpair->data_h(), obj);
|
||||
}
|
||||
|
||||
void GradientBooster::PredictBatch(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) {
|
||||
PredictBatch(dmat, &out_preds->data_h(), ntree_limit);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include "../common/common.h"
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "../common/random.h"
|
||||
#include "gbtree_model.h"
|
||||
#include "../common/timer.h"
|
||||
@ -182,35 +183,13 @@ class GBTree : public GradientBooster {
|
||||
void DoBoost(DMatrix* p_fmat,
|
||||
std::vector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
const std::vector<bst_gpair>& gpair = *in_gpair;
|
||||
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
monitor.Start("BoostNewTrees");
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(gpair, p_fmat, 0, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
} else {
|
||||
CHECK_EQ(gpair.size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
std::vector<bst_gpair> tmp(gpair.size() / ngroup);
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp[i] = gpair[i * ngroup + gid];
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(tmp, p_fmat, gid, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
}
|
||||
}
|
||||
monitor.Stop("BoostNewTrees");
|
||||
monitor.Start("CommitModel");
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
this->CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
monitor.Stop("CommitModel");
|
||||
DoBoostHelper(p_fmat, in_gpair, obj);
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
DoBoostHelper(p_fmat, in_gpair, obj);
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
@ -219,6 +198,12 @@ class GBTree : public GradientBooster {
|
||||
predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparseBatch::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
@ -257,9 +242,48 @@ class GBTree : public GradientBooster {
|
||||
updaters.push_back(std::move(up));
|
||||
}
|
||||
}
|
||||
|
||||
// TVec is either std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
|
||||
template <typename TVec>
|
||||
void DoBoostHelper(DMatrix* p_fmat,
|
||||
TVec* in_gpair,
|
||||
ObjFunction* obj) {
|
||||
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
monitor.Start("BoostNewTrees");
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
std::vector<bst_gpair> tmp(in_gpair->size() / ngroup);
|
||||
auto& gpair_h = HostDeviceVector<bst_gpair>::data_h(in_gpair);
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp[i] = gpair_h[i * ngroup + gid];
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
}
|
||||
}
|
||||
monitor.Stop("BoostNewTrees");
|
||||
monitor.Start("CommitModel");
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
this->CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
monitor.Stop("CommitModel");
|
||||
}
|
||||
|
||||
// do group specific group
|
||||
// TVec is either const std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
|
||||
template <typename TVec>
|
||||
inline void
|
||||
BoostNewTrees(const std::vector<bst_gpair> &gpair,
|
||||
BoostNewTrees(TVec* gpair,
|
||||
DMatrix *p_fmat,
|
||||
int bst_group,
|
||||
std::vector<std::unique_ptr<RegTree> >* ret) {
|
||||
@ -286,9 +310,24 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
// update the trees
|
||||
for (auto& up : updaters) {
|
||||
up->Update(gpair, p_fmat, new_trees);
|
||||
UpdateHelper(up.get(), gpair, p_fmat, new_trees);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateHelper(TreeUpdater* updater,
|
||||
std::vector<bst_gpair>* gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*>& new_trees) {
|
||||
updater->Update(*gpair, p_fmat, new_trees);
|
||||
}
|
||||
|
||||
void UpdateHelper(TreeUpdater* updater,
|
||||
HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*>& new_trees) {
|
||||
updater->Update(gpair, p_fmat, new_trees);
|
||||
}
|
||||
|
||||
// commit new trees all at once
|
||||
virtual void
|
||||
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||
|
||||
@ -16,10 +16,12 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "./common/common.h"
|
||||
#include "./common/host_device_vector.h"
|
||||
#include "./common/io.h"
|
||||
#include "./common/random.h"
|
||||
#include "common/timer.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// implementation of base learner.
|
||||
bool Learner::AllowLazyCheckPoint() const {
|
||||
@ -360,10 +362,10 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
monitor.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_);
|
||||
this->PredictRaw(train, &preds2_);
|
||||
monitor.Stop("PredictRaw");
|
||||
monitor.Start("GetGradient");
|
||||
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
|
||||
obj_->GetGradient(&preds2_, train->info(), iter, &gpair_);
|
||||
monitor.Stop("GetGradient");
|
||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||
monitor.Stop("UpdateOneIter");
|
||||
@ -547,6 +549,13 @@ class LearnerImpl : public Learner {
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_.get() != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
// model parameter
|
||||
LearnerModelParam mparam;
|
||||
// training parameter
|
||||
@ -561,8 +570,9 @@ class LearnerImpl : public Learner {
|
||||
std::string name_obj_;
|
||||
// temporal storages for prediction
|
||||
std::vector<bst_float> preds_;
|
||||
HostDeviceVector<bst_float> preds2_;
|
||||
// gradient pairs
|
||||
std::vector<bst_gpair> gpair_;
|
||||
HostDeviceVector<bst_gpair> gpair_;
|
||||
|
||||
private:
|
||||
/*! \brief random number transformation seed. */
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#include <xgboost/objective.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
} // namespace dmlc
|
||||
@ -22,12 +24,27 @@ ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
void ObjFunction::GetGradient(HostDeviceVector<bst_float>* preds,
|
||||
const MetaInfo& info,
|
||||
int iteration,
|
||||
HostDeviceVector<bst_gpair>* out_gpair) {
|
||||
GetGradient(preds->data_h(), info, iteration, &out_gpair->data_h());
|
||||
}
|
||||
|
||||
void ObjFunction::PredTransform(HostDeviceVector<bst_float> *io_preds) {
|
||||
PredTransform(&io_preds->data_h());
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
|
||||
#endif
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
} // namespace obj
|
||||
|
||||
110
src/objective/regression_loss.h
Normal file
110
src/objective/regression_loss.h
Normal file
@ -0,0 +1,110 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
|
||||
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
// common regressions
|
||||
// linear regression
|
||||
struct LinearSquareLoss {
|
||||
// duplication is necessary, as __device__ specifier
|
||||
// cannot be made conditional on template parameter
|
||||
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
|
||||
XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return true; }
|
||||
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||
return predt - label;
|
||||
}
|
||||
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||
return 1.0f;
|
||||
}
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
|
||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||
static const char* LabelErrorMsg() { return ""; }
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
};
|
||||
|
||||
// logistic loss for probability regression task
|
||||
struct LogisticRegression {
|
||||
// duplication is necessary, as __device__ specifier
|
||||
// cannot be made conditional on template parameter
|
||||
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); }
|
||||
XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
|
||||
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||
return predt - label;
|
||||
}
|
||||
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||
const float eps = 1e-16f;
|
||||
return fmaxf(predt * (1.0f - predt), eps);
|
||||
}
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return common::Sigmoid(x); }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static bst_float ProbToMargin(bst_float base_score) {
|
||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||
<< "base_score must be in (0,1) for logistic loss";
|
||||
return -logf(1.0f / base_score - 1.0f);
|
||||
}
|
||||
static const char* LabelErrorMsg() {
|
||||
return "label must be in [0,1] for logistic regression";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
};
|
||||
|
||||
// logistic loss for binary classification task
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "error"; }
|
||||
};
|
||||
|
||||
// logistic loss, but predict un-transformed margin
|
||||
struct LogisticRaw : public LogisticRegression {
|
||||
// duplication is necessary, as __device__ specifier
|
||||
// cannot be made conditional on template parameter
|
||||
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
|
||||
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||
predt = common::Sigmoid(predt);
|
||||
return predt - label;
|
||||
}
|
||||
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||
const float eps = 1e-16f;
|
||||
predt = common::Sigmoid(predt);
|
||||
return fmaxf(predt * (1.0f - predt), eps);
|
||||
}
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) {
|
||||
predt = common::Sigmoid(predt);
|
||||
return predt - label;
|
||||
}
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
predt = common::Sigmoid(predt);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "auc"; }
|
||||
};
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
|
||||
@ -12,70 +12,13 @@
|
||||
#include <utility>
|
||||
#include "../common/math.h"
|
||||
#include "../common/avx_helpers.h"
|
||||
#include "./regression_loss.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj);
|
||||
|
||||
// common regressions
|
||||
// linear regression
|
||||
struct LinearSquareLoss {
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
static bool CheckLabel(bst_float x) { return true; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
|
||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||
static const char* LabelErrorMsg() { return ""; }
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
};
|
||||
// logistic loss for probability regression task
|
||||
struct LogisticRegression {
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return common::Sigmoid(x); }
|
||||
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static bst_float ProbToMargin(bst_float base_score) {
|
||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||
<< "base_score must be in (0,1) for logistic loss";
|
||||
return -std::log(1.0f / base_score - 1.0f);
|
||||
}
|
||||
static const char* LabelErrorMsg() {
|
||||
return "label must be in [0,1] for logistic regression";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
};
|
||||
// logistic loss for binary classification task.
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "error"; }
|
||||
};
|
||||
// logistic loss, but predict un-transformed margin
|
||||
struct LogisticRaw : public LogisticRegression {
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) {
|
||||
predt = common::Sigmoid(predt);
|
||||
return predt - label;
|
||||
}
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
predt = common::Sigmoid(predt);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "auc"; }
|
||||
};
|
||||
|
||||
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
float scale_pos_weight;
|
||||
// declare parameters
|
||||
|
||||
241
src/objective/regression_obj_gpu.cu
Normal file
241
src/objective/regression_obj_gpu.cu
Normal file
@ -0,0 +1,241 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
// GPU implementation of objective function.
|
||||
// Necessary to avoid extra copying of data to CPU.
|
||||
#include <dmlc/omp.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "./regression_loss.h"
|
||||
|
||||
using namespace dh;
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
|
||||
struct GPURegLossParam : public dmlc::Parameter<GPURegLossParam> {
|
||||
float scale_pos_weight;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPURegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
// GPU kernel for gradient computation
|
||||
template<typename Loss>
|
||||
__global__ void get_gradient_k
|
||||
(bst_gpair *__restrict__ out_gpair, uint *__restrict__ label_correct,
|
||||
const float * __restrict__ preds, const float * __restrict__ labels,
|
||||
const float * __restrict__ weights, int n, float scale_pos_weight) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= n)
|
||||
return;
|
||||
float p = Loss::PredTransform(preds[i]);
|
||||
float w = weights == nullptr ? 1.0f : weights[i];
|
||||
float label = labels[i];
|
||||
if (label == 1.0f)
|
||||
w *= scale_pos_weight;
|
||||
if (!Loss::CheckLabel(label))
|
||||
atomicAnd(label_correct, 0);
|
||||
out_gpair[i] = bst_gpair
|
||||
(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w);
|
||||
}
|
||||
|
||||
// GPU kernel for predicate transformation
|
||||
template<typename Loss>
|
||||
__global__ void pred_transform_k(float * __restrict__ preds, int n) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= n)
|
||||
return;
|
||||
preds[i] = Loss::PredTransform(preds[i]);
|
||||
}
|
||||
|
||||
// regression loss function for evaluation on GPU (eventually)
|
||||
template<typename Loss>
|
||||
class GPURegLossObj : public ObjFunction {
|
||||
protected:
|
||||
// manages device data
|
||||
struct DeviceData {
|
||||
dvec<float> labels, weights;
|
||||
dvec<uint> label_correct;
|
||||
|
||||
// allocate everything on device
|
||||
DeviceData(bulk_allocator<memory_type::DEVICE>* ba, int device_idx, size_t n) {
|
||||
ba->allocate(device_idx, false,
|
||||
&labels, n,
|
||||
&weights, n,
|
||||
&label_correct, 1);
|
||||
}
|
||||
size_t size() const { return labels.size(); }
|
||||
};
|
||||
|
||||
|
||||
bool copied_;
|
||||
std::unique_ptr<bulk_allocator<memory_type::DEVICE>> ba_;
|
||||
std::unique_ptr<DeviceData> data_;
|
||||
HostDeviceVector<bst_float> preds_d_;
|
||||
HostDeviceVector<bst_gpair> out_gpair_d_;
|
||||
|
||||
// allocate device data for n elements, do nothing if enough memory is allocated already
|
||||
void LazyResize(int n) {
|
||||
if (data_.get() != nullptr && data_->size() >= n)
|
||||
return;
|
||||
copied_ = false;
|
||||
// free the old data and allocate the new data
|
||||
ba_.reset(new bulk_allocator<memory_type::DEVICE>());
|
||||
data_.reset(new DeviceData(ba_.get(), 0, n));
|
||||
preds_d_.resize(n, param_.gpu_id);
|
||||
out_gpair_d_.resize(n, param_.gpu_id);
|
||||
}
|
||||
|
||||
public:
|
||||
GPURegLossObj() : copied_(false), preds_d_(0, -1), out_gpair_d_(0, -1) {}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
||||
}
|
||||
void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.size() << ", label.size=" << info.labels.size();
|
||||
|
||||
size_t ndata = preds.size();
|
||||
out_gpair->resize(ndata);
|
||||
LazyResize(ndata);
|
||||
thrust::copy(preds.begin(), preds.end(), preds_d_.tbegin(param_.gpu_id));
|
||||
GetGradientDevice(preds_d_.ptr_d(param_.gpu_id), info, iter,
|
||||
out_gpair_d_.ptr_d(param_.gpu_id), ndata);
|
||||
thrust::copy_n(out_gpair_d_.tbegin(param_.gpu_id), ndata, out_gpair->begin());
|
||||
}
|
||||
|
||||
void GetGradient(HostDeviceVector<float>* preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<bst_gpair>* out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds->size(), info.labels.size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds->size() << ", label.size=" << info.labels.size();
|
||||
size_t ndata = preds->size();
|
||||
out_gpair->resize(ndata, param_.gpu_id);
|
||||
LazyResize(ndata);
|
||||
GetGradientDevice(preds->ptr_d(param_.gpu_id), info, iter,
|
||||
out_gpair->ptr_d(param_.gpu_id), ndata);
|
||||
}
|
||||
|
||||
private:
|
||||
void GetGradientDevice(float* preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
bst_gpair* out_gpair, size_t n) {
|
||||
safe_cuda(cudaSetDevice(param_.gpu_id));
|
||||
DeviceData& d = *data_;
|
||||
d.label_correct.fill(1);
|
||||
// only copy the labels and weights once, similar to how the data is copied
|
||||
if (!copied_) {
|
||||
thrust::copy(info.labels.begin(), info.labels.begin() + n,
|
||||
d.labels.tbegin());
|
||||
if (info.weights.size() > 0) {
|
||||
thrust::copy(info.weights.begin(), info.weights.begin() + n,
|
||||
d.weights.tbegin());
|
||||
}
|
||||
copied_ = true;
|
||||
}
|
||||
|
||||
// run the kernel
|
||||
const int block = 256;
|
||||
get_gradient_k<Loss><<<div_round_up(n, block), block>>>
|
||||
(out_gpair, d.label_correct.data(), preds,
|
||||
d.labels.data(), info.weights.size() > 0 ? d.weights.data() : nullptr,
|
||||
n, param_.scale_pos_weight);
|
||||
safe_cuda(cudaGetLastError());
|
||||
|
||||
// copy output data from the GPU
|
||||
uint label_correct_h;
|
||||
thrust::copy_n(d.label_correct.tbegin(), 1, &label_correct_h);
|
||||
|
||||
bool label_correct = label_correct_h != 0;
|
||||
if (!label_correct) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
|
||||
void PredTransform(std::vector<float> *io_preds) override {
|
||||
LazyResize(io_preds->size());
|
||||
thrust::copy(io_preds->begin(), io_preds->end(), preds_d_.tbegin(param_.gpu_id));
|
||||
PredTransformDevice(preds_d_.ptr_d(param_.gpu_id), io_preds->size());
|
||||
thrust::copy_n(preds_d_.tbegin(param_.gpu_id), io_preds->size(), io_preds->begin());
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
||||
PredTransformDevice(io_preds->ptr_d(param_.gpu_id), io_preds->size());
|
||||
}
|
||||
|
||||
void PredTransformDevice(float* preds, size_t n) {
|
||||
safe_cuda(cudaSetDevice(param_.gpu_id));
|
||||
const int block = 256;
|
||||
pred_transform_k<Loss><<<div_round_up(n, block), block>>>(preds, n);
|
||||
safe_cuda(cudaGetLastError());
|
||||
safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
|
||||
float ProbToMargin(float base_score) const override {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
protected:
|
||||
GPURegLossParam param_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(GPURegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
|
||||
.describe("Linear regression (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
|
||||
.describe("Logistic regression for probability regression task (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
|
||||
.describe("Logistic regression for binary classification task (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
|
||||
.describe("Logistic regression for classification, output score "
|
||||
"before logistic transformation (computed on GPU)")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticRaw>(); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -5,6 +5,7 @@
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include "dmlc/logging.h"
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
@ -108,6 +109,12 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
PredictBatch(dmat, &out_preds->data_h(), model, tree_begin, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
|
||||
@ -3,13 +3,16 @@
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <memory>
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
@ -247,8 +250,16 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||
}
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
protected:
|
||||
struct DevicePredictionCacheEntry {
|
||||
std::shared_ptr<DMatrix> data;
|
||||
HostDeviceVector<bst_float> predictions;
|
||||
};
|
||||
|
||||
std::unordered_map<DMatrix*, DevicePredictionCacheEntry> device_cache_;
|
||||
|
||||
private:
|
||||
void DevicePredictInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||
size_t tree_end) {
|
||||
if (tree_end - tree_begin == 0) {
|
||||
@ -293,7 +304,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
tree_group.begin());
|
||||
|
||||
device_matrix->predictions.resize(out_preds->size());
|
||||
thrust::copy(out_preds->begin(), out_preds->end(),
|
||||
thrust::copy(out_preds->tbegin(param.gpu_id), out_preds->tend(param.gpu_id),
|
||||
device_matrix->predictions.begin());
|
||||
|
||||
const int BLOCK_THREADS = 128;
|
||||
@ -319,19 +330,30 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
thrust::copy(device_matrix->predictions.begin(),
|
||||
device_matrix->predictions.end(), out_preds->begin());
|
||||
device_matrix->predictions.end(), out_preds->tbegin(param.gpu_id));
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
HostDeviceVector<bst_float> out_preds_d;
|
||||
PredictBatch(dmat, &out_preds_d, model, tree_begin, ntree_limit);
|
||||
out_preds->resize(out_preds_d.size());
|
||||
thrust::copy(out_preds_d.tbegin(param.gpu_id),
|
||||
out_preds_d.tend(param.gpu_id), out_preds->begin());
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
if (this->PredictFromCacheDevice(dmat, out_preds, model, ntree_limit)) {
|
||||
return;
|
||||
}
|
||||
this->InitOutPredictions(dmat->info(), out_preds, model);
|
||||
this->InitOutPredictionsDevice(dmat->info(), out_preds, model);
|
||||
|
||||
int tree_end = ntree_limit * model.param.num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
@ -341,26 +363,78 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(
|
||||
const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees) override {
|
||||
|
||||
void InitOutPredictionsDevice(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model) const {
|
||||
size_t n = model.param.num_output_group * info.num_row;
|
||||
const std::vector<bst_float>& base_margin = info.base_margin;
|
||||
out_preds->resize(n, param.gpu_id);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
thrust::copy(base_margin.begin(), base_margin.end(), out_preds->tbegin(param.gpu_id));
|
||||
} else {
|
||||
thrust::fill(out_preds->tbegin(param.gpu_id),
|
||||
out_preds->tend(param.gpu_id), model.base_margin);
|
||||
}
|
||||
}
|
||||
|
||||
bool PredictFromCache(DMatrix* dmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit) {
|
||||
HostDeviceVector<bst_float> out_preds_d(0, -1);
|
||||
bool result = PredictFromCacheDevice(dmat, &out_preds_d, model, ntree_limit);
|
||||
if (!result) return false;
|
||||
out_preds->resize(out_preds_d.size(), param.gpu_id);
|
||||
thrust::copy(out_preds_d.tbegin(param.gpu_id),
|
||||
out_preds_d.tend(param.gpu_id), out_preds->begin());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PredictFromCacheDevice(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit) {
|
||||
if (ntree_limit == 0 ||
|
||||
ntree_limit * model.param.num_output_group >= model.trees.size()) {
|
||||
auto it = device_cache_.find(dmat);
|
||||
if (it != device_cache_.end()) {
|
||||
HostDeviceVector<bst_float>& y = it->second.predictions;
|
||||
if (y.size() != 0) {
|
||||
out_preds->resize(y.size(), param.gpu_id);
|
||||
thrust::copy(y.tbegin(param.gpu_id), y.tend(param.gpu_id),
|
||||
out_preds->tbegin(param.gpu_id));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees) override {
|
||||
auto old_ntree = model.trees.size() - num_new_trees;
|
||||
// update cache entry
|
||||
for (auto& kv : cache_) {
|
||||
PredictionCacheEntry& e = kv.second;
|
||||
for (auto& kv : device_cache_) {
|
||||
DevicePredictionCacheEntry& e = kv.second;
|
||||
DMatrix* dmat = kv.first;
|
||||
HostDeviceVector<bst_float>& predictions = e.predictions;
|
||||
|
||||
if (e.predictions.size() == 0) {
|
||||
cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0,
|
||||
if (predictions.size() == 0) {
|
||||
// ensure that the device in predictions is correct
|
||||
predictions.resize(0, param.gpu_id);
|
||||
cpu_predictor->PredictBatch(dmat, &predictions.data_h(), model, 0,
|
||||
static_cast<bst_uint>(model.trees.size()));
|
||||
} else if (model.param.num_output_group == 1 && updaters->size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
updaters->back()->UpdatePredictionCache(e.data.get(),
|
||||
&(e.predictions))) {
|
||||
{} // do nothing
|
||||
&predictions)) {
|
||||
// do nothing
|
||||
} else {
|
||||
DevicePredictInternal(dmat, &(e.predictions), model, old_ntree,
|
||||
DevicePredictInternal(dmat, &predictions, model, old_ntree,
|
||||
model.trees.size());
|
||||
}
|
||||
}
|
||||
@ -391,6 +465,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
Predictor::Init(cfg, cache);
|
||||
cpu_predictor->Init(cfg, cache);
|
||||
param.InitAllowUnknown(cfg);
|
||||
for (const std::shared_ptr<DMatrix>& d : cache)
|
||||
device_cache_[d.get()].data = d;
|
||||
max_shared_memory_bytes = dh::max_shared_memory(param.gpu_id);
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
} // namespace dmlc
|
||||
@ -20,6 +22,17 @@ TreeUpdater* TreeUpdater::Create(const std::string& name) {
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
void TreeUpdater::Update(HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix* data,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
Update(gpair->data_h(), data, trees);
|
||||
}
|
||||
|
||||
bool TreeUpdater::UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) {
|
||||
return UpdatePredictionCache(data, &out_preds->data_h());
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <vector>
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/timer.h"
|
||||
#include "param.h"
|
||||
@ -349,7 +350,8 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
// Reset values for each update iteration
|
||||
void Reset(const std::vector<bst_gpair>& host_gpair) {
|
||||
void Reset(HostDeviceVector<bst_gpair>* dh_gpair, int device) {
|
||||
auto begin = dh_gpair->tbegin(device);
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
position.current_dvec().fill(0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
@ -359,8 +361,8 @@ struct DeviceShard {
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.size());
|
||||
this->gpair.copy(host_gpair.begin() + row_begin_idx,
|
||||
host_gpair.begin() + row_end_idx);
|
||||
this->gpair.copy(begin + row_begin_idx,
|
||||
begin + row_end_idx);
|
||||
subsample_gpair(&gpair, param.subsample, row_begin_idx);
|
||||
hist.Reset();
|
||||
}
|
||||
@ -504,9 +506,28 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
monitor.Init("updater_gpu_hist", param.debug_verbose);
|
||||
}
|
||||
|
||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor.Start("Update", dList);
|
||||
// TODO(canonizer): move it into the class if this ever becomes a bottleneck
|
||||
HostDeviceVector<bst_gpair> gpair_d(gpair.size(), param.gpu_id);
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id));
|
||||
Update(&gpair_d, dmat, trees);
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor.Start("Update", dList);
|
||||
UpdateHelper(gpair, dmat, trees);
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
private:
|
||||
void UpdateHelper(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
GradStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
@ -521,9 +542,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
public:
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
info = &dmat->info();
|
||||
monitor.Start("Quantiles", dList);
|
||||
@ -572,7 +593,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
initialised = true;
|
||||
}
|
||||
|
||||
void InitData(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
void InitData(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
|
||||
const RegTree& tree) {
|
||||
monitor.Start("InitDataOnce", dList);
|
||||
if (!initialised) {
|
||||
@ -585,11 +606,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
// Copy gpair & reset memory
|
||||
monitor.Start("InitDataReset", dList);
|
||||
omp_set_num_threads(shards.size());
|
||||
#pragma omp parallel
|
||||
{
|
||||
auto cpu_thread_id = omp_get_thread_num();
|
||||
shards[cpu_thread_id]->Reset(gpair);
|
||||
}
|
||||
|
||||
// TODO(canonizer): make it parallel again once HostDeviceVector is thread-safe
|
||||
for (int shard = 0; shard < shards.size(); ++shard)
|
||||
shards[shard]->Reset(gpair, param.gpu_id);
|
||||
monitor.Stop("InitDataReset", dList);
|
||||
}
|
||||
|
||||
@ -687,7 +707,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
return std::move(best_splits);
|
||||
}
|
||||
|
||||
void InitRoot(const std::vector<bst_gpair>& gpair, RegTree* p_tree) {
|
||||
void InitRoot(RegTree* p_tree) {
|
||||
auto root_nidx = 0;
|
||||
// Sum gradients
|
||||
std::vector<bst_gpair> tmp_sums(shards.size());
|
||||
@ -800,7 +820,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
}
|
||||
|
||||
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
|
||||
void UpdateTree(HostDeviceVector<bst_gpair>* gpair, DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
// Temporarily store number of threads so we can change it back later
|
||||
int nthread = omp_get_max_threads();
|
||||
@ -811,7 +831,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
this->InitData(gpair, p_fmat, *p_tree);
|
||||
monitor.Stop("InitData", dList);
|
||||
monitor.Start("InitRoot", dList);
|
||||
this->InitRoot(gpair, p_tree);
|
||||
this->InitRoot(p_tree);
|
||||
monitor.Stop("InitRoot", dList);
|
||||
|
||||
auto timestamp = qexpand_->size();
|
||||
@ -854,6 +874,16 @@ class GPUHistMaker : public TreeUpdater {
|
||||
omp_set_num_threads(nthread);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* p_out_preds) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
struct ExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
|
||||
69
tests/cpp/objective/test_regression_obj_gpu.cu
Normal file
69
tests/cpp/objective/test_regression_obj_gpu.cu
Normal file
@ -0,0 +1,69 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, GPULinearRegressionGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRegressionGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRegressionBasic) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
|
||||
// test label validation
|
||||
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0}))
|
||||
<< "Expected error when label not in range [0,1f] for LogisticRegression";
|
||||
|
||||
// test ProbToMargin
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f);
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f);
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f);
|
||||
EXPECT_ANY_THROW(obj->ProbToMargin(10))
|
||||
<< "Expected error when base_score not in range [0,1f] for LogisticRegression";
|
||||
|
||||
// test PredTransform
|
||||
std::vector<xgboost::bst_float> preds = {0, 0.1f, 0.5f, 0.9f, 1};
|
||||
std::vector<xgboost::bst_float> out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f};
|
||||
obj->PredTransform(&preds);
|
||||
for (int i = 0; i < static_cast<int>(preds.size()); ++i) {
|
||||
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRawGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user