GPU Plugin: Add subsample, colsample_bytree, colsample_bylevel (#1895)
This commit is contained in:
parent
cee4aafb93
commit
b49b339183
@ -9,10 +9,10 @@ https://www.kaggle.com/c/bosch-production-line-performance/data
|
||||
|
||||
Copy train_numeric.csv into xgboost/demo/data.
|
||||
|
||||
The subsample parameter can be changed so you can run the script first on a small portion of the data. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970.
|
||||
The subset parameter changes the proportion of rows loaded from the CSV file. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970. Lower the parameter if your device runs out of memory.
|
||||
|
||||
```python
|
||||
subsample = 0.4
|
||||
subset = 0.4
|
||||
```
|
||||
|
||||
Parameters are set as usual except that we set silent to 0 to see how much memory is being allocated on the GPU and we change 'updater' to 'grow_gpu' to activate the GPU plugin.
|
||||
|
||||
@ -5,12 +5,12 @@ import time
|
||||
import random
|
||||
from sklearn.cross_validation import StratifiedKFold
|
||||
|
||||
#For sub sampling rows from input file
|
||||
#For sampling rows from input file
|
||||
random_seed = 9
|
||||
subsample = 0.4
|
||||
subset = 0.4
|
||||
|
||||
n_rows = 1183747;
|
||||
train_rows = int(n_rows * subsample)
|
||||
train_rows = int(n_rows * subset)
|
||||
random.seed(random_seed)
|
||||
skip = sorted(random.sample(xrange(1,n_rows + 1),n_rows-train_rows))
|
||||
data = pd.read_csv("../data/train_numeric.csv", index_col=0, dtype=np.float32, skiprows=skip)
|
||||
|
||||
@ -32,8 +32,6 @@ Data is stored in a sparse format. For example, missing values produced by one h
|
||||
|
||||
A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset.
|
||||
|
||||
The algorithm will automatically perform row subsampling if it detects there is not enough memory on the device.
|
||||
|
||||
## Dependencies
|
||||
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
#include <thrust/random.h>
|
||||
#include <algorithm>
|
||||
#include <ctime>
|
||||
#include <sstream>
|
||||
@ -147,6 +148,8 @@ struct Timer {
|
||||
LARGE_INTEGER now;
|
||||
QueryPerformanceCounter(&now);
|
||||
return static_cast<double>(now.QuadPart) / s_frequency.QuadPart;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -160,12 +163,14 @@ struct Timer {
|
||||
#ifdef _WIN32
|
||||
_ReadWriteBarrier();
|
||||
return seconds_now() - start;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
void printElapsed(char *label) {
|
||||
void printElapsed(std::string label) {
|
||||
#ifdef TIMERS
|
||||
safe_cuda(cudaDeviceSynchronize());
|
||||
printf("%s:\t %1.4fs\n", label, elapsed());
|
||||
printf("%s:\t %1.4fs\n", label.c_str(), elapsed());
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@ -233,46 +238,6 @@ template <typename T> __device__ range block_stride_range(T begin, T end) {
|
||||
return r;
|
||||
}
|
||||
|
||||
/*
|
||||
* Utility functions
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
thrust::host_vector<T> h = v;
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print(char *label, const thrust::device_vector<T> &v,
|
||||
const char *format = "%d ", int max = 10) {
|
||||
thrust::host_vector<T> h_v = v;
|
||||
|
||||
std::cout << label << ":\n";
|
||||
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
|
||||
printf(format, h_v[i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
|
||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||
}
|
||||
|
||||
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
|
||||
return thrust::device_pointer_cast(d_ptr);
|
||||
}
|
||||
|
||||
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
|
||||
return raw_pointer_cast(v.data());
|
||||
}
|
||||
|
||||
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
|
||||
return sizeof(T) * v.size();
|
||||
}
|
||||
|
||||
// Threadblock iterates over range, filling with value
|
||||
template <typename IterT, typename ValueT>
|
||||
@ -306,11 +271,11 @@ template <typename T> class dvec {
|
||||
|
||||
public:
|
||||
dvec() : _ptr(NULL), _size(0) {}
|
||||
size_t size() { return _size; }
|
||||
bool empty() { return _ptr == NULL || _size == 0; }
|
||||
size_t size() const { return _size; }
|
||||
bool empty() const { return _ptr == NULL || _size == 0; }
|
||||
T *data() { return _ptr; }
|
||||
|
||||
std::vector<T> as_vector() {
|
||||
std::vector<T> as_vector() const {
|
||||
std::vector<T> h_vector(size());
|
||||
safe_cuda(cudaMemcpy(h_vector.data(), _ptr, size() * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
@ -454,6 +419,55 @@ inline std::string device_name() {
|
||||
return std::string(prop.name);
|
||||
}
|
||||
|
||||
/*
|
||||
* Utility functions
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
thrust::host_vector<T> h = v;
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
std::vector<T> h = v.as_vector();
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print(char *label, const thrust::device_vector<T> &v,
|
||||
const char *format = "%d ", int max = 10) {
|
||||
thrust::host_vector<T> h_v = v;
|
||||
|
||||
std::cout << label << ":\n";
|
||||
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
|
||||
printf(format, h_v[i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
|
||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||
}
|
||||
|
||||
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
|
||||
return thrust::device_pointer_cast(d_ptr);
|
||||
}
|
||||
|
||||
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
|
||||
return raw_pointer_cast(v.data());
|
||||
}
|
||||
|
||||
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
|
||||
return sizeof(T) * v.size();
|
||||
}
|
||||
/*
|
||||
* Kernel launcher
|
||||
*/
|
||||
@ -470,4 +484,25 @@ inline void launch_n(size_t n, L lambda) {
|
||||
|
||||
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
|
||||
}
|
||||
|
||||
/*
|
||||
* Random
|
||||
*/
|
||||
|
||||
struct BernoulliRng {
|
||||
float p;
|
||||
int seed;
|
||||
|
||||
__host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {}
|
||||
|
||||
__host__ __device__ bool operator()(const int i) const {
|
||||
thrust::default_random_engine rng(seed);
|
||||
thrust::uniform_real_distribution<float> dist;
|
||||
rng.discard(i);
|
||||
|
||||
return dist(rng) <= p;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@ -4,9 +4,11 @@
|
||||
#pragma once
|
||||
#include <cub/cub.cuh>
|
||||
#include <xgboost/base.h>
|
||||
#include <vector>
|
||||
#include "device_helpers.cuh"
|
||||
#include "find_split_multiscan.cuh"
|
||||
#include "find_split_sorting.cuh"
|
||||
#include "gpu_data.cuh"
|
||||
#include "types_functions.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -62,24 +64,47 @@ void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes,
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void find_split(const ItemIter items_iter, Split *d_split_candidates,
|
||||
Node *d_nodes, bst_uint num_items, int num_features,
|
||||
const int *d_feature_offsets, gpu_gpair *d_node_sums,
|
||||
int *d_node_offsets, const GPUTrainingParam param,
|
||||
const int level, bool multiscan_algorithm) {
|
||||
void colsample_level(GPUData *data, const TrainParam xgboost_param,
|
||||
const std::vector<int> &feature_set_tree,
|
||||
std::vector<int> *feature_set_level) {
|
||||
unsigned n_bytree =
|
||||
static_cast<unsigned>(xgboost_param.colsample_bytree * data->n_features);
|
||||
unsigned n =
|
||||
static_cast<unsigned>(n_bytree * xgboost_param.colsample_bylevel);
|
||||
CHECK_GT(n, 0);
|
||||
|
||||
*feature_set_level = feature_set_tree;
|
||||
|
||||
std::shuffle((*feature_set_level).begin(),
|
||||
(*feature_set_level).begin() + n_bytree, common::GlobalRandom());
|
||||
|
||||
data->feature_set = *feature_set_level;
|
||||
|
||||
data->feature_flags.fill(0);
|
||||
auto d_feature_set = data->feature_set.data();
|
||||
auto d_feature_flags = data->feature_flags.data();
|
||||
|
||||
dh::launch_n(
|
||||
n, [=] __device__(int i) { d_feature_flags[d_feature_set[i]] = 1; });
|
||||
}
|
||||
|
||||
void find_split(GPUData *data, const TrainParam xgboost_param, const int level,
|
||||
bool multiscan_algorithm,
|
||||
const std::vector<int> &feature_set_tree,
|
||||
std::vector<int> *feature_set_level) {
|
||||
colsample_level(data, xgboost_param, feature_set_tree, feature_set_level);
|
||||
// Reset split candidates
|
||||
data->split_candidates.fill(Split());
|
||||
|
||||
if (multiscan_algorithm) {
|
||||
find_split_candidates_multiscan(items_iter, d_split_candidates, d_nodes,
|
||||
num_items, num_features, d_feature_offsets,
|
||||
param, level);
|
||||
find_split_candidates_multiscan(data, level);
|
||||
} else {
|
||||
find_split_candidates_sorted(items_iter, d_split_candidates, d_nodes,
|
||||
num_items, num_features, d_feature_offsets,
|
||||
d_node_sums, d_node_offsets, param, level);
|
||||
find_split_candidates_sorted(data, level);
|
||||
}
|
||||
|
||||
// Find the best split for each node
|
||||
reduce_split_candidates(d_split_candidates, d_nodes, level, num_features,
|
||||
param);
|
||||
reduce_split_candidates(data->split_candidates.data(), data->nodes.data(),
|
||||
level, data->n_features, data->param);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <cub/cub.cuh>
|
||||
#include <xgboost/base.h>
|
||||
#include "device_helpers.cuh"
|
||||
#include "gpu_data.cuh"
|
||||
#include "types_functions.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -609,22 +610,11 @@ struct FindSplitEnactorMultiscan {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ResetSplitCandidates() {
|
||||
const int max_nodes = 1 << level;
|
||||
const int begin = blockIdx.x * max_nodes;
|
||||
const int end = begin + max_nodes;
|
||||
|
||||
for (auto i : dh::block_stride_range(begin, end)) {
|
||||
d_split_candidates_out[i] = Split();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin,
|
||||
const bst_uint &segment_end) {
|
||||
// Current position
|
||||
bst_uint offset = segment_begin;
|
||||
|
||||
ResetSplitCandidates();
|
||||
ResetTileCarry();
|
||||
ResetSplits();
|
||||
CacheNodes();
|
||||
@ -654,8 +644,9 @@ __launch_bounds__(1024, 2)
|
||||
const ItemIter items_iter, Split *d_split_candidates_out,
|
||||
const Node *d_nodes, const int node_begin, bst_uint num_items,
|
||||
int num_features, const int *d_feature_offsets,
|
||||
const GPUTrainingParam param, const int level) {
|
||||
if (num_items <= 0) {
|
||||
const GPUTrainingParam param, const int *d_feature_flags,
|
||||
const int level) {
|
||||
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -685,69 +676,45 @@ __launch_bounds__(1024, 2)
|
||||
}
|
||||
|
||||
template <int N_NODES>
|
||||
void find_split_candidates_multiscan_variation(
|
||||
const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes,
|
||||
int node_begin, int node_end, bst_uint num_items, int num_features,
|
||||
const int *d_feature_offsets, const GPUTrainingParam param,
|
||||
const int level) {
|
||||
|
||||
void find_split_candidates_multiscan_variation(GPUData *data, const int level) {
|
||||
const int node_begin = (1 << level) - 1;
|
||||
const int BLOCK_THREADS = 512;
|
||||
|
||||
CHECK((node_end - node_begin) <= N_NODES) << "Multiscan: N_NODES template "
|
||||
"parameter too small for given "
|
||||
"node range.";
|
||||
CHECK(BLOCK_THREADS / 32 < 32)
|
||||
<< "Too many active warps. See FindSplitEnactor - ReduceSplits.";
|
||||
|
||||
typedef FindSplitParamsMultiscan<BLOCK_THREADS, N_NODES, false>
|
||||
find_split_params;
|
||||
typedef ReduceParamsMultiscan<BLOCK_THREADS, N_NODES, false> reduce_params;
|
||||
int grid_size = num_features;
|
||||
int grid_size = data->n_features;
|
||||
|
||||
find_split_candidates_multiscan_kernel<
|
||||
find_split_params,
|
||||
reduce_params><<<grid_size, find_split_params::BLOCK_THREADS>>>(
|
||||
items_iter, d_split_candidates, d_nodes, node_begin, num_items,
|
||||
num_features, d_feature_offsets, param, level);
|
||||
data->items_iter, data->split_candidates.data(), data->nodes.data(),
|
||||
node_begin, data->fvalues.size(), data->n_features, data->foffsets.data(),
|
||||
data->param, data->feature_flags.data(), level);
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void find_split_candidates_multiscan(
|
||||
const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes,
|
||||
bst_uint num_items, int num_features, const int *d_feature_offsets,
|
||||
const GPUTrainingParam param, const int level) {
|
||||
void find_split_candidates_multiscan(GPUData *data, const int level) {
|
||||
// Select templated variation of split finding algorithm
|
||||
switch (level) {
|
||||
case 0:
|
||||
find_split_candidates_multiscan_variation<1>(
|
||||
items_iter, d_split_candidates, d_nodes, 0, 1, num_items, num_features,
|
||||
d_feature_offsets, param, level);
|
||||
find_split_candidates_multiscan_variation<1>(data, level);
|
||||
break;
|
||||
case 1:
|
||||
find_split_candidates_multiscan_variation<2>(
|
||||
items_iter, d_split_candidates, d_nodes, 1, 3, num_items, num_features,
|
||||
d_feature_offsets, param, level);
|
||||
find_split_candidates_multiscan_variation<2>(data, level);
|
||||
break;
|
||||
case 2:
|
||||
find_split_candidates_multiscan_variation<4>(
|
||||
items_iter, d_split_candidates, d_nodes, 3, 7, num_items, num_features,
|
||||
d_feature_offsets, param, level);
|
||||
find_split_candidates_multiscan_variation<4>(data, level);
|
||||
break;
|
||||
case 3:
|
||||
find_split_candidates_multiscan_variation<8>(
|
||||
items_iter, d_split_candidates, d_nodes, 7, 15, num_items, num_features,
|
||||
d_feature_offsets, param, level);
|
||||
find_split_candidates_multiscan_variation<8>(data, level);
|
||||
break;
|
||||
case 4:
|
||||
find_split_candidates_multiscan_variation<16>(
|
||||
items_iter, d_split_candidates, d_nodes, 15, 31, num_items,
|
||||
num_features, d_feature_offsets, param, level);
|
||||
break;
|
||||
case 5:
|
||||
find_split_candidates_multiscan_variation<32>(
|
||||
items_iter, d_split_candidates, d_nodes, 31, 63, num_items,
|
||||
num_features, d_feature_offsets, param, level);
|
||||
find_split_candidates_multiscan_variation<16>(data, level);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,17 +337,8 @@ struct FindSplitEnactorSorting {
|
||||
WriteBestSplit(node_id_adjusted);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ResetSplitCandidates() {
|
||||
const int max_nodes = 1 << level;
|
||||
const int begin = blockIdx.x * max_nodes;
|
||||
|
||||
dh::block_fill(d_split_candidates_out + begin, max_nodes, Split());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin,
|
||||
const bst_uint &segment_end) {
|
||||
ResetSplitCandidates();
|
||||
|
||||
int node_begin = segment_begin;
|
||||
|
||||
const int max_nodes = 1 << level;
|
||||
@ -377,9 +368,9 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
|
||||
const ItemIter items_iter, Split *d_split_candidates_out,
|
||||
const Node *d_nodes, bst_uint num_items, const int num_features,
|
||||
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
|
||||
const GPUTrainingParam param, const int level) {
|
||||
const GPUTrainingParam param, const int *d_feature_flags, const int level) {
|
||||
|
||||
if (num_items <= 0) {
|
||||
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -408,23 +399,19 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
|
||||
.ProcessFeature(segment_begin, segment_end);
|
||||
}
|
||||
|
||||
void find_split_candidates_sorted(const ItemIter items_iter,
|
||||
Split *d_split_candidates, Node *d_nodes,
|
||||
bst_uint num_items, int num_features,
|
||||
const int *d_feature_offsets,
|
||||
gpu_gpair *d_node_sums, int *d_node_offsets,
|
||||
const GPUTrainingParam param,
|
||||
const int level) {
|
||||
void find_split_candidates_sorted(GPUData * data, const int level) {
|
||||
const int BLOCK_THREADS = 512;
|
||||
|
||||
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
|
||||
|
||||
int grid_size = num_features;
|
||||
int grid_size = data->n_features;
|
||||
|
||||
find_split_candidates_sorted_kernel<
|
||||
BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>(
|
||||
items_iter, d_split_candidates, d_nodes, num_items, num_features,
|
||||
d_feature_offsets, d_node_sums, d_node_offsets, param, level);
|
||||
data->items_iter, data->split_candidates.data(), data->nodes.data(),
|
||||
data->fvalues.size(), data->n_features,
|
||||
data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(),
|
||||
data->param, data->feature_flags.data(), level);
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
@ -12,143 +12,17 @@
|
||||
#include <thrust/sequence.h>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "../../../src/common/random.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "find_split.cuh"
|
||||
#include "gpu_builder.cuh"
|
||||
#include "types_functions.cuh"
|
||||
#include "gpu_data.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
struct GPUData {
|
||||
GPUData() : allocated(false), n_features(0), n_instances(0) {}
|
||||
|
||||
bool allocated;
|
||||
int n_features;
|
||||
int n_instances;
|
||||
|
||||
dh::bulk_allocator ba;
|
||||
GPUTrainingParam param;
|
||||
|
||||
dh::dvec<float> fvalues;
|
||||
dh::dvec<float> fvalues_temp;
|
||||
dh::dvec<float> fvalues_cached;
|
||||
dh::dvec<int> foffsets;
|
||||
dh::dvec<bst_uint> instance_id;
|
||||
dh::dvec<bst_uint> instance_id_temp;
|
||||
dh::dvec<bst_uint> instance_id_cached;
|
||||
dh::dvec<int> feature_id;
|
||||
dh::dvec<NodeIdT> node_id;
|
||||
dh::dvec<NodeIdT> node_id_temp;
|
||||
dh::dvec<NodeIdT> node_id_instance;
|
||||
dh::dvec<gpu_gpair> gpair;
|
||||
dh::dvec<Node> nodes;
|
||||
dh::dvec<Split> split_candidates;
|
||||
dh::dvec<gpu_gpair> node_sums;
|
||||
dh::dvec<int> node_offsets;
|
||||
dh::dvec<int> sort_index_in;
|
||||
dh::dvec<int> sort_index_out;
|
||||
|
||||
dh::dvec<char> cub_mem;
|
||||
|
||||
ItemIter items_iter;
|
||||
|
||||
void Init(const std::vector<float> &in_fvalues,
|
||||
const std::vector<int> &in_foffsets,
|
||||
const std::vector<bst_uint> &in_instance_id,
|
||||
const std::vector<int> &in_feature_id,
|
||||
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
|
||||
bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) {
|
||||
n_features = n_features_in;
|
||||
n_instances = n_instances_in;
|
||||
|
||||
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
|
||||
uint32_t max_nodes_level = 1 << max_depth;
|
||||
|
||||
// Calculate memory for sort
|
||||
size_t cub_mem_size = 0;
|
||||
cub::DoubleBuffer<NodeIdT> db_key;
|
||||
cub::DoubleBuffer<int> db_value;
|
||||
|
||||
cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
cub_mem.data(), cub_mem_size, db_key,
|
||||
db_value, in_fvalues.size(), n_features,
|
||||
foffsets.data(), foffsets.data() + 1);
|
||||
|
||||
// Allocate memory
|
||||
size_t free_memory = dh::available_memory();
|
||||
ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(),
|
||||
&fvalues_cached, in_fvalues.size(), &foffsets,
|
||||
in_foffsets.size(), &instance_id, in_instance_id.size(),
|
||||
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
|
||||
in_instance_id.size(), &feature_id, in_feature_id.size(),
|
||||
&node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(),
|
||||
&node_id_instance, n_instances, &gpair, n_instances, &nodes,
|
||||
max_nodes, &split_candidates, max_nodes_level * n_features,
|
||||
&node_sums, max_nodes_level * n_features, &node_offsets,
|
||||
max_nodes_level * n_features, &sort_index_in, in_fvalues.size(),
|
||||
&sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size);
|
||||
|
||||
if (!param_in.silent) {
|
||||
const int mb_size = 1048576;
|
||||
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
|
||||
<< free_memory / mb_size << " MB on " << dh::device_name();
|
||||
}
|
||||
node_id.fill(0);
|
||||
node_id_instance.fill(0);
|
||||
|
||||
fvalues = in_fvalues;
|
||||
fvalues_cached = fvalues;
|
||||
foffsets = in_foffsets;
|
||||
instance_id = in_instance_id;
|
||||
instance_id_cached = instance_id;
|
||||
feature_id = in_feature_id;
|
||||
|
||||
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
|
||||
param_in.reg_alpha, param_in.max_delta_step);
|
||||
|
||||
gpair = in_gpair;
|
||||
|
||||
nodes.fill(Node());
|
||||
|
||||
items_iter = thrust::make_zip_iterator(thrust::make_tuple(
|
||||
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
|
||||
fvalues.tbegin(), node_id.tbegin()));
|
||||
|
||||
allocated = true;
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
~GPUData() {}
|
||||
|
||||
// Reset memory for new boosting iteration
|
||||
void Reset(const std::vector<bst_gpair> &in_gpair) {
|
||||
CHECK(allocated);
|
||||
gpair = in_gpair;
|
||||
instance_id = instance_id_cached;
|
||||
fvalues = fvalues_cached;
|
||||
nodes.fill(Node());
|
||||
node_id_instance.fill(0);
|
||||
node_id.fill(0);
|
||||
}
|
||||
|
||||
bool IsAllocated() { return allocated; }
|
||||
|
||||
// Gather from node_id_instance into node_id according to instance_id
|
||||
void GatherNodeId() {
|
||||
// Update node_id for each item
|
||||
auto d_node_id = node_id.data();
|
||||
auto d_node_id_instance = node_id_instance.data();
|
||||
auto d_instance_id = instance_id.data();
|
||||
|
||||
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
|
||||
// Item item = d_items[i];
|
||||
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
|
||||
|
||||
@ -253,15 +127,26 @@ void GPUBuilder::Sort(int level) {
|
||||
}
|
||||
}
|
||||
|
||||
void GPUBuilder::ColsampleTree() {
|
||||
unsigned n = static_cast<unsigned>(
|
||||
param.colsample_bytree * gpu_data->n_features);
|
||||
CHECK_GT(n, 0);
|
||||
|
||||
feature_set_tree.resize(gpu_data->n_features);
|
||||
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
||||
std::shuffle(feature_set_tree.begin(), feature_set_tree.end(),
|
||||
common::GlobalRandom());
|
||||
}
|
||||
|
||||
void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
cudaProfilerStart();
|
||||
try {
|
||||
dh::Timer update;
|
||||
dh::Timer t;
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
t.printElapsed("init data");
|
||||
this->InitFirstNode();
|
||||
this->ColsampleTree();
|
||||
|
||||
for (int level = 0; level < param.max_depth; level++) {
|
||||
bool use_multiscan_algorithm = level < multiscan_levels;
|
||||
@ -280,11 +165,8 @@ void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||
}
|
||||
|
||||
dh::Timer split;
|
||||
find_split(gpu_data->items_iter, gpu_data->split_candidates.data(),
|
||||
gpu_data->nodes.data(), (bst_uint)gpu_data->fvalues.size(),
|
||||
gpu_data->n_features, gpu_data->foffsets.data(),
|
||||
gpu_data->node_sums.data(), gpu_data->node_offsets.data(),
|
||||
gpu_data->param, level, use_multiscan_algorithm);
|
||||
find_split(gpu_data, param, level, use_multiscan_algorithm,
|
||||
feature_set_tree, &feature_set_level);
|
||||
|
||||
split.printElapsed("split");
|
||||
|
||||
@ -302,22 +184,6 @@ void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||
std::cerr << "Unknown exception." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
cudaProfilerStop();
|
||||
}
|
||||
|
||||
float GPUBuilder::GetSubsamplingRate(MetaInfo info) {
|
||||
float subsample = 1.0;
|
||||
uint32_t max_nodes = (1 << (param.max_depth + 1)) - 1;
|
||||
uint32_t max_nodes_level = 1 << param.max_depth;
|
||||
size_t required = 10 * info.num_row + 40 * info.num_nonzero
|
||||
+ 64 * max_nodes + 76 * max_nodes_level * info.num_col;
|
||||
size_t available = dh::available_memory();
|
||||
while (available < required) {
|
||||
subsample -= 0.05;
|
||||
required = 10 * info.num_row + subsample * (44 * info.num_nonzero);
|
||||
}
|
||||
|
||||
return subsample;
|
||||
}
|
||||
|
||||
void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
||||
@ -325,7 +191,7 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
||||
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block";
|
||||
|
||||
if (gpu_data->IsAllocated()) {
|
||||
gpu_data->Reset(gpair);
|
||||
gpu_data->Reset(gpair, param.subsample);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -333,35 +199,6 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
||||
|
||||
MetaInfo info = fmat.info();
|
||||
|
||||
// Work out if dataset will fit on GPU
|
||||
float subsample = this->GetSubsamplingRate(info);
|
||||
CHECK(subsample > 0.0);
|
||||
if (!param.silent && subsample < param.subsample) {
|
||||
LOG(CONSOLE) << "Not enough device memory for entire dataset.";
|
||||
}
|
||||
|
||||
// Override subsample parameter if user-specified parameter is lower
|
||||
subsample = std::min(param.subsample, subsample);
|
||||
|
||||
std::vector<bool> row_flags;
|
||||
|
||||
if (subsample < 1.0) {
|
||||
if (!param.silent && subsample < 1.0) {
|
||||
LOG(CONSOLE) << "Subsampling " << subsample * 100 << "% of rows.";
|
||||
}
|
||||
|
||||
const RowSet &rowset = fmat.buffered_rowset();
|
||||
row_flags.resize(info.num_row);
|
||||
std::bernoulli_distribution coin_flip(subsample);
|
||||
auto &rnd = common::GlobalRandom();
|
||||
for (size_t i = 0; i < rowset.size(); ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
if (gpair[ridx].hess < 0.0f)
|
||||
continue;
|
||||
row_flags[ridx] = coin_flip(rnd);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> foffsets;
|
||||
foffsets.push_back(0);
|
||||
std::vector<int> feature_id;
|
||||
@ -382,18 +219,10 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
||||
for (const ColBatch::Entry *it = col.data; it != col.data + col.length;
|
||||
it++) {
|
||||
bst_uint inst_id = it->index;
|
||||
if (subsample < 1.0) {
|
||||
if (row_flags[inst_id]) {
|
||||
fvalues.push_back(it->fvalue);
|
||||
instance_id.push_back(inst_id);
|
||||
feature_id.push_back(i);
|
||||
}
|
||||
} else {
|
||||
fvalues.push_back(it->fvalue);
|
||||
instance_id.push_back(inst_id);
|
||||
feature_id.push_back(i);
|
||||
}
|
||||
}
|
||||
foffsets.push_back(fvalues.size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ class GPUBuilder {
|
||||
RegTree *p_tree);
|
||||
|
||||
void UpdateNodeId(int level);
|
||||
|
||||
private:
|
||||
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
||||
const RegTree &tree);
|
||||
@ -31,9 +32,12 @@ class GPUBuilder {
|
||||
void Sort(int level);
|
||||
void InitFirstNode();
|
||||
void CopyTree(RegTree &tree); // NOLINT
|
||||
void ColsampleTree();
|
||||
|
||||
TrainParam param;
|
||||
GPUData *gpu_data;
|
||||
std::vector<int> feature_set_tree;
|
||||
std::vector<int> feature_set_level;
|
||||
|
||||
int multiscan_levels =
|
||||
5; // Number of levels before switching to sorting algorithm
|
||||
|
||||
162
plugin/updater_gpu/src/gpu_data.cuh
Normal file
162
plugin/updater_gpu/src/gpu_data.cuh
Normal file
@ -0,0 +1,162 @@
|
||||
/*!
|
||||
* Copyright 2016 Rory mitchell
|
||||
*/
|
||||
#pragma once
|
||||
#include <cub/cub.cuh>
|
||||
#include <xgboost/logging.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <vector>
|
||||
#include "device_helpers.cuh"
|
||||
#include "../../src/tree/param.h"
|
||||
#include "types_functions.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
struct GPUData {
|
||||
GPUData() : allocated(false), n_features(0), n_instances(0) {}
|
||||
|
||||
bool allocated;
|
||||
int n_features;
|
||||
int n_instances;
|
||||
|
||||
dh::bulk_allocator ba;
|
||||
GPUTrainingParam param;
|
||||
|
||||
dh::dvec<float> fvalues;
|
||||
dh::dvec<float> fvalues_temp;
|
||||
dh::dvec<float> fvalues_cached;
|
||||
dh::dvec<int> foffsets;
|
||||
dh::dvec<bst_uint> instance_id;
|
||||
dh::dvec<bst_uint> instance_id_temp;
|
||||
dh::dvec<bst_uint> instance_id_cached;
|
||||
dh::dvec<int> feature_id;
|
||||
dh::dvec<NodeIdT> node_id;
|
||||
dh::dvec<NodeIdT> node_id_temp;
|
||||
dh::dvec<NodeIdT> node_id_instance;
|
||||
dh::dvec<gpu_gpair> gpair;
|
||||
dh::dvec<Node> nodes;
|
||||
dh::dvec<Split> split_candidates;
|
||||
dh::dvec<gpu_gpair> node_sums;
|
||||
dh::dvec<int> node_offsets;
|
||||
dh::dvec<int> sort_index_in;
|
||||
dh::dvec<int> sort_index_out;
|
||||
|
||||
dh::dvec<char> cub_mem;
|
||||
|
||||
dh::dvec<int> feature_flags;
|
||||
dh::dvec<int> feature_set;
|
||||
|
||||
ItemIter items_iter;
|
||||
|
||||
void Init(const std::vector<float> &in_fvalues,
|
||||
const std::vector<int> &in_foffsets,
|
||||
const std::vector<bst_uint> &in_instance_id,
|
||||
const std::vector<int> &in_feature_id,
|
||||
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
|
||||
bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) {
|
||||
n_features = n_features_in;
|
||||
n_instances = n_instances_in;
|
||||
|
||||
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
|
||||
uint32_t max_nodes_level = 1 << max_depth;
|
||||
|
||||
// Calculate memory for sort
|
||||
size_t cub_mem_size = 0;
|
||||
cub::DoubleBuffer<NodeIdT> db_key;
|
||||
cub::DoubleBuffer<int> db_value;
|
||||
|
||||
cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
cub_mem.data(), cub_mem_size, db_key,
|
||||
db_value, in_fvalues.size(), n_features,
|
||||
foffsets.data(), foffsets.data() + 1);
|
||||
|
||||
// Allocate memory
|
||||
size_t free_memory = dh::available_memory();
|
||||
ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(),
|
||||
&fvalues_cached, in_fvalues.size(), &foffsets,
|
||||
in_foffsets.size(), &instance_id, in_instance_id.size(),
|
||||
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
|
||||
in_instance_id.size(), &feature_id, in_feature_id.size(),
|
||||
&node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(),
|
||||
&node_id_instance, n_instances, &gpair, n_instances, &nodes,
|
||||
max_nodes, &split_candidates, max_nodes_level * n_features,
|
||||
&node_sums, max_nodes_level * n_features, &node_offsets,
|
||||
max_nodes_level * n_features, &sort_index_in, in_fvalues.size(),
|
||||
&sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size,
|
||||
&feature_flags, n_features, &feature_set, n_features);
|
||||
|
||||
if (!param_in.silent) {
|
||||
const int mb_size = 1048576;
|
||||
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
|
||||
<< free_memory / mb_size << " MB on " << dh::device_name();
|
||||
}
|
||||
|
||||
fvalues_cached = in_fvalues;
|
||||
foffsets = in_foffsets;
|
||||
instance_id_cached = in_instance_id;
|
||||
feature_id = in_feature_id;
|
||||
|
||||
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
|
||||
param_in.reg_alpha, param_in.max_delta_step);
|
||||
|
||||
|
||||
allocated = true;
|
||||
|
||||
this->Reset(in_gpair, param_in.subsample);
|
||||
|
||||
items_iter = thrust::make_zip_iterator(thrust::make_tuple(
|
||||
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
|
||||
fvalues.tbegin(), node_id.tbegin()));
|
||||
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
~GPUData() {}
|
||||
|
||||
// Set gradient pair to 0 with p = 1 - subsample
|
||||
void MarkSubsample(float subsample) {
|
||||
if (subsample == 1.0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto d_gpair = gpair.data();
|
||||
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
|
||||
|
||||
dh::launch_n(n_instances, [=] __device__(int i) {
|
||||
if (!rng(i)) {
|
||||
d_gpair[i] = gpu_gpair();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Reset memory for new boosting iteration
|
||||
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
|
||||
CHECK(allocated);
|
||||
gpair = in_gpair;
|
||||
this->MarkSubsample(subsample);
|
||||
instance_id = instance_id_cached;
|
||||
fvalues = fvalues_cached;
|
||||
nodes.fill(Node());
|
||||
node_id_instance.fill(0);
|
||||
node_id.fill(0);
|
||||
}
|
||||
|
||||
bool IsAllocated() { return allocated; }
|
||||
|
||||
// Gather from node_id_instance into node_id according to instance_id
|
||||
void GatherNodeId() {
|
||||
// Update node_id for each item
|
||||
auto d_node_id = node_id.data();
|
||||
auto d_node_id_instance = node_id_instance.data();
|
||||
auto d_instance_id = instance_id.data();
|
||||
|
||||
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
|
||||
// Item item = d_items[i];
|
||||
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user