Modify caching allocator/vector and fix issues relating to inability to train large datasets (#4615)
This commit is contained in:
parent
cd1526d3b1
commit
7a388cbf8b
@ -22,7 +22,6 @@
|
|||||||
#include "./common/common.h"
|
#include "./common/common.h"
|
||||||
#include "./common/config.h"
|
#include "./common/config.h"
|
||||||
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
enum CLITask {
|
enum CLITask {
|
||||||
@ -240,6 +239,7 @@ void CLITrain(const CLIParam& param) {
|
|||||||
version += 1;
|
version += 1;
|
||||||
CHECK_EQ(version, rabit::VersionNumber());
|
CHECK_EQ(version, rabit::VersionNumber());
|
||||||
}
|
}
|
||||||
|
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec";
|
||||||
// always save final round
|
// always save final round
|
||||||
if ((param.save_period == 0 || param.num_round % param.save_period != 0) &&
|
if ((param.save_period == 0 || param.num_round % param.save_period != 0) &&
|
||||||
param.model_out != "NONE" &&
|
param.model_out != "NONE" &&
|
||||||
|
|||||||
@ -305,11 +305,11 @@ struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
|||||||
};
|
};
|
||||||
pointer allocate(size_t n) {
|
pointer allocate(size_t n) {
|
||||||
pointer ptr = super_t::allocate(n);
|
pointer ptr = super_t::allocate(n);
|
||||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n);
|
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) {
|
void deallocate(pointer ptr, size_t n) {
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n);
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||||
return super_t::deallocate(ptr, n);
|
return super_t::deallocate(ptr, n);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -329,19 +329,19 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
|||||||
{
|
{
|
||||||
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
||||||
// maximum cached bytes
|
// maximum cached bytes
|
||||||
static cub::CachingDeviceAllocator allocator(8,3,10);
|
static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
|
||||||
return allocator;
|
return *allocator;
|
||||||
}
|
}
|
||||||
pointer allocate(size_t n) {
|
pointer allocate(size_t n) {
|
||||||
T *ptr;
|
T *ptr;
|
||||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||||
n * sizeof(T));
|
n * sizeof(T));
|
||||||
pointer thrust_ptr = thrust::device_ptr<T>(ptr);
|
pointer thrust_ptr(ptr);
|
||||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n);
|
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
||||||
return thrust_ptr;
|
return thrust_ptr;
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) {
|
void deallocate(pointer ptr, size_t n) {
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n);
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||||
}
|
}
|
||||||
__host__ __device__
|
__host__ __device__
|
||||||
@ -363,6 +363,7 @@ template <typename T>
|
|||||||
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>;
|
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>;
|
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief A double buffer, useful for algorithms like sort.
|
* \brief A double buffer, useful for algorithms like sort.
|
||||||
*/
|
*/
|
||||||
@ -376,9 +377,7 @@ class DoubleBuffer {
|
|||||||
DoubleBuffer(VectorT *v1, VectorT *v2) {
|
DoubleBuffer(VectorT *v1, VectorT *v2) {
|
||||||
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
|
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
|
||||||
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
|
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
|
||||||
buff.d_buffers[0] = v1->data().get();
|
buff = cub::DoubleBuffer<T>(a.data(), b.data());
|
||||||
buff.d_buffers[1] = v2->data().get();
|
|
||||||
buff.selector = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Size() const {
|
size_t Size() const {
|
||||||
|
|||||||
@ -250,6 +250,10 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
struct DeviceShard {
|
struct DeviceShard {
|
||||||
DeviceShard() : device_{-1} {}
|
DeviceShard() : device_{-1} {}
|
||||||
|
|
||||||
|
~DeviceShard() {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
|
}
|
||||||
|
|
||||||
void Init(int device) {
|
void Init(int device) {
|
||||||
this->device_ = device;
|
this->device_ = device;
|
||||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_);
|
max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_);
|
||||||
|
|||||||
@ -611,8 +611,6 @@ struct DeviceShard {
|
|||||||
/*! \brief Sum gradient for each node. */
|
/*! \brief Sum gradient for each node. */
|
||||||
std::vector<GradientPair> node_sum_gradients;
|
std::vector<GradientPair> node_sum_gradients;
|
||||||
common::Span<GradientPair> node_sum_gradients_d;
|
common::Span<GradientPair> node_sum_gradients_d;
|
||||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
|
||||||
dh::device_vector<int> feature_set_d;
|
|
||||||
/*! The row offset for this shard. */
|
/*! The row offset for this shard. */
|
||||||
bst_uint row_begin_idx;
|
bst_uint row_begin_idx;
|
||||||
bst_uint row_end_idx;
|
bst_uint row_end_idx;
|
||||||
@ -700,6 +698,7 @@ struct DeviceShard {
|
|||||||
this->interaction_constraints.Reset();
|
this->interaction_constraints.Reset();
|
||||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||||
GradientPair());
|
GradientPair());
|
||||||
|
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||||
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
|
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
|
||||||
|
|
||||||
dh::safe_cuda(cudaMemcpyAsync(
|
dh::safe_cuda(cudaMemcpyAsync(
|
||||||
@ -921,6 +920,7 @@ struct DeviceShard {
|
|||||||
dh::safe_cuda(cudaMemcpy(
|
dh::safe_cuda(cudaMemcpy(
|
||||||
out_preds_d, prediction_cache.data(),
|
out_preds_d, prediction_cache.data(),
|
||||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||||
|
row_partitioner.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduceHist(int nidx, dh::AllReducer* reducer) {
|
void AllReduceHist(int nidx, dh::AllReducer* reducer) {
|
||||||
|
|||||||
@ -11,15 +11,16 @@ namespace tree {
|
|||||||
|
|
||||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||||
int right_idx) {
|
int right_idx) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
std::vector<int64_t> left_count = {
|
std::vector<int64_t> left_count = {
|
||||||
std::count(position_in.begin(), position_in.end(), left_idx)};
|
std::count(position_in.begin(), position_in.end(), left_idx)};
|
||||||
thrust::device_vector<int64_t> d_left_count = left_count;
|
dh::caching_device_vector<int64_t> d_left_count = left_count;
|
||||||
thrust::device_vector<int> position = position_in;
|
dh::caching_device_vector<int> position = position_in;
|
||||||
thrust::device_vector<int> position_out(position.size());
|
dh::caching_device_vector<int> position_out(position.size());
|
||||||
|
|
||||||
thrust::device_vector<RowPartitioner::RowIndexT> ridx(position.size());
|
dh::caching_device_vector<RowPartitioner::RowIndexT> ridx(position.size());
|
||||||
thrust::sequence(ridx.begin(), ridx.end());
|
thrust::sequence(ridx.begin(), ridx.end());
|
||||||
thrust::device_vector<RowPartitioner::RowIndexT> ridx_out(ridx.size());
|
dh::caching_device_vector<RowPartitioner::RowIndexT> ridx_out(ridx.size());
|
||||||
RowPartitioner rp(0,10);
|
RowPartitioner rp(0,10);
|
||||||
rp.SortPosition(
|
rp.SortPosition(
|
||||||
common::Span<int>(position.data().get(), position.size()),
|
common::Span<int>(position.data().get(), position.size()),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user