From 84ab74f3a56739829b03161fb9c249f3a760a518 Mon Sep 17 00:00:00 2001 From: Thejaswi Date: Fri, 12 Jan 2018 14:03:39 +0530 Subject: [PATCH] Objective function evaluation on GPU with minimal PCIe transfers (#2935) * Added GPU objective function and no-copy interface. - xgboost::HostDeviceVector syncs automatically between host and device - no-copy interfaces have been added - default implementations just sync the data to host and call the implementations with std::vector - GPU objective function, predictor, histogram updater process data directly on GPU --- amalgamation/xgboost-all0.cc | 1 + doc/parameter.md | 3 + include/xgboost/gbm.h | 8 + include/xgboost/objective.h | 13 + include/xgboost/predictor.h | 21 +- include/xgboost/tree_updater.h | 7 + src/common/device_helpers.cuh | 7 + src/common/host_device_vector.cc | 54 ++++ src/common/host_device_vector.cu | 135 ++++++++++ src/common/host_device_vector.h | 100 ++++++++ src/common/math.h | 4 +- src/gbm/gbm.cc | 13 + src/gbm/gbtree.cc | 101 +++++--- src/learner.cc | 16 +- src/objective/objective.cc | 17 ++ src/objective/regression_loss.h | 110 ++++++++ src/objective/regression_obj.cc | 59 +---- src/objective/regression_obj_gpu.cu | 241 ++++++++++++++++++ src/predictor/cpu_predictor.cc | 7 + src/predictor/gpu_predictor.cu | 108 ++++++-- src/tree/tree_updater.cc | 13 + src/tree/updater_gpu_hist.cu | 56 +++- .../cpp/objective/test_regression_obj_gpu.cu | 69 +++++ 23 files changed, 1036 insertions(+), 127 deletions(-) create mode 100644 src/common/host_device_vector.cc create mode 100644 src/common/host_device_vector.cu create mode 100644 src/common/host_device_vector.h create mode 100644 src/objective/regression_loss.h create mode 100644 src/objective/regression_obj_gpu.cu create mode 100644 tests/cpp/objective/test_regression_obj_gpu.cu diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 88f780cbf..4ad5fe96a 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -57,6 +57,7 @@ #include "../src/learner.cc" #include "../src/logging.cc" #include "../src/common/common.cc" +#include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" // c_api diff --git a/doc/parameter.md b/doc/parameter.md index d9eaadab1..d1a8e8f58 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -165,6 +165,9 @@ Specify the learning task and the corresponding learning objective. The objectiv - "reg:logistic" --logistic regression - "binary:logistic" --logistic regression for binary classification, output probability - "binary:logitraw" --logistic regression for binary classification, output score before logistic transformation + - "gpu:reg:linear", "gpu:reg:logistic", "gpu:binary:logistic", gpu:binary:logitraw" --versions + of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm, + they can only be used when the entire training session uses the same dataset - "count:poisson" --poisson regression for count data, output mean of poisson distribution - max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization) - "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes) diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 877890509..58babdf29 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -18,6 +18,7 @@ #include "./data.h" #include "./objective.h" #include "./feature_map.h" +#include "../../src/common/host_device_vector.h" namespace xgboost { /*! @@ -70,6 +71,10 @@ class GradientBooster { virtual void DoBoost(DMatrix* p_fmat, std::vector* in_gpair, ObjFunction* obj = nullptr) = 0; + virtual void DoBoost(DMatrix* p_fmat, + HostDeviceVector* in_gpair, + ObjFunction* obj = nullptr); + /*! * \brief generate predictions for given feature matrix * \param dmat feature matrix @@ -80,6 +85,9 @@ class GradientBooster { virtual void PredictBatch(DMatrix* dmat, std::vector* out_preds, unsigned ntree_limit = 0) = 0; + virtual void PredictBatch(DMatrix* dmat, + HostDeviceVector* out_preds, + unsigned ntree_limit = 0); /*! * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is usually diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 08201d59e..3f26db891 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -14,8 +14,11 @@ #include #include "./data.h" #include "./base.h" +#include "../../src/common/host_device_vector.h" + namespace xgboost { + /*! \brief interface of objective function */ class ObjFunction { public: @@ -45,6 +48,11 @@ class ObjFunction { const MetaInfo& info, int iteration, std::vector* out_gpair) = 0; + virtual void GetGradient(HostDeviceVector* preds, + const MetaInfo& info, + int iteration, + HostDeviceVector* out_gpair); + /*! \return the default evaluation metric for the objective */ virtual const char* DefaultEvalMetric() const = 0; // the following functions are optional, most of time default implementation is good enough @@ -53,6 +61,8 @@ class ObjFunction { * \param io_preds prediction values, saves to this vector as well */ virtual void PredTransform(std::vector *io_preds) {} + virtual void PredTransform(HostDeviceVector *io_preds); + /*! * \brief transform prediction values, this is only called when Eval is called, * usually it redirect to PredTransform @@ -61,6 +71,9 @@ class ObjFunction { virtual void EvalTransform(std::vector *io_preds) { this->PredTransform(io_preds); } + virtual void EvalTransform(HostDeviceVector *io_preds) { + this->PredTransform(io_preds); + } /*! * \brief transform probability value back to margin * this is used to transform user-set base_score back to margin diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index cc89bb60a..5fd857654 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -13,6 +13,7 @@ #include #include #include "../../src/gbm/gbtree_model.h" +#include "../../src/common/host_device_vector.h" // Forward declarations namespace xgboost { @@ -51,10 +52,6 @@ class Predictor { const std::vector>& cache); /** - * \fn virtual void Predictor::PredictBatch( DMatrix* dmat, - * std::vector* out_preds, const gbm::GBTreeModel &model, int - * tree_begin, unsigned ntree_limit = 0) = 0; - * * \brief Generate batch predictions for a given feature matrix. May use * cached predictions if available instead of calculating from scratch. * @@ -70,6 +67,22 @@ class Predictor { const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) = 0; + /** + * \brief Generate batch predictions for a given feature matrix. May use + * cached predictions if available instead of calculating from scratch. + * + * \param [in,out] dmat Feature matrix. + * \param [in,out] out_preds The output preds. + * \param model The model to predict from. + * \param tree_begin The tree begin index. + * \param ntree_limit (Optional) The ntree limit. 0 means do not + * limit trees. + */ + + virtual void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) = 0; + /** * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel * &model, std::vector >* updaters, int diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 5f5dd5ecf..8dbfa6cae 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -16,6 +16,7 @@ #include "./base.h" #include "./data.h" #include "./tree_model.h" +#include "../../src/common/host_device_vector.h" namespace xgboost { /*! @@ -42,6 +43,9 @@ class TreeUpdater { virtual void Update(const std::vector& gpair, DMatrix* data, const std::vector& trees) = 0; + virtual void Update(HostDeviceVector* gpair, + DMatrix* data, + const std::vector& trees); /*! * \brief determines whether updater has enough knowledge about a given dataset @@ -57,6 +61,9 @@ class TreeUpdater { std::vector* out_preds) { return false; } + virtual bool UpdatePredictionCache(const DMatrix* data, + HostDeviceVector* out_preds); + /*! * \brief Create a tree updater given name * \param name Name of the tree updater. diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index ace609b6e..2d7c602d3 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -484,6 +484,13 @@ class bulk_allocator { } public: + bulk_allocator() {} + // prevent accidental copying, moving or assignment of this object + bulk_allocator(const bulk_allocator&) = delete; + bulk_allocator(bulk_allocator&&) = delete; + void operator=(const bulk_allocator&) = delete; + void operator=(bulk_allocator&&) = delete; + ~bulk_allocator() { for (size_t i = 0; i < d_ptr.size(); i++) { if (!(d_ptr[i] == nullptr)) { diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc new file mode 100644 index 000000000..154a80cf3 --- /dev/null +++ b/src/common/host_device_vector.cc @@ -0,0 +1,54 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#ifndef XGBOOST_USE_CUDA + +// dummy implementation of HostDeviceVector in case CUDA is not used + +#include +#include "./host_device_vector.h" + +namespace xgboost { + +template +struct HostDeviceVectorImpl { + explicit HostDeviceVectorImpl(size_t size) : data_h_(size) {} + std::vector data_h_; +}; + +template +HostDeviceVector::HostDeviceVector(size_t size, int device) : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(size); +} + +template +HostDeviceVector::~HostDeviceVector() { + HostDeviceVectorImpl* tmp = impl_; + impl_ = nullptr; + delete tmp; +} + +template +size_t HostDeviceVector::size() const { return impl_->data_h_.size(); } + +template +int HostDeviceVector::device() const { return -1; } + +template +T* HostDeviceVector::ptr_d(int device) { return nullptr; } + +template +std::vector& HostDeviceVector::data_h() { return impl_->data_h_; } + +template +void HostDeviceVector::resize(size_t new_size, int new_device) { + impl_->data_h_.resize(new_size); +} + +// explicit instantiations are required, as HostDeviceVector isn't header-only +template class HostDeviceVector; +template class HostDeviceVector; + +} // namespace xgboost + +#endif diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu new file mode 100644 index 000000000..4370ef21e --- /dev/null +++ b/src/common/host_device_vector.cu @@ -0,0 +1,135 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#include "./host_device_vector.h" +#include "./device_helpers.cuh" + +namespace xgboost { + +template +struct HostDeviceVectorImpl { + HostDeviceVectorImpl(size_t size, int device) + : device_(device), on_d_(device >= 0) { + if (on_d_) { + dh::safe_cuda(cudaSetDevice(device_)); + data_d_.resize(size); + } else { + data_h_.resize(size); + } + } + HostDeviceVectorImpl(const HostDeviceVectorImpl&) = delete; + HostDeviceVectorImpl(HostDeviceVectorImpl&&) = delete; + void operator=(const HostDeviceVectorImpl&) = delete; + void operator=(HostDeviceVectorImpl&&) = delete; + + size_t size() const { return on_d_ ? data_d_.size() : data_h_.size(); } + + int device() const { return device_; } + + T* ptr_d(int device) { + lazy_sync_device(device); + return data_d_.data().get(); + } + thrust::device_ptr tbegin(int device) { + return thrust::device_ptr(ptr_d(device)); + } + thrust::device_ptr tend(int device) { + auto begin = tbegin(device); + return begin + size(); + } + std::vector& data_h() { + lazy_sync_host(); + return data_h_; + } + void resize(size_t new_size, int new_device) { + if (new_size == this->size() && new_device == device_) + return; + device_ = new_device; + // if !on_d_, but the data size is 0 and the device is set, + // resize the data on device instead + if (!on_d_ && (data_h_.size() > 0 || device_ == -1)) { + data_h_.resize(new_size); + } else { + dh::safe_cuda(cudaSetDevice(device_)); + data_d_.resize(new_size); + on_d_ = true; + } + } + + void lazy_sync_host() { + if (!on_d_) + return; + if (data_h_.size() != this->size()) + data_h_.resize(this->size()); + dh::safe_cuda(cudaSetDevice(device_)); + thrust::copy(data_d_.begin(), data_d_.end(), data_h_.begin()); + on_d_ = false; + } + + void lazy_sync_device(int device) { + if (on_d_) + return; + if (device != device_) { + CHECK_EQ(device_, -1); + device_ = device; + } + if (data_d_.size() != this->size()) { + dh::safe_cuda(cudaSetDevice(device_)); + data_d_.resize(this->size()); + } + dh::safe_cuda(cudaSetDevice(device_)); + thrust::copy(data_h_.begin(), data_h_.end(), data_d_.begin()); + on_d_ = true; + } + + std::vector data_h_; + thrust::device_vector data_d_; + // true if there is an up-to-date copy of data on device, false otherwise + bool on_d_; + int device_; +}; + +template +HostDeviceVector::HostDeviceVector(size_t size, int device) : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(size, device); +} + +template +HostDeviceVector::~HostDeviceVector() { + HostDeviceVectorImpl* tmp = impl_; + impl_ = nullptr; + delete tmp; +} + +template +size_t HostDeviceVector::size() const { return impl_->size(); } + +template +int HostDeviceVector::device() const { return impl_->device(); } + +template +T* HostDeviceVector::ptr_d(int device) { return impl_->ptr_d(device); } + +template +thrust::device_ptr HostDeviceVector::tbegin(int device) { + return impl_->tbegin(device); +} + +template +thrust::device_ptr HostDeviceVector::tend(int device) { + return impl_->tend(device); +} + +template +std::vector& HostDeviceVector::data_h() { return impl_->data_h(); } + +template +void HostDeviceVector::resize(size_t new_size, int new_device) { + impl_->resize(new_size, new_device); +} + +// explicit instantiations are required, as HostDeviceVector isn't header-only +template class HostDeviceVector; +template class HostDeviceVector; + +} // namespace xgboost diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h new file mode 100644 index 000000000..fc0ca0660 --- /dev/null +++ b/src/common/host_device_vector.h @@ -0,0 +1,100 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ +#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ + +#include +#include + +// only include thrust-related files if host_device_vector.h +// is included from a .cu file +#ifdef __CUDACC__ +#include +#endif + +namespace xgboost { + +template struct HostDeviceVectorImpl; + +/** + * @file host_device_vector.h + * @brief A device-and-host vector abstraction layer. + * + * Why HostDeviceVector?
+ * With CUDA, one has to explicitly manage memory through 'cudaMemcpy' calls. + * This wrapper class hides this management from the users, thereby making it + * easy to integrate GPU/CPU usage under a single interface. + * + * Initialization/Allocation:
+ * One can choose to initialize the vector on CPU or GPU during constructor. + * (use the 'device' argument) Or, can choose to use the 'resize' method to + * allocate/resize memory explicitly. + * + * Accessing underling data:
+ * Use 'data_h' method to explicitly query for the underlying std::vector. + * If you need the raw device pointer, use the 'ptr_d' method. For perf + * implications of these calls, see below. + * + * Accessing underling data and their perf implications:
+ * There are 4 scenarios to be considered here: + * data_h and data on CPU --> no problems, std::vector returned immediately + * data_h but data on GPU --> this causes a cudaMemcpy to be issued internally. + * subsequent calls to data_h, will NOT incur this penalty. + * (assuming 'ptr_d' is not called in between) + * ptr_d but data on CPU --> this causes a cudaMemcpy to be issued internally. + * subsequent calls to ptr_d, will NOT incur this penalty. + * (assuming 'data_h' is not called in between) + * ptr_d and data on GPU --> no problems, the device ptr will be returned immediately + * + * What if xgboost is compiled without CUDA?
+ * In that case, there's a special implementation which always falls-back to + * working with std::vector. This logic can be found in host_device_vector.cc + * + * Why not consider CUDA unified memory?
+ * We did consider. However, it poses complications if we need to support both + * compiling with and without CUDA toolkit. It was easier to have + * 'HostDeviceVector' with a special-case implementation in host_device_vector.cc + * + * @note: This is not thread-safe! + */ +template +class HostDeviceVector { + public: + explicit HostDeviceVector(size_t size = 0, int device = -1); + ~HostDeviceVector(); + HostDeviceVector(const HostDeviceVector&) = delete; + HostDeviceVector(HostDeviceVector&&) = delete; + void operator=(const HostDeviceVector&) = delete; + void operator=(HostDeviceVector&&) = delete; + size_t size() const; + int device() const; + T* ptr_d(int device); + + // only define functions returning device_ptr + // if HostDeviceVector.h is included from a .cu file +#ifdef __CUDACC__ + thrust::device_ptr tbegin(int device); + thrust::device_ptr tend(int device); +#endif + + std::vector& data_h(); + void resize(size_t new_size, int new_device); + + // helper functions in case a function needs to be templated + // to work for both HostDeviceVector and std::vector + static std::vector& data_h(HostDeviceVector* v) { + return v->data_h(); + } + + static std::vector& data_h(std::vector* v) { + return *v; + } + + private: + HostDeviceVectorImpl* impl_; +}; + +} // namespace xgboost + +#endif // XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ diff --git a/src/common/math.h b/src/common/math.h index 6e594032e..fb2459f44 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -20,8 +20,8 @@ namespace common { * \param x input parameter * \return the transformed value. */ -inline float Sigmoid(float x) { - return 1.0f / (1.0f + std::exp(-x)); +XGBOOST_DEVICE inline float Sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); } inline avx::Float8 Sigmoid(avx::Float8 x) { diff --git a/src/gbm/gbm.cc b/src/gbm/gbm.cc index aa93e51db..4d7ee0975 100644 --- a/src/gbm/gbm.cc +++ b/src/gbm/gbm.cc @@ -21,6 +21,19 @@ GradientBooster* GradientBooster::Create( } return (e->body)(cache_mats, base_margin); } + +void GradientBooster::DoBoost(DMatrix* p_fmat, + HostDeviceVector* in_gpair, + ObjFunction* obj) { + DoBoost(p_fmat, &in_gpair->data_h(), obj); +} + +void GradientBooster::PredictBatch(DMatrix* dmat, + HostDeviceVector* out_preds, + unsigned ntree_limit) { + PredictBatch(dmat, &out_preds->data_h(), ntree_limit); +} + } // namespace xgboost namespace xgboost { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index c9055d7c0..0e80386bd 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -18,6 +18,7 @@ #include #include #include "../common/common.h" +#include "../common/host_device_vector.h" #include "../common/random.h" #include "gbtree_model.h" #include "../common/timer.h" @@ -182,35 +183,13 @@ class GBTree : public GradientBooster { void DoBoost(DMatrix* p_fmat, std::vector* in_gpair, ObjFunction* obj) override { - const std::vector& gpair = *in_gpair; - std::vector > > new_trees; - const int ngroup = model_.param.num_output_group; - monitor.Start("BoostNewTrees"); - if (ngroup == 1) { - std::vector > ret; - BoostNewTrees(gpair, p_fmat, 0, &ret); - new_trees.push_back(std::move(ret)); - } else { - CHECK_EQ(gpair.size() % ngroup, 0U) - << "must have exactly ngroup*nrow gpairs"; - std::vector tmp(gpair.size() / ngroup); - for (int gid = 0; gid < ngroup; ++gid) { - bst_omp_uint nsize = static_cast(tmp.size()); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize; ++i) { - tmp[i] = gpair[i * ngroup + gid]; - } - std::vector > ret; - BoostNewTrees(tmp, p_fmat, gid, &ret); - new_trees.push_back(std::move(ret)); - } - } - monitor.Stop("BoostNewTrees"); - monitor.Start("CommitModel"); - for (int gid = 0; gid < ngroup; ++gid) { - this->CommitModel(std::move(new_trees[gid]), gid); - } - monitor.Stop("CommitModel"); + DoBoostHelper(p_fmat, in_gpair, obj); + } + + void DoBoost(DMatrix* p_fmat, + HostDeviceVector* in_gpair, + ObjFunction* obj) override { + DoBoostHelper(p_fmat, in_gpair, obj); } void PredictBatch(DMatrix* p_fmat, @@ -219,6 +198,12 @@ class GBTree : public GradientBooster { predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); } + void PredictBatch(DMatrix* p_fmat, + HostDeviceVector* out_preds, + unsigned ntree_limit) override { + predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); + } + void PredictInstance(const SparseBatch::Inst& inst, std::vector* out_preds, unsigned ntree_limit, @@ -257,9 +242,48 @@ class GBTree : public GradientBooster { updaters.push_back(std::move(up)); } } + + // TVec is either std::vector or HostDeviceVector + template + void DoBoostHelper(DMatrix* p_fmat, + TVec* in_gpair, + ObjFunction* obj) { + std::vector > > new_trees; + const int ngroup = model_.param.num_output_group; + monitor.Start("BoostNewTrees"); + if (ngroup == 1) { + std::vector > ret; + BoostNewTrees(in_gpair, p_fmat, 0, &ret); + new_trees.push_back(std::move(ret)); + } else { + CHECK_EQ(in_gpair->size() % ngroup, 0U) + << "must have exactly ngroup*nrow gpairs"; + std::vector tmp(in_gpair->size() / ngroup); + auto& gpair_h = HostDeviceVector::data_h(in_gpair); + for (int gid = 0; gid < ngroup; ++gid) { + bst_omp_uint nsize = static_cast(tmp.size()); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + tmp[i] = gpair_h[i * ngroup + gid]; + } + std::vector > ret; + BoostNewTrees(&tmp, p_fmat, gid, &ret); + new_trees.push_back(std::move(ret)); + } + } + monitor.Stop("BoostNewTrees"); + monitor.Start("CommitModel"); + for (int gid = 0; gid < ngroup; ++gid) { + this->CommitModel(std::move(new_trees[gid]), gid); + } + monitor.Stop("CommitModel"); + } + // do group specific group + // TVec is either const std::vector or HostDeviceVector + template inline void - BoostNewTrees(const std::vector &gpair, + BoostNewTrees(TVec* gpair, DMatrix *p_fmat, int bst_group, std::vector >* ret) { @@ -286,9 +310,24 @@ class GBTree : public GradientBooster { } // update the trees for (auto& up : updaters) { - up->Update(gpair, p_fmat, new_trees); + UpdateHelper(up.get(), gpair, p_fmat, new_trees); } } + + void UpdateHelper(TreeUpdater* updater, + std::vector* gpair, + DMatrix *p_fmat, + const std::vector& new_trees) { + updater->Update(*gpair, p_fmat, new_trees); + } + + void UpdateHelper(TreeUpdater* updater, + HostDeviceVector* gpair, + DMatrix *p_fmat, + const std::vector& new_trees) { + updater->Update(gpair, p_fmat, new_trees); + } + // commit new trees all at once virtual void CommitModel(std::vector >&& new_trees, diff --git a/src/learner.cc b/src/learner.cc index 117f73c92..321338d4a 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -16,10 +16,12 @@ #include #include #include "./common/common.h" +#include "./common/host_device_vector.h" #include "./common/io.h" #include "./common/random.h" #include "common/timer.h" + namespace xgboost { // implementation of base learner. bool Learner::AllowLazyCheckPoint() const { @@ -360,10 +362,10 @@ class LearnerImpl : public Learner { } this->LazyInitDMatrix(train); monitor.Start("PredictRaw"); - this->PredictRaw(train, &preds_); + this->PredictRaw(train, &preds2_); monitor.Stop("PredictRaw"); monitor.Start("GetGradient"); - obj_->GetGradient(preds_, train->info(), iter, &gpair_); + obj_->GetGradient(&preds2_, train->info(), iter, &gpair_); monitor.Stop("GetGradient"); gbm_->DoBoost(train, &gpair_, obj_.get()); monitor.Stop("UpdateOneIter"); @@ -547,6 +549,13 @@ class LearnerImpl : public Learner { << "Predict must happen after Load or InitModel"; gbm_->PredictBatch(data, out_preds, ntree_limit); } + inline void PredictRaw(DMatrix* data, HostDeviceVector* out_preds, + unsigned ntree_limit = 0) const { + CHECK(gbm_.get() != nullptr) + << "Predict must happen after Load or InitModel"; + gbm_->PredictBatch(data, out_preds, ntree_limit); + } + // model parameter LearnerModelParam mparam; // training parameter @@ -561,8 +570,9 @@ class LearnerImpl : public Learner { std::string name_obj_; // temporal storages for prediction std::vector preds_; + HostDeviceVector preds2_; // gradient pairs - std::vector gpair_; + HostDeviceVector gpair_; private: /*! \brief random number transformation seed. */ diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 413494d3d..53f52ac9f 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -6,6 +6,8 @@ #include #include +#include "../common/host_device_vector.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); } // namespace dmlc @@ -22,12 +24,27 @@ ObjFunction* ObjFunction::Create(const std::string& name) { } return (e->body)(); } + +void ObjFunction::GetGradient(HostDeviceVector* preds, + const MetaInfo& info, + int iteration, + HostDeviceVector* out_gpair) { + GetGradient(preds->data_h(), info, iteration, &out_gpair->data_h()); +} + +void ObjFunction::PredTransform(HostDeviceVector *io_preds) { + PredTransform(&io_preds->data_h()); +} + } // namespace xgboost namespace xgboost { namespace obj { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(regression_obj); +#ifdef XGBOOST_USE_CUDA + DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); +#endif DMLC_REGISTRY_LINK_TAG(multiclass_obj); DMLC_REGISTRY_LINK_TAG(rank_obj); } // namespace obj diff --git a/src/objective/regression_loss.h b/src/objective/regression_loss.h new file mode 100644 index 000000000..16cc27092 --- /dev/null +++ b/src/objective/regression_loss.h @@ -0,0 +1,110 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ +#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ + +#include +#include +#include +#include "../common/math.h" + +namespace xgboost { +namespace obj { + +// common regressions +// linear regression +struct LinearSquareLoss { + // duplication is necessary, as __device__ specifier + // cannot be made conditional on template parameter + XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; } + XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return true; } + XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) { + return predt - label; + } + XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) { + return 1.0f; + } + template + static T PredTransform(T x) { return x; } + template + static T FirstOrderGradient(T predt, T label) { return predt - label; } + template + static T SecondOrderGradient(T predt, T label) { return T(1.0f); } + static bst_float ProbToMargin(bst_float base_score) { return base_score; } + static const char* LabelErrorMsg() { return ""; } + static const char* DefaultEvalMetric() { return "rmse"; } +}; + +// logistic loss for probability regression task +struct LogisticRegression { + // duplication is necessary, as __device__ specifier + // cannot be made conditional on template parameter + XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); } + XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } + XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) { + return predt - label; + } + XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) { + const float eps = 1e-16f; + return fmaxf(predt * (1.0f - predt), eps); + } + template + static T PredTransform(T x) { return common::Sigmoid(x); } + template + static T FirstOrderGradient(T predt, T label) { return predt - label; } + template + static T SecondOrderGradient(T predt, T label) { + const T eps = T(1e-16f); + return std::max(predt * (T(1.0f) - predt), eps); + } + static bst_float ProbToMargin(bst_float base_score) { + CHECK(base_score > 0.0f && base_score < 1.0f) + << "base_score must be in (0,1) for logistic loss"; + return -logf(1.0f / base_score - 1.0f); + } + static const char* LabelErrorMsg() { + return "label must be in [0,1] for logistic regression"; + } + static const char* DefaultEvalMetric() { return "rmse"; } +}; + +// logistic loss for binary classification task +struct LogisticClassification : public LogisticRegression { + static const char* DefaultEvalMetric() { return "error"; } +}; + +// logistic loss, but predict un-transformed margin +struct LogisticRaw : public LogisticRegression { + // duplication is necessary, as __device__ specifier + // cannot be made conditional on template parameter + XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; } + XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) { + predt = common::Sigmoid(predt); + return predt - label; + } + XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) { + const float eps = 1e-16f; + predt = common::Sigmoid(predt); + return fmaxf(predt * (1.0f - predt), eps); + } + template + static T PredTransform(T x) { return x; } + template + static T FirstOrderGradient(T predt, T label) { + predt = common::Sigmoid(predt); + return predt - label; + } + template + static T SecondOrderGradient(T predt, T label) { + const T eps = T(1e-16f); + predt = common::Sigmoid(predt); + return std::max(predt * (T(1.0f) - predt), eps); + } + static const char* DefaultEvalMetric() { return "auc"; } +}; + +} // namespace obj +} // namespace xgboost + +#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 83db376c8..92f1c5d3f 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -12,70 +12,13 @@ #include #include "../common/math.h" #include "../common/avx_helpers.h" +#include "./regression_loss.h" namespace xgboost { namespace obj { DMLC_REGISTRY_FILE_TAG(regression_obj); -// common regressions -// linear regression -struct LinearSquareLoss { - template - static T PredTransform(T x) { return x; } - static bool CheckLabel(bst_float x) { return true; } - template - static T FirstOrderGradient(T predt, T label) { return predt - label; } - template - static T SecondOrderGradient(T predt, T label) { return T(1.0f); } - static bst_float ProbToMargin(bst_float base_score) { return base_score; } - static const char* LabelErrorMsg() { return ""; } - static const char* DefaultEvalMetric() { return "rmse"; } -}; -// logistic loss for probability regression task -struct LogisticRegression { - template - static T PredTransform(T x) { return common::Sigmoid(x); } - static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } - template - static T FirstOrderGradient(T predt, T label) { return predt - label; } - template - static T SecondOrderGradient(T predt, T label) { - const T eps = T(1e-16f); - return std::max(predt * (T(1.0f) - predt), eps); - } - static bst_float ProbToMargin(bst_float base_score) { - CHECK(base_score > 0.0f && base_score < 1.0f) - << "base_score must be in (0,1) for logistic loss"; - return -std::log(1.0f / base_score - 1.0f); - } - static const char* LabelErrorMsg() { - return "label must be in [0,1] for logistic regression"; - } - static const char* DefaultEvalMetric() { return "rmse"; } -}; -// logistic loss for binary classification task. -struct LogisticClassification : public LogisticRegression { - static const char* DefaultEvalMetric() { return "error"; } -}; -// logistic loss, but predict un-transformed margin -struct LogisticRaw : public LogisticRegression { - template - static T PredTransform(T x) { return x; } - template - static T FirstOrderGradient(T predt, T label) { - predt = common::Sigmoid(predt); - return predt - label; - } - template - static T SecondOrderGradient(T predt, T label) { - const T eps = T(1e-16f); - predt = common::Sigmoid(predt); - return std::max(predt * (T(1.0f) - predt), eps); - } - static const char* DefaultEvalMetric() { return "auc"; } -}; - struct RegLossParam : public dmlc::Parameter { float scale_pos_weight; // declare parameters diff --git a/src/objective/regression_obj_gpu.cu b/src/objective/regression_obj_gpu.cu new file mode 100644 index 000000000..a55ad85bd --- /dev/null +++ b/src/objective/regression_obj_gpu.cu @@ -0,0 +1,241 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +// GPU implementation of objective function. +// Necessary to avoid extra copying of data to CPU. +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common/device_helpers.cuh" +#include "../common/host_device_vector.h" +#include "./regression_loss.h" + +using namespace dh; + +namespace xgboost { +namespace obj { + +DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); + +struct GPURegLossParam : public dmlc::Parameter { + float scale_pos_weight; + int n_gpus; + int gpu_id; + // declare parameters + DMLC_DECLARE_PARAMETER(GPURegLossParam) { + DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) + .describe("Scale the weight of positive examples by this factor"); + DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)"); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +// GPU kernel for gradient computation +template +__global__ void get_gradient_k +(bst_gpair *__restrict__ out_gpair, uint *__restrict__ label_correct, + const float * __restrict__ preds, const float * __restrict__ labels, + const float * __restrict__ weights, int n, float scale_pos_weight) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= n) + return; + float p = Loss::PredTransform(preds[i]); + float w = weights == nullptr ? 1.0f : weights[i]; + float label = labels[i]; + if (label == 1.0f) + w *= scale_pos_weight; + if (!Loss::CheckLabel(label)) + atomicAnd(label_correct, 0); + out_gpair[i] = bst_gpair + (Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w); +} + +// GPU kernel for predicate transformation +template +__global__ void pred_transform_k(float * __restrict__ preds, int n) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= n) + return; + preds[i] = Loss::PredTransform(preds[i]); +} + +// regression loss function for evaluation on GPU (eventually) +template +class GPURegLossObj : public ObjFunction { + protected: + // manages device data + struct DeviceData { + dvec labels, weights; + dvec label_correct; + + // allocate everything on device + DeviceData(bulk_allocator* ba, int device_idx, size_t n) { + ba->allocate(device_idx, false, + &labels, n, + &weights, n, + &label_correct, 1); + } + size_t size() const { return labels.size(); } + }; + + + bool copied_; + std::unique_ptr> ba_; + std::unique_ptr data_; + HostDeviceVector preds_d_; + HostDeviceVector out_gpair_d_; + + // allocate device data for n elements, do nothing if enough memory is allocated already + void LazyResize(int n) { + if (data_.get() != nullptr && data_->size() >= n) + return; + copied_ = false; + // free the old data and allocate the new data + ba_.reset(new bulk_allocator()); + data_.reset(new DeviceData(ba_.get(), 0, n)); + preds_d_.resize(n, param_.gpu_id); + out_gpair_d_.resize(n, param_.gpu_id); + } + + public: + GPURegLossObj() : copied_(false), preds_d_(0, -1), out_gpair_d_(0, -1) {} + + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; + } + void GetGradient(const std::vector &preds, + const MetaInfo &info, + int iter, + std::vector *out_gpair) override { + CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels.size()) + << "labels are not correctly provided" + << "preds.size=" << preds.size() << ", label.size=" << info.labels.size(); + + size_t ndata = preds.size(); + out_gpair->resize(ndata); + LazyResize(ndata); + thrust::copy(preds.begin(), preds.end(), preds_d_.tbegin(param_.gpu_id)); + GetGradientDevice(preds_d_.ptr_d(param_.gpu_id), info, iter, + out_gpair_d_.ptr_d(param_.gpu_id), ndata); + thrust::copy_n(out_gpair_d_.tbegin(param_.gpu_id), ndata, out_gpair->begin()); + } + + void GetGradient(HostDeviceVector* preds, + const MetaInfo &info, + int iter, + HostDeviceVector* out_gpair) override { + CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds->size(), info.labels.size()) + << "labels are not correctly provided" + << "preds.size=" << preds->size() << ", label.size=" << info.labels.size(); + size_t ndata = preds->size(); + out_gpair->resize(ndata, param_.gpu_id); + LazyResize(ndata); + GetGradientDevice(preds->ptr_d(param_.gpu_id), info, iter, + out_gpair->ptr_d(param_.gpu_id), ndata); + } + + private: + void GetGradientDevice(float* preds, + const MetaInfo &info, + int iter, + bst_gpair* out_gpair, size_t n) { + safe_cuda(cudaSetDevice(param_.gpu_id)); + DeviceData& d = *data_; + d.label_correct.fill(1); + // only copy the labels and weights once, similar to how the data is copied + if (!copied_) { + thrust::copy(info.labels.begin(), info.labels.begin() + n, + d.labels.tbegin()); + if (info.weights.size() > 0) { + thrust::copy(info.weights.begin(), info.weights.begin() + n, + d.weights.tbegin()); + } + copied_ = true; + } + + // run the kernel + const int block = 256; + get_gradient_k<<>> + (out_gpair, d.label_correct.data(), preds, + d.labels.data(), info.weights.size() > 0 ? d.weights.data() : nullptr, + n, param_.scale_pos_weight); + safe_cuda(cudaGetLastError()); + + // copy output data from the GPU + uint label_correct_h; + thrust::copy_n(d.label_correct.tbegin(), 1, &label_correct_h); + + bool label_correct = label_correct_h != 0; + if (!label_correct) { + LOG(FATAL) << Loss::LabelErrorMsg(); + } + } + + public: + const char* DefaultEvalMetric() const override { + return Loss::DefaultEvalMetric(); + } + + void PredTransform(std::vector *io_preds) override { + LazyResize(io_preds->size()); + thrust::copy(io_preds->begin(), io_preds->end(), preds_d_.tbegin(param_.gpu_id)); + PredTransformDevice(preds_d_.ptr_d(param_.gpu_id), io_preds->size()); + thrust::copy_n(preds_d_.tbegin(param_.gpu_id), io_preds->size(), io_preds->begin()); + } + + void PredTransform(HostDeviceVector *io_preds) override { + PredTransformDevice(io_preds->ptr_d(param_.gpu_id), io_preds->size()); + } + + void PredTransformDevice(float* preds, size_t n) { + safe_cuda(cudaSetDevice(param_.gpu_id)); + const int block = 256; + pred_transform_k<<>>(preds, n); + safe_cuda(cudaGetLastError()); + safe_cuda(cudaDeviceSynchronize()); + } + + + float ProbToMargin(float base_score) const override { + return Loss::ProbToMargin(base_score); + } + + protected: + GPURegLossParam param_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(GPURegLossParam); + +XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear") +.describe("Linear regression (computed on GPU).") +.set_body([]() { return new GPURegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic") +.describe("Logistic regression for probability regression task (computed on GPU).") +.set_body([]() { return new GPURegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic") +.describe("Logistic regression for binary classification task (computed on GPU).") +.set_body([]() { return new GPURegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw") +.describe("Logistic regression for classification, output score " + "before logistic transformation (computed on GPU)") +.set_body([]() { return new GPURegLossObj(); }); + +} // namespace obj +} // namespace xgboost diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index de0a85c49..a56a3d85e 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -5,6 +5,7 @@ #include #include #include "dmlc/logging.h" +#include "../common/host_device_vector.h" namespace xgboost { namespace predictor { @@ -108,6 +109,12 @@ class CPUPredictor : public Predictor { } public: + void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) override { + PredictBatch(dmat, &out_preds->data_h(), model, tree_begin, ntree_limit); + } + void PredictBatch(DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index d3d7cf421..aace2ebe6 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -3,13 +3,16 @@ */ #include #include +#include #include +#include #include #include #include #include #include #include "../common/device_helpers.cuh" +#include "../common/host_device_vector.h" namespace xgboost { namespace predictor { @@ -247,8 +250,16 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes, } class GPUPredictor : public xgboost::Predictor { + protected: + struct DevicePredictionCacheEntry { + std::shared_ptr data; + HostDeviceVector predictions; + }; + + std::unordered_map device_cache_; + private: - void DevicePredictInternal(DMatrix* dmat, std::vector* out_preds, + void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { if (tree_end - tree_begin == 0) { @@ -293,7 +304,7 @@ class GPUPredictor : public xgboost::Predictor { tree_group.begin()); device_matrix->predictions.resize(out_preds->size()); - thrust::copy(out_preds->begin(), out_preds->end(), + thrust::copy(out_preds->tbegin(param.gpu_id), out_preds->tend(param.gpu_id), device_matrix->predictions.begin()); const int BLOCK_THREADS = 128; @@ -319,19 +330,30 @@ class GPUPredictor : public xgboost::Predictor { dh::safe_cuda(cudaDeviceSynchronize()); thrust::copy(device_matrix->predictions.begin(), - device_matrix->predictions.end(), out_preds->begin()); + device_matrix->predictions.end(), out_preds->tbegin(param.gpu_id)); } + public: GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {} void PredictBatch(DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { - if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { + HostDeviceVector out_preds_d; + PredictBatch(dmat, &out_preds_d, model, tree_begin, ntree_limit); + out_preds->resize(out_preds_d.size()); + thrust::copy(out_preds_d.tbegin(param.gpu_id), + out_preds_d.tend(param.gpu_id), out_preds->begin()); + } + + void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) override { + if (this->PredictFromCacheDevice(dmat, out_preds, model, ntree_limit)) { return; } - this->InitOutPredictions(dmat->info(), out_preds, model); + this->InitOutPredictionsDevice(dmat->info(), out_preds, model); int tree_end = ntree_limit * model.param.num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -341,26 +363,78 @@ class GPUPredictor : public xgboost::Predictor { DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); } - void UpdatePredictionCache( - const gbm::GBTreeModel& model, - std::vector>* updaters, - int num_new_trees) override { + + void InitOutPredictionsDevice(const MetaInfo& info, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model) const { + size_t n = model.param.num_output_group * info.num_row; + const std::vector& base_margin = info.base_margin; + out_preds->resize(n, param.gpu_id); + if (base_margin.size() != 0) { + CHECK_EQ(out_preds->size(), n); + thrust::copy(base_margin.begin(), base_margin.end(), out_preds->tbegin(param.gpu_id)); + } else { + thrust::fill(out_preds->tbegin(param.gpu_id), + out_preds->tend(param.gpu_id), model.base_margin); + } + } + + bool PredictFromCache(DMatrix* dmat, + std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit) { + HostDeviceVector out_preds_d(0, -1); + bool result = PredictFromCacheDevice(dmat, &out_preds_d, model, ntree_limit); + if (!result) return false; + out_preds->resize(out_preds_d.size(), param.gpu_id); + thrust::copy(out_preds_d.tbegin(param.gpu_id), + out_preds_d.tend(param.gpu_id), out_preds->begin()); + return true; + } + + bool PredictFromCacheDevice(DMatrix* dmat, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit) { + if (ntree_limit == 0 || + ntree_limit * model.param.num_output_group >= model.trees.size()) { + auto it = device_cache_.find(dmat); + if (it != device_cache_.end()) { + HostDeviceVector& y = it->second.predictions; + if (y.size() != 0) { + out_preds->resize(y.size(), param.gpu_id); + thrust::copy(y.tbegin(param.gpu_id), y.tend(param.gpu_id), + out_preds->tbegin(param.gpu_id)); + return true; + } + } + } + + return false; + } + + void UpdatePredictionCache(const gbm::GBTreeModel& model, + std::vector>* updaters, + int num_new_trees) override { auto old_ntree = model.trees.size() - num_new_trees; // update cache entry - for (auto& kv : cache_) { - PredictionCacheEntry& e = kv.second; + for (auto& kv : device_cache_) { + DevicePredictionCacheEntry& e = kv.second; DMatrix* dmat = kv.first; + HostDeviceVector& predictions = e.predictions; - if (e.predictions.size() == 0) { - cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0, + if (predictions.size() == 0) { + // ensure that the device in predictions is correct + predictions.resize(0, param.gpu_id); + cpu_predictor->PredictBatch(dmat, &predictions.data_h(), model, 0, static_cast(model.trees.size())); } else if (model.param.num_output_group == 1 && updaters->size() > 0 && num_new_trees == 1 && updaters->back()->UpdatePredictionCache(e.data.get(), - &(e.predictions))) { - {} // do nothing + &predictions)) { + // do nothing } else { - DevicePredictInternal(dmat, &(e.predictions), model, old_ntree, + DevicePredictInternal(dmat, &predictions, model, old_ntree, model.trees.size()); } } @@ -391,6 +465,8 @@ class GPUPredictor : public xgboost::Predictor { Predictor::Init(cfg, cache); cpu_predictor->Init(cfg, cache); param.InitAllowUnknown(cfg); + for (const std::shared_ptr& d : cache) + device_cache_[d.get()].data = d; max_shared_memory_bytes = dh::max_shared_memory(param.gpu_id); } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 32630903e..2ca949e21 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -6,6 +6,8 @@ #include #include +#include "../common/host_device_vector.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); } // namespace dmlc @@ -20,6 +22,17 @@ TreeUpdater* TreeUpdater::Create(const std::string& name) { return (e->body)(); } +void TreeUpdater::Update(HostDeviceVector* gpair, + DMatrix* data, + const std::vector& trees) { + Update(gpair->data_h(), data, trees); +} + +bool TreeUpdater::UpdatePredictionCache(const DMatrix* data, + HostDeviceVector* out_preds) { + return UpdatePredictionCache(data, &out_preds->data_h()); +} + } // namespace xgboost namespace xgboost { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 6930ec1e2..3b778f355 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -12,6 +12,7 @@ #include #include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" +#include "../common/host_device_vector.h" #include "../common/hist_util.h" #include "../common/timer.h" #include "param.h" @@ -349,7 +350,8 @@ struct DeviceShard { } // Reset values for each update iteration - void Reset(const std::vector& host_gpair) { + void Reset(HostDeviceVector* dh_gpair, int device) { + auto begin = dh_gpair->tbegin(device); dh::safe_cuda(cudaSetDevice(device_idx)); position.current_dvec().fill(0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), @@ -359,8 +361,8 @@ struct DeviceShard { std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); ridx_segments.front() = Segment(0, ridx.size()); - this->gpair.copy(host_gpair.begin() + row_begin_idx, - host_gpair.begin() + row_end_idx); + this->gpair.copy(begin + row_begin_idx, + begin + row_end_idx); subsample_gpair(&gpair, param.subsample, row_begin_idx); hist.Reset(); } @@ -504,9 +506,28 @@ class GPUHistMaker : public TreeUpdater { monitor.Init("updater_gpu_hist", param.debug_verbose); } + void Update(const std::vector& gpair, DMatrix* dmat, const std::vector& trees) override { monitor.Start("Update", dList); + // TODO(canonizer): move it into the class if this ever becomes a bottleneck + HostDeviceVector gpair_d(gpair.size(), param.gpu_id); + dh::safe_cuda(cudaSetDevice(param.gpu_id)); + thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id)); + Update(&gpair_d, dmat, trees); + monitor.Stop("Update", dList); + } + + void Update(HostDeviceVector* gpair, DMatrix* dmat, + const std::vector& trees) override { + monitor.Start("Update", dList); + UpdateHelper(gpair, dmat, trees); + monitor.Stop("Update", dList); + } + + private: + void UpdateHelper(HostDeviceVector* gpair, DMatrix* dmat, + const std::vector& trees) { GradStats::CheckInfo(dmat->info()); // rescale learning rate according to size of trees float lr = param.learning_rate; @@ -521,9 +542,9 @@ class GPUHistMaker : public TreeUpdater { LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; } param.learning_rate = lr; - monitor.Stop("Update", dList); } + public: void InitDataOnce(DMatrix* dmat) { info = &dmat->info(); monitor.Start("Quantiles", dList); @@ -572,7 +593,7 @@ class GPUHistMaker : public TreeUpdater { initialised = true; } - void InitData(const std::vector& gpair, DMatrix* dmat, + void InitData(HostDeviceVector* gpair, DMatrix* dmat, const RegTree& tree) { monitor.Start("InitDataOnce", dList); if (!initialised) { @@ -585,11 +606,10 @@ class GPUHistMaker : public TreeUpdater { // Copy gpair & reset memory monitor.Start("InitDataReset", dList); omp_set_num_threads(shards.size()); -#pragma omp parallel - { - auto cpu_thread_id = omp_get_thread_num(); - shards[cpu_thread_id]->Reset(gpair); - } + + // TODO(canonizer): make it parallel again once HostDeviceVector is thread-safe + for (int shard = 0; shard < shards.size(); ++shard) + shards[shard]->Reset(gpair, param.gpu_id); monitor.Stop("InitDataReset", dList); } @@ -687,7 +707,7 @@ class GPUHistMaker : public TreeUpdater { return std::move(best_splits); } - void InitRoot(const std::vector& gpair, RegTree* p_tree) { + void InitRoot(RegTree* p_tree) { auto root_nidx = 0; // Sum gradients std::vector tmp_sums(shards.size()); @@ -800,7 +820,7 @@ class GPUHistMaker : public TreeUpdater { this->UpdatePosition(candidate, p_tree); } - void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { // Temporarily store number of threads so we can change it back later int nthread = omp_get_max_threads(); @@ -811,7 +831,7 @@ class GPUHistMaker : public TreeUpdater { this->InitData(gpair, p_fmat, *p_tree); monitor.Stop("InitData", dList); monitor.Start("InitRoot", dList); - this->InitRoot(gpair, p_tree); + this->InitRoot(p_tree); monitor.Stop("InitRoot", dList); auto timestamp = qexpand_->size(); @@ -854,6 +874,16 @@ class GPUHistMaker : public TreeUpdater { omp_set_num_threads(nthread); } + bool UpdatePredictionCache(const DMatrix* data, + std::vector* p_out_preds) override { + return false; + } + + bool UpdatePredictionCache(const DMatrix* data, + HostDeviceVector* p_out_preds) override { + return false; + } + struct ExpandEntry { int nid; int depth; diff --git a/tests/cpp/objective/test_regression_obj_gpu.cu b/tests/cpp/objective/test_regression_obj_gpu.cu new file mode 100644 index 000000000..0e507dc07 --- /dev/null +++ b/tests/cpp/objective/test_regression_obj_gpu.cu @@ -0,0 +1,69 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#include + +#include "../helpers.h" + +TEST(Objective, GPULinearRegressionGPair) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear"); + std::vector > args; + obj->Configure(args); + CheckObjFunction(obj, + {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + {0, 0, 0, 0, 1, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, + {1, 1, 1, 1, 1, 1, 1, 1}); + + ASSERT_NO_THROW(obj->DefaultEvalMetric()); +} + +TEST(Objective, GPULogisticRegressionGPair) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic"); + std::vector > args; + obj->Configure(args); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, + {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); +} + +TEST(Objective, GPULogisticRegressionBasic) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic"); + std::vector > args; + obj->Configure(args); + + // test label validation + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0})) + << "Expected error when label not in range [0,1f] for LogisticRegression"; + + // test ProbToMargin + EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f); + EXPECT_ANY_THROW(obj->ProbToMargin(10)) + << "Expected error when base_score not in range [0,1f] for LogisticRegression"; + + // test PredTransform + std::vector preds = {0, 0.1f, 0.5f, 0.9f, 1}; + std::vector out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f}; + obj->PredTransform(&preds); + for (int i = 0; i < static_cast(preds.size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } +} + +TEST(Objective, GPULogisticRawGPair) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw"); + std::vector > args; + obj->Configure(args); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, + {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); +}