GPU Plugin: Add subsample, colsample_bytree, colsample_bylevel (#1895)

This commit is contained in:
Rory Mitchell 2016-12-23 04:30:36 +13:00 committed by Tianqi Chen
parent cee4aafb93
commit b49b339183
10 changed files with 331 additions and 324 deletions

View File

@ -9,10 +9,10 @@ https://www.kaggle.com/c/bosch-production-line-performance/data
Copy train_numeric.csv into xgboost/demo/data.
The subsample parameter can be changed so you can run the script first on a small portion of the data. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970.
The subset parameter changes the proportion of rows loaded from the CSV file. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970. Lower the parameter if your device runs out of memory.
```python
subsample = 0.4
subset = 0.4
```
Parameters are set as usual except that we set silent to 0 to see how much memory is being allocated on the GPU and we change 'updater' to 'grow_gpu' to activate the GPU plugin.

View File

@ -5,12 +5,12 @@ import time
import random
from sklearn.cross_validation import StratifiedKFold
#For sub sampling rows from input file
#For sampling rows from input file
random_seed = 9
subsample = 0.4
subset = 0.4
n_rows = 1183747;
train_rows = int(n_rows * subsample)
train_rows = int(n_rows * subset)
random.seed(random_seed)
skip = sorted(random.sample(xrange(1,n_rows + 1),n_rows-train_rows))
data = pd.read_csv("../data/train_numeric.csv", index_col=0, dtype=np.float32, skiprows=skip)

View File

@ -32,8 +32,6 @@ Data is stored in a sparse format. For example, missing values produced by one h
A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset.
The algorithm will automatically perform row subsampling if it detects there is not enough memory on the device.
## Dependencies
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).

View File

@ -7,6 +7,7 @@
#include <thrust/device_vector.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/random.h>
#include <algorithm>
#include <ctime>
#include <sstream>
@ -147,6 +148,8 @@ struct Timer {
LARGE_INTEGER now;
QueryPerformanceCounter(&now);
return static_cast<double>(now.QuadPart) / s_frequency.QuadPart;
#else
return 0;
#endif
}
@ -160,12 +163,14 @@ struct Timer {
#ifdef _WIN32
_ReadWriteBarrier();
return seconds_now() - start;
#else
return 0;
#endif
}
void printElapsed(char *label) {
void printElapsed(std::string label) {
#ifdef TIMERS
safe_cuda(cudaDeviceSynchronize());
printf("%s:\t %1.4fs\n", label, elapsed());
printf("%s:\t %1.4fs\n", label.c_str(), elapsed());
#endif
}
};
@ -233,46 +238,6 @@ template <typename T> __device__ range block_stride_range(T begin, T end) {
return r;
}
/*
* Utility functions
*/
template <typename T>
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
thrust::host_vector<T> h = v;
for (int i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
}
template <typename T>
void print(char *label, const thrust::device_vector<T> &v,
const char *format = "%d ", int max = 10) {
thrust::host_vector<T> h_v = v;
std::cout << label << ":\n";
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
printf(format, h_v[i]);
}
std::cout << "\n";
}
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr);
}
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
return sizeof(T) * v.size();
}
// Threadblock iterates over range, filling with value
template <typename IterT, typename ValueT>
@ -306,11 +271,11 @@ template <typename T> class dvec {
public:
dvec() : _ptr(NULL), _size(0) {}
size_t size() { return _size; }
bool empty() { return _ptr == NULL || _size == 0; }
size_t size() const { return _size; }
bool empty() const { return _ptr == NULL || _size == 0; }
T *data() { return _ptr; }
std::vector<T> as_vector() {
std::vector<T> as_vector() const {
std::vector<T> h_vector(size());
safe_cuda(cudaMemcpy(h_vector.data(), _ptr, size() * sizeof(T),
cudaMemcpyDeviceToHost));
@ -454,6 +419,55 @@ inline std::string device_name() {
return std::string(prop.name);
}
/*
* Utility functions
*/
template <typename T>
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
thrust::host_vector<T> h = v;
for (int i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
}
template <typename T>
void print(const dvec<T> &v, size_t max_items = 10) {
std::vector<T> h = v.as_vector();
for (int i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
}
template <typename T>
void print(char *label, const thrust::device_vector<T> &v,
const char *format = "%d ", int max = 10) {
thrust::host_vector<T> h_v = v;
std::cout << label << ":\n";
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
printf(format, h_v[i]);
}
std::cout << "\n";
}
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr);
}
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
return sizeof(T) * v.size();
}
/*
* Kernel launcher
*/
@ -470,4 +484,25 @@ inline void launch_n(size_t n, L lambda) {
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
}
/*
* Random
*/
struct BernoulliRng {
float p;
int seed;
__host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {}
__host__ __device__ bool operator()(const int i) const {
thrust::default_random_engine rng(seed);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng) <= p;
}
};
} // namespace dh

View File

@ -4,9 +4,11 @@
#pragma once
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include <vector>
#include "device_helpers.cuh"
#include "find_split_multiscan.cuh"
#include "find_split_sorting.cuh"
#include "gpu_data.cuh"
#include "types_functions.cuh"
namespace xgboost {
@ -62,24 +64,47 @@ void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes,
dh::safe_cuda(cudaDeviceSynchronize());
}
void find_split(const ItemIter items_iter, Split *d_split_candidates,
Node *d_nodes, bst_uint num_items, int num_features,
const int *d_feature_offsets, gpu_gpair *d_node_sums,
int *d_node_offsets, const GPUTrainingParam param,
const int level, bool multiscan_algorithm) {
void colsample_level(GPUData *data, const TrainParam xgboost_param,
const std::vector<int> &feature_set_tree,
std::vector<int> *feature_set_level) {
unsigned n_bytree =
static_cast<unsigned>(xgboost_param.colsample_bytree * data->n_features);
unsigned n =
static_cast<unsigned>(n_bytree * xgboost_param.colsample_bylevel);
CHECK_GT(n, 0);
*feature_set_level = feature_set_tree;
std::shuffle((*feature_set_level).begin(),
(*feature_set_level).begin() + n_bytree, common::GlobalRandom());
data->feature_set = *feature_set_level;
data->feature_flags.fill(0);
auto d_feature_set = data->feature_set.data();
auto d_feature_flags = data->feature_flags.data();
dh::launch_n(
n, [=] __device__(int i) { d_feature_flags[d_feature_set[i]] = 1; });
}
void find_split(GPUData *data, const TrainParam xgboost_param, const int level,
bool multiscan_algorithm,
const std::vector<int> &feature_set_tree,
std::vector<int> *feature_set_level) {
colsample_level(data, xgboost_param, feature_set_tree, feature_set_level);
// Reset split candidates
data->split_candidates.fill(Split());
if (multiscan_algorithm) {
find_split_candidates_multiscan(items_iter, d_split_candidates, d_nodes,
num_items, num_features, d_feature_offsets,
param, level);
find_split_candidates_multiscan(data, level);
} else {
find_split_candidates_sorted(items_iter, d_split_candidates, d_nodes,
num_items, num_features, d_feature_offsets,
d_node_sums, d_node_offsets, param, level);
find_split_candidates_sorted(data, level);
}
// Find the best split for each node
reduce_split_candidates(d_split_candidates, d_nodes, level, num_features,
param);
reduce_split_candidates(data->split_candidates.data(), data->nodes.data(),
level, data->n_features, data->param);
}
} // namespace tree
} // namespace xgboost

View File

@ -5,6 +5,7 @@
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include "device_helpers.cuh"
#include "gpu_data.cuh"
#include "types_functions.cuh"
namespace xgboost {
@ -609,22 +610,11 @@ struct FindSplitEnactorMultiscan {
}
}
__device__ __forceinline__ void ResetSplitCandidates() {
const int max_nodes = 1 << level;
const int begin = blockIdx.x * max_nodes;
const int end = begin + max_nodes;
for (auto i : dh::block_stride_range(begin, end)) {
d_split_candidates_out[i] = Split();
}
}
__device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin,
const bst_uint &segment_end) {
// Current position
bst_uint offset = segment_begin;
ResetSplitCandidates();
ResetTileCarry();
ResetSplits();
CacheNodes();
@ -654,8 +644,9 @@ __launch_bounds__(1024, 2)
const ItemIter items_iter, Split *d_split_candidates_out,
const Node *d_nodes, const int node_begin, bst_uint num_items,
int num_features, const int *d_feature_offsets,
const GPUTrainingParam param, const int level) {
if (num_items <= 0) {
const GPUTrainingParam param, const int *d_feature_flags,
const int level) {
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
return;
}
@ -685,69 +676,45 @@ __launch_bounds__(1024, 2)
}
template <int N_NODES>
void find_split_candidates_multiscan_variation(
const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes,
int node_begin, int node_end, bst_uint num_items, int num_features,
const int *d_feature_offsets, const GPUTrainingParam param,
const int level) {
void find_split_candidates_multiscan_variation(GPUData *data, const int level) {
const int node_begin = (1 << level) - 1;
const int BLOCK_THREADS = 512;
CHECK((node_end - node_begin) <= N_NODES) << "Multiscan: N_NODES template "
"parameter too small for given "
"node range.";
CHECK(BLOCK_THREADS / 32 < 32)
<< "Too many active warps. See FindSplitEnactor - ReduceSplits.";
typedef FindSplitParamsMultiscan<BLOCK_THREADS, N_NODES, false>
find_split_params;
typedef ReduceParamsMultiscan<BLOCK_THREADS, N_NODES, false> reduce_params;
int grid_size = num_features;
int grid_size = data->n_features;
find_split_candidates_multiscan_kernel<
find_split_params,
reduce_params><<<grid_size, find_split_params::BLOCK_THREADS>>>(
items_iter, d_split_candidates, d_nodes, node_begin, num_items,
num_features, d_feature_offsets, param, level);
data->items_iter, data->split_candidates.data(), data->nodes.data(),
node_begin, data->fvalues.size(), data->n_features, data->foffsets.data(),
data->param, data->feature_flags.data(), level);
dh::safe_cuda(cudaDeviceSynchronize());
}
void find_split_candidates_multiscan(
const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes,
bst_uint num_items, int num_features, const int *d_feature_offsets,
const GPUTrainingParam param, const int level) {
void find_split_candidates_multiscan(GPUData *data, const int level) {
// Select templated variation of split finding algorithm
switch (level) {
case 0:
find_split_candidates_multiscan_variation<1>(
items_iter, d_split_candidates, d_nodes, 0, 1, num_items, num_features,
d_feature_offsets, param, level);
find_split_candidates_multiscan_variation<1>(data, level);
break;
case 1:
find_split_candidates_multiscan_variation<2>(
items_iter, d_split_candidates, d_nodes, 1, 3, num_items, num_features,
d_feature_offsets, param, level);
find_split_candidates_multiscan_variation<2>(data, level);
break;
case 2:
find_split_candidates_multiscan_variation<4>(
items_iter, d_split_candidates, d_nodes, 3, 7, num_items, num_features,
d_feature_offsets, param, level);
find_split_candidates_multiscan_variation<4>(data, level);
break;
case 3:
find_split_candidates_multiscan_variation<8>(
items_iter, d_split_candidates, d_nodes, 7, 15, num_items, num_features,
d_feature_offsets, param, level);
find_split_candidates_multiscan_variation<8>(data, level);
break;
case 4:
find_split_candidates_multiscan_variation<16>(
items_iter, d_split_candidates, d_nodes, 15, 31, num_items,
num_features, d_feature_offsets, param, level);
break;
case 5:
find_split_candidates_multiscan_variation<32>(
items_iter, d_split_candidates, d_nodes, 31, 63, num_items,
num_features, d_feature_offsets, param, level);
find_split_candidates_multiscan_variation<16>(data, level);
break;
}
}

View File

@ -337,17 +337,8 @@ struct FindSplitEnactorSorting {
WriteBestSplit(node_id_adjusted);
}
__device__ __forceinline__ void ResetSplitCandidates() {
const int max_nodes = 1 << level;
const int begin = blockIdx.x * max_nodes;
dh::block_fill(d_split_candidates_out + begin, max_nodes, Split());
}
__device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin,
const bst_uint &segment_end) {
ResetSplitCandidates();
int node_begin = segment_begin;
const int max_nodes = 1 << level;
@ -377,9 +368,9 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
const ItemIter items_iter, Split *d_split_candidates_out,
const Node *d_nodes, bst_uint num_items, const int num_features,
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
const GPUTrainingParam param, const int level) {
const GPUTrainingParam param, const int *d_feature_flags, const int level) {
if (num_items <= 0) {
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
return;
}
@ -408,23 +399,19 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
.ProcessFeature(segment_begin, segment_end);
}
void find_split_candidates_sorted(const ItemIter items_iter,
Split *d_split_candidates, Node *d_nodes,
bst_uint num_items, int num_features,
const int *d_feature_offsets,
gpu_gpair *d_node_sums, int *d_node_offsets,
const GPUTrainingParam param,
const int level) {
void find_split_candidates_sorted(GPUData * data, const int level) {
const int BLOCK_THREADS = 512;
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
int grid_size = num_features;
int grid_size = data->n_features;
find_split_candidates_sorted_kernel<
BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>(
items_iter, d_split_candidates, d_nodes, num_items, num_features,
d_feature_offsets, d_node_sums, d_node_offsets, param, level);
data->items_iter, data->split_candidates.data(), data->nodes.data(),
data->fvalues.size(), data->n_features,
data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(),
data->param, data->feature_flags.data(), level);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());

View File

@ -12,143 +12,17 @@
#include <thrust/sequence.h>
#include <algorithm>
#include <random>
#include <numeric>
#include <vector>
#include "../../../src/common/random.h"
#include "device_helpers.cuh"
#include "find_split.cuh"
#include "gpu_builder.cuh"
#include "types_functions.cuh"
#include "gpu_data.cuh"
namespace xgboost {
namespace tree {
struct GPUData {
GPUData() : allocated(false), n_features(0), n_instances(0) {}
bool allocated;
int n_features;
int n_instances;
dh::bulk_allocator ba;
GPUTrainingParam param;
dh::dvec<float> fvalues;
dh::dvec<float> fvalues_temp;
dh::dvec<float> fvalues_cached;
dh::dvec<int> foffsets;
dh::dvec<bst_uint> instance_id;
dh::dvec<bst_uint> instance_id_temp;
dh::dvec<bst_uint> instance_id_cached;
dh::dvec<int> feature_id;
dh::dvec<NodeIdT> node_id;
dh::dvec<NodeIdT> node_id_temp;
dh::dvec<NodeIdT> node_id_instance;
dh::dvec<gpu_gpair> gpair;
dh::dvec<Node> nodes;
dh::dvec<Split> split_candidates;
dh::dvec<gpu_gpair> node_sums;
dh::dvec<int> node_offsets;
dh::dvec<int> sort_index_in;
dh::dvec<int> sort_index_out;
dh::dvec<char> cub_mem;
ItemIter items_iter;
void Init(const std::vector<float> &in_fvalues,
const std::vector<int> &in_foffsets,
const std::vector<bst_uint> &in_instance_id,
const std::vector<int> &in_feature_id,
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
bst_uint n_features_in, int max_depth, const TrainParam &param_in) {
n_features = n_features_in;
n_instances = n_instances_in;
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
uint32_t max_nodes_level = 1 << max_depth;
// Calculate memory for sort
size_t cub_mem_size = 0;
cub::DoubleBuffer<NodeIdT> db_key;
cub::DoubleBuffer<int> db_value;
cub::DeviceSegmentedRadixSort::SortPairs(
cub_mem.data(), cub_mem_size, db_key,
db_value, in_fvalues.size(), n_features,
foffsets.data(), foffsets.data() + 1);
// Allocate memory
size_t free_memory = dh::available_memory();
ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(),
&fvalues_cached, in_fvalues.size(), &foffsets,
in_foffsets.size(), &instance_id, in_instance_id.size(),
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
in_instance_id.size(), &feature_id, in_feature_id.size(),
&node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(),
&node_id_instance, n_instances, &gpair, n_instances, &nodes,
max_nodes, &split_candidates, max_nodes_level * n_features,
&node_sums, max_nodes_level * n_features, &node_offsets,
max_nodes_level * n_features, &sort_index_in, in_fvalues.size(),
&sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size);
if (!param_in.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on " << dh::device_name();
}
node_id.fill(0);
node_id_instance.fill(0);
fvalues = in_fvalues;
fvalues_cached = fvalues;
foffsets = in_foffsets;
instance_id = in_instance_id;
instance_id_cached = instance_id;
feature_id = in_feature_id;
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step);
gpair = in_gpair;
nodes.fill(Node());
items_iter = thrust::make_zip_iterator(thrust::make_tuple(
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
fvalues.tbegin(), node_id.tbegin()));
allocated = true;
dh::safe_cuda(cudaGetLastError());
}
~GPUData() {}
// Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair) {
CHECK(allocated);
gpair = in_gpair;
instance_id = instance_id_cached;
fvalues = fvalues_cached;
nodes.fill(Node());
node_id_instance.fill(0);
node_id.fill(0);
}
bool IsAllocated() { return allocated; }
// Gather from node_id_instance into node_id according to instance_id
void GatherNodeId() {
// Update node_id for each item
auto d_node_id = node_id.data();
auto d_node_id_instance = node_id_instance.data();
auto d_instance_id = instance_id.data();
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
// Item item = d_items[i];
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
});
}
};
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
@ -253,15 +127,26 @@ void GPUBuilder::Sort(int level) {
}
}
void GPUBuilder::ColsampleTree() {
unsigned n = static_cast<unsigned>(
param.colsample_bytree * gpu_data->n_features);
CHECK_GT(n, 0);
feature_set_tree.resize(gpu_data->n_features);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
std::shuffle(feature_set_tree.begin(), feature_set_tree.end(),
common::GlobalRandom());
}
void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree) {
cudaProfilerStart();
try {
dh::Timer update;
dh::Timer t;
this->InitData(gpair, *p_fmat, *p_tree);
t.printElapsed("init data");
this->InitFirstNode();
this->ColsampleTree();
for (int level = 0; level < param.max_depth; level++) {
bool use_multiscan_algorithm = level < multiscan_levels;
@ -280,11 +165,8 @@ void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
}
dh::Timer split;
find_split(gpu_data->items_iter, gpu_data->split_candidates.data(),
gpu_data->nodes.data(), (bst_uint)gpu_data->fvalues.size(),
gpu_data->n_features, gpu_data->foffsets.data(),
gpu_data->node_sums.data(), gpu_data->node_offsets.data(),
gpu_data->param, level, use_multiscan_algorithm);
find_split(gpu_data, param, level, use_multiscan_algorithm,
feature_set_tree, &feature_set_level);
split.printElapsed("split");
@ -302,22 +184,6 @@ void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
std::cerr << "Unknown exception." << std::endl;
exit(-1);
}
cudaProfilerStop();
}
float GPUBuilder::GetSubsamplingRate(MetaInfo info) {
float subsample = 1.0;
uint32_t max_nodes = (1 << (param.max_depth + 1)) - 1;
uint32_t max_nodes_level = 1 << param.max_depth;
size_t required = 10 * info.num_row + 40 * info.num_nonzero
+ 64 * max_nodes + 76 * max_nodes_level * info.num_col;
size_t available = dh::available_memory();
while (available < required) {
subsample -= 0.05;
required = 10 * info.num_row + subsample * (44 * info.num_nonzero);
}
return subsample;
}
void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
@ -325,7 +191,7 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block";
if (gpu_data->IsAllocated()) {
gpu_data->Reset(gpair);
gpu_data->Reset(gpair, param.subsample);
return;
}
@ -333,35 +199,6 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
MetaInfo info = fmat.info();
// Work out if dataset will fit on GPU
float subsample = this->GetSubsamplingRate(info);
CHECK(subsample > 0.0);
if (!param.silent && subsample < param.subsample) {
LOG(CONSOLE) << "Not enough device memory for entire dataset.";
}
// Override subsample parameter if user-specified parameter is lower
subsample = std::min(param.subsample, subsample);
std::vector<bool> row_flags;
if (subsample < 1.0) {
if (!param.silent && subsample < 1.0) {
LOG(CONSOLE) << "Subsampling " << subsample * 100 << "% of rows.";
}
const RowSet &rowset = fmat.buffered_rowset();
row_flags.resize(info.num_row);
std::bernoulli_distribution coin_flip(subsample);
auto &rnd = common::GlobalRandom();
for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f)
continue;
row_flags[ridx] = coin_flip(rnd);
}
}
std::vector<int> foffsets;
foffsets.push_back(0);
std::vector<int> feature_id;
@ -382,17 +219,9 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
for (const ColBatch::Entry *it = col.data; it != col.data + col.length;
it++) {
bst_uint inst_id = it->index;
if (subsample < 1.0) {
if (row_flags[inst_id]) {
fvalues.push_back(it->fvalue);
instance_id.push_back(inst_id);
feature_id.push_back(i);
}
} else {
fvalues.push_back(it->fvalue);
instance_id.push_back(inst_id);
feature_id.push_back(i);
}
}
foffsets.push_back(fvalues.size());
}

View File

@ -23,6 +23,7 @@ class GPUBuilder {
RegTree *p_tree);
void UpdateNodeId(int level);
private:
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
@ -31,12 +32,15 @@ class GPUBuilder {
void Sort(int level);
void InitFirstNode();
void CopyTree(RegTree &tree); // NOLINT
void ColsampleTree();
TrainParam param;
GPUData *gpu_data;
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
int multiscan_levels =
5; // Number of levels before switching to sorting algorithm
5; // Number of levels before switching to sorting algorithm
};
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,162 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cub/cub.cuh>
#include <xgboost/logging.h>
#include <thrust/sequence.h>
#include <vector>
#include "device_helpers.cuh"
#include "../../src/tree/param.h"
#include "types_functions.cuh"
namespace xgboost {
namespace tree {
struct GPUData {
GPUData() : allocated(false), n_features(0), n_instances(0) {}
bool allocated;
int n_features;
int n_instances;
dh::bulk_allocator ba;
GPUTrainingParam param;
dh::dvec<float> fvalues;
dh::dvec<float> fvalues_temp;
dh::dvec<float> fvalues_cached;
dh::dvec<int> foffsets;
dh::dvec<bst_uint> instance_id;
dh::dvec<bst_uint> instance_id_temp;
dh::dvec<bst_uint> instance_id_cached;
dh::dvec<int> feature_id;
dh::dvec<NodeIdT> node_id;
dh::dvec<NodeIdT> node_id_temp;
dh::dvec<NodeIdT> node_id_instance;
dh::dvec<gpu_gpair> gpair;
dh::dvec<Node> nodes;
dh::dvec<Split> split_candidates;
dh::dvec<gpu_gpair> node_sums;
dh::dvec<int> node_offsets;
dh::dvec<int> sort_index_in;
dh::dvec<int> sort_index_out;
dh::dvec<char> cub_mem;
dh::dvec<int> feature_flags;
dh::dvec<int> feature_set;
ItemIter items_iter;
void Init(const std::vector<float> &in_fvalues,
const std::vector<int> &in_foffsets,
const std::vector<bst_uint> &in_instance_id,
const std::vector<int> &in_feature_id,
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
bst_uint n_features_in, int max_depth, const TrainParam &param_in) {
n_features = n_features_in;
n_instances = n_instances_in;
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
uint32_t max_nodes_level = 1 << max_depth;
// Calculate memory for sort
size_t cub_mem_size = 0;
cub::DoubleBuffer<NodeIdT> db_key;
cub::DoubleBuffer<int> db_value;
cub::DeviceSegmentedRadixSort::SortPairs(
cub_mem.data(), cub_mem_size, db_key,
db_value, in_fvalues.size(), n_features,
foffsets.data(), foffsets.data() + 1);
// Allocate memory
size_t free_memory = dh::available_memory();
ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(),
&fvalues_cached, in_fvalues.size(), &foffsets,
in_foffsets.size(), &instance_id, in_instance_id.size(),
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
in_instance_id.size(), &feature_id, in_feature_id.size(),
&node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(),
&node_id_instance, n_instances, &gpair, n_instances, &nodes,
max_nodes, &split_candidates, max_nodes_level * n_features,
&node_sums, max_nodes_level * n_features, &node_offsets,
max_nodes_level * n_features, &sort_index_in, in_fvalues.size(),
&sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size,
&feature_flags, n_features, &feature_set, n_features);
if (!param_in.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on " << dh::device_name();
}
fvalues_cached = in_fvalues;
foffsets = in_foffsets;
instance_id_cached = in_instance_id;
feature_id = in_feature_id;
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step);
allocated = true;
this->Reset(in_gpair, param_in.subsample);
items_iter = thrust::make_zip_iterator(thrust::make_tuple(
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
fvalues.tbegin(), node_id.tbegin()));
dh::safe_cuda(cudaGetLastError());
}
~GPUData() {}
// Set gradient pair to 0 with p = 1 - subsample
void MarkSubsample(float subsample) {
if (subsample == 1.0) {
return;
}
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(n_instances, [=] __device__(int i) {
if (!rng(i)) {
d_gpair[i] = gpu_gpair();
}
});
}
// Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
CHECK(allocated);
gpair = in_gpair;
this->MarkSubsample(subsample);
instance_id = instance_id_cached;
fvalues = fvalues_cached;
nodes.fill(Node());
node_id_instance.fill(0);
node_id.fill(0);
}
bool IsAllocated() { return allocated; }
// Gather from node_id_instance into node_id according to instance_id
void GatherNodeId() {
// Update node_id for each item
auto d_node_id = node_id.data();
auto d_node_id_instance = node_id_instance.data();
auto d_instance_id = instance_id.data();
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
// Item item = d_items[i];
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
});
}
};
} // namespace tree
} // namespace xgboost