Objective function evaluation on GPU with minimal PCIe transfers (#2935)
* Added GPU objective function and no-copy interface. - xgboost::HostDeviceVector<T> syncs automatically between host and device - no-copy interfaces have been added - default implementations just sync the data to host and call the implementations with std::vector - GPU objective function, predictor, histogram updater process data directly on GPU
This commit is contained in:
@@ -484,6 +484,13 @@ class bulk_allocator {
|
||||
}
|
||||
|
||||
public:
|
||||
bulk_allocator() {}
|
||||
// prevent accidental copying, moving or assignment of this object
|
||||
bulk_allocator(const bulk_allocator<MemoryT>&) = delete;
|
||||
bulk_allocator(bulk_allocator<MemoryT>&&) = delete;
|
||||
void operator=(const bulk_allocator<MemoryT>&) = delete;
|
||||
void operator=(bulk_allocator<MemoryT>&&) = delete;
|
||||
|
||||
~bulk_allocator() {
|
||||
for (size_t i = 0; i < d_ptr.size(); i++) {
|
||||
if (!(d_ptr[i] == nullptr)) {
|
||||
|
||||
54
src/common/host_device_vector.cc
Normal file
54
src/common/host_device_vector.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
// dummy implementation of HostDeviceVector in case CUDA is not used
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include "./host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
struct HostDeviceVectorImpl {
|
||||
explicit HostDeviceVectorImpl(size_t size) : data_h_(size) {}
|
||||
std::vector<T> data_h_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
|
||||
impl_ = new HostDeviceVectorImpl<T>(size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::~HostDeviceVector() {
|
||||
HostDeviceVectorImpl<T>* tmp = impl_;
|
||||
impl_ = nullptr;
|
||||
delete tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::size() const { return impl_->data_h_.size(); }
|
||||
|
||||
template <typename T>
|
||||
int HostDeviceVector<T>::device() const { return -1; }
|
||||
|
||||
template <typename T>
|
||||
T* HostDeviceVector<T>::ptr_d(int device) { return nullptr; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h_; }
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
|
||||
impl_->data_h_.resize(new_size);
|
||||
}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<bst_gpair>;
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif
|
||||
135
src/common/host_device_vector.cu
Normal file
135
src/common/host_device_vector.cu
Normal file
@@ -0,0 +1,135 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include "./host_device_vector.h"
|
||||
#include "./device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
struct HostDeviceVectorImpl {
|
||||
HostDeviceVectorImpl(size_t size, int device)
|
||||
: device_(device), on_d_(device >= 0) {
|
||||
if (on_d_) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(size);
|
||||
} else {
|
||||
data_h_.resize(size);
|
||||
}
|
||||
}
|
||||
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>&) = delete;
|
||||
HostDeviceVectorImpl(HostDeviceVectorImpl<T>&&) = delete;
|
||||
void operator=(const HostDeviceVectorImpl<T>&) = delete;
|
||||
void operator=(HostDeviceVectorImpl<T>&&) = delete;
|
||||
|
||||
size_t size() const { return on_d_ ? data_d_.size() : data_h_.size(); }
|
||||
|
||||
int device() const { return device_; }
|
||||
|
||||
T* ptr_d(int device) {
|
||||
lazy_sync_device(device);
|
||||
return data_d_.data().get();
|
||||
}
|
||||
thrust::device_ptr<T> tbegin(int device) {
|
||||
return thrust::device_ptr<T>(ptr_d(device));
|
||||
}
|
||||
thrust::device_ptr<T> tend(int device) {
|
||||
auto begin = tbegin(device);
|
||||
return begin + size();
|
||||
}
|
||||
std::vector<T>& data_h() {
|
||||
lazy_sync_host();
|
||||
return data_h_;
|
||||
}
|
||||
void resize(size_t new_size, int new_device) {
|
||||
if (new_size == this->size() && new_device == device_)
|
||||
return;
|
||||
device_ = new_device;
|
||||
// if !on_d_, but the data size is 0 and the device is set,
|
||||
// resize the data on device instead
|
||||
if (!on_d_ && (data_h_.size() > 0 || device_ == -1)) {
|
||||
data_h_.resize(new_size);
|
||||
} else {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(new_size);
|
||||
on_d_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void lazy_sync_host() {
|
||||
if (!on_d_)
|
||||
return;
|
||||
if (data_h_.size() != this->size())
|
||||
data_h_.resize(this->size());
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
thrust::copy(data_d_.begin(), data_d_.end(), data_h_.begin());
|
||||
on_d_ = false;
|
||||
}
|
||||
|
||||
void lazy_sync_device(int device) {
|
||||
if (on_d_)
|
||||
return;
|
||||
if (device != device_) {
|
||||
CHECK_EQ(device_, -1);
|
||||
device_ = device;
|
||||
}
|
||||
if (data_d_.size() != this->size()) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
data_d_.resize(this->size());
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
thrust::copy(data_h_.begin(), data_h_.end(), data_d_.begin());
|
||||
on_d_ = true;
|
||||
}
|
||||
|
||||
std::vector<T> data_h_;
|
||||
thrust::device_vector<T> data_d_;
|
||||
// true if there is an up-to-date copy of data on device, false otherwise
|
||||
bool on_d_;
|
||||
int device_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
|
||||
impl_ = new HostDeviceVectorImpl<T>(size, device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HostDeviceVector<T>::~HostDeviceVector() {
|
||||
HostDeviceVectorImpl<T>* tmp = impl_;
|
||||
impl_ = nullptr;
|
||||
delete tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::size() const { return impl_->size(); }
|
||||
|
||||
template <typename T>
|
||||
int HostDeviceVector<T>::device() const { return impl_->device(); }
|
||||
|
||||
template <typename T>
|
||||
T* HostDeviceVector<T>::ptr_d(int device) { return impl_->ptr_d(device); }
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> HostDeviceVector<T>::tbegin(int device) {
|
||||
return impl_->tbegin(device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> HostDeviceVector<T>::tend(int device) {
|
||||
return impl_->tend(device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h(); }
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
|
||||
impl_->resize(new_size, new_device);
|
||||
}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<bst_gpair>;
|
||||
|
||||
} // namespace xgboost
|
||||
100
src/common/host_device_vector.h
Normal file
100
src/common/host_device_vector.h
Normal file
@@ -0,0 +1,100 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
// only include thrust-related files if host_device_vector.h
|
||||
// is included from a .cu file
|
||||
#ifdef __CUDACC__
|
||||
#include <thrust/device_ptr.h>
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T> struct HostDeviceVectorImpl;
|
||||
|
||||
/**
|
||||
* @file host_device_vector.h
|
||||
* @brief A device-and-host vector abstraction layer.
|
||||
*
|
||||
* Why HostDeviceVector?<br/>
|
||||
* With CUDA, one has to explicitly manage memory through 'cudaMemcpy' calls.
|
||||
* This wrapper class hides this management from the users, thereby making it
|
||||
* easy to integrate GPU/CPU usage under a single interface.
|
||||
*
|
||||
* Initialization/Allocation:<br/>
|
||||
* One can choose to initialize the vector on CPU or GPU during constructor.
|
||||
* (use the 'device' argument) Or, can choose to use the 'resize' method to
|
||||
* allocate/resize memory explicitly.
|
||||
*
|
||||
* Accessing underling data:<br/>
|
||||
* Use 'data_h' method to explicitly query for the underlying std::vector.
|
||||
* If you need the raw device pointer, use the 'ptr_d' method. For perf
|
||||
* implications of these calls, see below.
|
||||
*
|
||||
* Accessing underling data and their perf implications:<br/>
|
||||
* There are 4 scenarios to be considered here:
|
||||
* data_h and data on CPU --> no problems, std::vector returned immediately
|
||||
* data_h but data on GPU --> this causes a cudaMemcpy to be issued internally.
|
||||
* subsequent calls to data_h, will NOT incur this penalty.
|
||||
* (assuming 'ptr_d' is not called in between)
|
||||
* ptr_d but data on CPU --> this causes a cudaMemcpy to be issued internally.
|
||||
* subsequent calls to ptr_d, will NOT incur this penalty.
|
||||
* (assuming 'data_h' is not called in between)
|
||||
* ptr_d and data on GPU --> no problems, the device ptr will be returned immediately
|
||||
*
|
||||
* What if xgboost is compiled without CUDA?<br/>
|
||||
* In that case, there's a special implementation which always falls-back to
|
||||
* working with std::vector. This logic can be found in host_device_vector.cc
|
||||
*
|
||||
* Why not consider CUDA unified memory?<br/>
|
||||
* We did consider. However, it poses complications if we need to support both
|
||||
* compiling with and without CUDA toolkit. It was easier to have
|
||||
* 'HostDeviceVector' with a special-case implementation in host_device_vector.cc
|
||||
*
|
||||
* @note: This is not thread-safe!
|
||||
*/
|
||||
template <typename T>
|
||||
class HostDeviceVector {
|
||||
public:
|
||||
explicit HostDeviceVector(size_t size = 0, int device = -1);
|
||||
~HostDeviceVector();
|
||||
HostDeviceVector(const HostDeviceVector<T>&) = delete;
|
||||
HostDeviceVector(HostDeviceVector<T>&&) = delete;
|
||||
void operator=(const HostDeviceVector<T>&) = delete;
|
||||
void operator=(HostDeviceVector<T>&&) = delete;
|
||||
size_t size() const;
|
||||
int device() const;
|
||||
T* ptr_d(int device);
|
||||
|
||||
// only define functions returning device_ptr
|
||||
// if HostDeviceVector.h is included from a .cu file
|
||||
#ifdef __CUDACC__
|
||||
thrust::device_ptr<T> tbegin(int device);
|
||||
thrust::device_ptr<T> tend(int device);
|
||||
#endif
|
||||
|
||||
std::vector<T>& data_h();
|
||||
void resize(size_t new_size, int new_device);
|
||||
|
||||
// helper functions in case a function needs to be templated
|
||||
// to work for both HostDeviceVector and std::vector
|
||||
static std::vector<T>& data_h(HostDeviceVector<T>* v) {
|
||||
return v->data_h();
|
||||
}
|
||||
|
||||
static std::vector<T>& data_h(std::vector<T>* v) {
|
||||
return *v;
|
||||
}
|
||||
|
||||
private:
|
||||
HostDeviceVectorImpl<T>* impl_;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_
|
||||
@@ -20,8 +20,8 @@ namespace common {
|
||||
* \param x input parameter
|
||||
* \return the transformed value.
|
||||
*/
|
||||
inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + std::exp(-x));
|
||||
XGBOOST_DEVICE inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
inline avx::Float8 Sigmoid(avx::Float8 x) {
|
||||
|
||||
Reference in New Issue
Block a user