Optimisations for gpu_hist. (#4248)
* Optimisations for gpu_hist. * Use streams to overlap operations. * ColumnSampler now uses HostDeviceVector to prevent repeatedly copying feature vectors to the device.
This commit is contained in:
@@ -208,16 +208,23 @@ __global__ void LaunchNKernel(int device_idx, size_t begin, size_t end,
|
||||
}
|
||||
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
||||
if (n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
|
||||
const int GRID_SIZE =
|
||||
static_cast<int>(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
|
||||
lambda);
|
||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
|
||||
n, lambda);
|
||||
}
|
||||
|
||||
// Default stream version
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -500,6 +507,31 @@ class BulkAllocator {
|
||||
}
|
||||
};
|
||||
|
||||
// Keep track of pinned memory allocation
|
||||
struct PinnedMemory {
|
||||
void *temp_storage{nullptr};
|
||||
size_t temp_storage_bytes{0};
|
||||
|
||||
~PinnedMemory() { Free(); }
|
||||
|
||||
template <typename T>
|
||||
xgboost::common::Span<T> GetSpan(size_t size) {
|
||||
size_t num_bytes = size * sizeof(T);
|
||||
if (num_bytes > temp_storage_bytes) {
|
||||
Free();
|
||||
safe_cuda(cudaMallocHost(&temp_storage, num_bytes));
|
||||
temp_storage_bytes = num_bytes;
|
||||
}
|
||||
return xgboost::common::Span<T>(static_cast<T *>(temp_storage), size);
|
||||
}
|
||||
|
||||
void Free() {
|
||||
if (temp_storage != nullptr) {
|
||||
safe_cuda(cudaFreeHost(temp_storage));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Keep track of cub library device allocation
|
||||
struct CubMemory {
|
||||
void *d_temp_storage;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <random>
|
||||
|
||||
#include "io.h"
|
||||
#include "host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@@ -84,26 +85,29 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
*/
|
||||
|
||||
class ColumnSampler {
|
||||
std::shared_ptr<std::vector<int>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<std::vector<int>>> feature_set_level_;
|
||||
std::shared_ptr<HostDeviceVector<int>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<HostDeviceVector<int>>> feature_set_level_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
|
||||
std::shared_ptr<std::vector<int>> ColSample
|
||||
(std::shared_ptr<std::vector<int>> p_features, float colsample) {
|
||||
std::shared_ptr<HostDeviceVector<int>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<int>> p_features, float colsample) {
|
||||
if (colsample == 1.0f) return p_features;
|
||||
const auto& features = *p_features;
|
||||
const auto& features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<std::vector<int>>();
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<int>>();
|
||||
auto& new_features = *p_new_features;
|
||||
new_features.resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.begin());
|
||||
std::shuffle(new_features.begin(), new_features.end(), rng_);
|
||||
new_features.resize(n);
|
||||
std::sort(new_features.begin(), new_features.end());
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
std::sort(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end());
|
||||
|
||||
return p_new_features;
|
||||
}
|
||||
@@ -135,13 +139,14 @@ class ColumnSampler {
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
|
||||
if (feature_set_tree_ == nullptr) {
|
||||
feature_set_tree_ = std::make_shared<std::vector<int>>();
|
||||
feature_set_tree_ = std::make_shared<HostDeviceVector<int>>();
|
||||
}
|
||||
Reset();
|
||||
|
||||
int begin_idx = skip_index_0 ? 1 : 0;
|
||||
feature_set_tree_->resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx);
|
||||
feature_set_tree_->Resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->HostVector().begin(),
|
||||
feature_set_tree_->HostVector().end(), begin_idx);
|
||||
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
@@ -150,7 +155,7 @@ class ColumnSampler {
|
||||
* \brief Resets this object.
|
||||
*/
|
||||
void Reset() {
|
||||
feature_set_tree_->clear();
|
||||
feature_set_tree_->Resize(0);
|
||||
feature_set_level_.clear();
|
||||
}
|
||||
|
||||
@@ -165,7 +170,7 @@ class ColumnSampler {
|
||||
* construction of each tree node, and must be called the same number of times in each
|
||||
* process and with the same parameters to return the same feature set across processes.
|
||||
*/
|
||||
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
|
||||
std::shared_ptr<HostDeviceVector<int>> GetFeatureSet(int depth) {
|
||||
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
|
||||
return feature_set_tree_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user