[GPU-Plugin] Multi-GPU gpu_id bug fixes for grow_gpu_hist and grow_gpu methods, and additional documentation for the gpu plugin. (#2463)

This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney
2017-06-30 01:04:17 -07:00
committed by Rory Mitchell
parent 91dae84a00
commit 6b287177c8
21 changed files with 578 additions and 449 deletions

View File

@@ -9,11 +9,11 @@
#include <algorithm>
#include <chrono>
#include <ctime>
#include <cub/cub.cuh>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include <numeric>
#include <cub/cub.cuh>
#ifndef NCCL
#define NCCL 1
@@ -29,8 +29,8 @@
namespace dh {
#define HOST_DEV_INLINE __host__ __device__ __forceinline__
#define DEV_INLINE __device__ __forceinline__
#define HOST_DEV_INLINE __host__ __device__ __forceinline__
#define DEV_INLINE __device__ __forceinline__
/*
* Error handling functions
@@ -126,6 +126,11 @@ inline std::string device_name(int device_idx) {
return std::string(prop.name);
}
// ensure gpu_id is correct, so not dependent upon user knowing details
inline int get_device_idx(int gpu_id) {
// protect against overrun for gpu_id
return (std::abs(gpu_id) + 0) % dh::n_visible_devices();
}
/*
* Timers
@@ -309,11 +314,13 @@ enum memory_type { DEVICE, DEVICE_MANAGED };
template <memory_type MemoryT>
class bulk_allocator;
template <typename T> class dvec2;
template <typename T>
class dvec2;
template <typename T>
class dvec {
friend class dvec2<T>;
friend class dvec2<T>;
private:
T *_ptr;
size_t _size;
@@ -327,9 +334,10 @@ class dvec {
_ptr = static_cast<T *>(ptr);
_size = size;
_device_idx = device_idx;
safe_cuda(cudaSetDevice(_device_idx));
}
dvec() : _ptr(NULL), _size(0), _device_idx(0) {}
dvec() : _ptr(NULL), _size(0), _device_idx(-1) {}
size_t size() const { return _size; }
int device_idx() const { return _device_idx; }
bool empty() const { return _ptr == NULL || _size == 0; }
@@ -378,6 +386,10 @@ class dvec {
if (other.device_idx() == this->device_idx()) {
thrust::copy(other.tbegin(), other.tend(), this->tbegin());
} else {
std::cout << "deviceother: " << other.device_idx()
<< " devicethis: " << this->device_idx() << std::endl;
std::cout << "size deviceother: " << other.size()
<< " devicethis: " << this->device_idx() << std::endl;
throw std::runtime_error("Cannot copy to/from different devices");
}
@@ -401,26 +413,24 @@ class dvec {
*/
template <typename T>
class dvec2 {
private:
dvec<T> _d1, _d2;
cub::DoubleBuffer<T> _buff;
int _device_idx;
public:
void external_allocate(int device_idx, void *ptr1, void *ptr2, size_t size) {
if (!empty()) {
throw std::runtime_error("Tried to allocate dvec2 but already allocated");
}
_device_idx = device_idx;
_d1.external_allocate(_device_idx, ptr1, size);
_d2.external_allocate(_device_idx, ptr2, size);
_buff.d_buffers[0] = static_cast<T *>(ptr1);
_buff.d_buffers[1] = static_cast<T *>(ptr2);
_buff.selector = 0;
_device_idx = device_idx;
}
dvec2() : _d1(), _d2(), _buff(), _device_idx(0) {}
dvec2() : _d1(), _d2(), _buff(), _device_idx(-1) {}
size_t size() const { return _d1.size(); }
int device_idx() const { return _device_idx; }
@@ -433,7 +443,7 @@ class dvec2 {
T *current() { return _buff.Current(); }
dvec<T> &current_dvec() { return _buff.selector == 0? d1() : d2(); }
dvec<T> &current_dvec() { return _buff.selector == 0 ? d1() : d2(); }
T *other() { return _buff.Alternate(); }
};
@@ -459,7 +469,8 @@ class bulk_allocator {
template <typename T, typename SizeT, typename... Args>
size_t get_size_bytes(dvec<T> *first_vec, SizeT first_size, Args... args) {
return get_size_bytes<T,SizeT>(first_vec, first_size) + get_size_bytes(args...);
return get_size_bytes<T, SizeT>(first_vec, first_size) +
get_size_bytes(args...);
}
template <typename T, typename SizeT>
@@ -496,20 +507,23 @@ class bulk_allocator {
template <typename T, typename SizeT, typename... Args>
size_t get_size_bytes(dvec2<T> *first_vec, SizeT first_size, Args... args) {
return get_size_bytes<T,SizeT>(first_vec, first_size) + get_size_bytes(args...);
return get_size_bytes<T, SizeT>(first_vec, first_size) +
get_size_bytes(args...);
}
template <typename T, typename SizeT>
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec, SizeT first_size) {
first_vec->external_allocate(device_idx, static_cast<void *>(ptr),
static_cast<void *>(ptr+align_round_up(first_size * sizeof(T))),
first_size);
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
SizeT first_size) {
first_vec->external_allocate(
device_idx, static_cast<void *>(ptr),
static_cast<void *>(ptr + align_round_up(first_size * sizeof(T))),
first_size);
}
template <typename T, typename SizeT, typename... Args>
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec, SizeT first_size,
Args... args) {
allocate_dvec<T,SizeT>(device_idx, ptr, first_vec, first_size);
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
SizeT first_size, Args... args) {
allocate_dvec<T, SizeT>(device_idx, ptr, first_vec, first_size);
ptr += (align_round_up(first_size * sizeof(T)) * 2);
allocate_dvec(device_idx, ptr, args...);
}
@@ -706,11 +720,11 @@ struct BernoulliRng {
* @param name name used to track later
* @param stream cuda stream where to measure time
*/
#define TIMEIT(call, name) \
do { \
dh::Timer t1234; \
call; \
t1234.printElapsed(name); \
} while(0)
#define TIMEIT(call, name) \
do { \
dh::Timer t1234; \
call; \
t1234.printElapsed(name); \
} while (0)
} // namespace dh