diff --git a/demo/gpu_acceleration/README.md b/demo/gpu_acceleration/README.md index ab668cfcf..43591fc67 100644 --- a/demo/gpu_acceleration/README.md +++ b/demo/gpu_acceleration/README.md @@ -9,10 +9,10 @@ https://www.kaggle.com/c/bosch-production-line-performance/data Copy train_numeric.csv into xgboost/demo/data. -The subsample parameter can be changed so you can run the script first on a small portion of the data. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970. +The subset parameter changes the proportion of rows loaded from the CSV file. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970. Lower the parameter if your device runs out of memory. ```python -subsample = 0.4 +subset = 0.4 ``` Parameters are set as usual except that we set silent to 0 to see how much memory is being allocated on the GPU and we change 'updater' to 'grow_gpu' to activate the GPU plugin. diff --git a/demo/gpu_acceleration/bosch.py b/demo/gpu_acceleration/bosch.py index 63f269c90..f717532cc 100644 --- a/demo/gpu_acceleration/bosch.py +++ b/demo/gpu_acceleration/bosch.py @@ -5,12 +5,12 @@ import time import random from sklearn.cross_validation import StratifiedKFold -#For sub sampling rows from input file +#For sampling rows from input file random_seed = 9 -subsample = 0.4 +subset = 0.4 n_rows = 1183747; -train_rows = int(n_rows * subsample) +train_rows = int(n_rows * subset) random.seed(random_seed) skip = sorted(random.sample(xrange(1,n_rows + 1),n_rows-train_rows)) data = pd.read_csv("../data/train_numeric.csv", index_col=0, dtype=np.float32, skiprows=skip) diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index faecd2b20..9a31869aa 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -32,8 +32,6 @@ Data is stored in a sparse format. For example, missing values produced by one h A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset. -The algorithm will automatically perform row subsampling if it detects there is not enough memory on the device. - ## Dependencies A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index 5af46351e..ee8e4a2fc 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -147,6 +148,8 @@ struct Timer { LARGE_INTEGER now; QueryPerformanceCounter(&now); return static_cast(now.QuadPart) / s_frequency.QuadPart; +#else + return 0; #endif } @@ -160,12 +163,14 @@ struct Timer { #ifdef _WIN32 _ReadWriteBarrier(); return seconds_now() - start; +#else + return 0; #endif } - void printElapsed(char *label) { + void printElapsed(std::string label) { #ifdef TIMERS safe_cuda(cudaDeviceSynchronize()); - printf("%s:\t %1.4fs\n", label, elapsed()); + printf("%s:\t %1.4fs\n", label.c_str(), elapsed()); #endif } }; @@ -233,46 +238,6 @@ template __device__ range block_stride_range(T begin, T end) { return r; } -/* - * Utility functions - */ - -template -void print(const thrust::device_vector &v, size_t max_items = 10) { - thrust::host_vector h = v; - for (int i = 0; i < std::min(max_items, h.size()); i++) { - std::cout << " " << h[i]; - } - std::cout << "\n"; -} - -template -void print(char *label, const thrust::device_vector &v, - const char *format = "%d ", int max = 10) { - thrust::host_vector h_v = v; - - std::cout << label << ":\n"; - for (int i = 0; i < std::min(static_cast(h_v.size()), max); i++) { - printf(format, h_v[i]); - } - std::cout << "\n"; -} - -template T1 div_round_up(const T1 a, const T2 b) { - return static_cast(ceil(static_cast(a) / b)); -} - -template thrust::device_ptr dptr(T *d_ptr) { - return thrust::device_pointer_cast(d_ptr); -} - -template T *raw(thrust::device_vector &v) { // NOLINT - return raw_pointer_cast(v.data()); -} - -template size_t size_bytes(const thrust::device_vector &v) { - return sizeof(T) * v.size(); -} // Threadblock iterates over range, filling with value template @@ -306,11 +271,11 @@ template class dvec { public: dvec() : _ptr(NULL), _size(0) {} - size_t size() { return _size; } - bool empty() { return _ptr == NULL || _size == 0; } + size_t size() const { return _size; } + bool empty() const { return _ptr == NULL || _size == 0; } T *data() { return _ptr; } - std::vector as_vector() { + std::vector as_vector() const { std::vector h_vector(size()); safe_cuda(cudaMemcpy(h_vector.data(), _ptr, size() * sizeof(T), cudaMemcpyDeviceToHost)); @@ -454,6 +419,55 @@ inline std::string device_name() { return std::string(prop.name); } +/* + * Utility functions + */ + +template +void print(const thrust::device_vector &v, size_t max_items = 10) { + thrust::host_vector h = v; + for (int i = 0; i < std::min(max_items, h.size()); i++) { + std::cout << " " << h[i]; + } + std::cout << "\n"; +} + +template +void print(const dvec &v, size_t max_items = 10) { + std::vector h = v.as_vector(); + for (int i = 0; i < std::min(max_items, h.size()); i++) { + std::cout << " " << h[i]; + } + std::cout << "\n"; +} + +template +void print(char *label, const thrust::device_vector &v, + const char *format = "%d ", int max = 10) { + thrust::host_vector h_v = v; + + std::cout << label << ":\n"; + for (int i = 0; i < std::min(static_cast(h_v.size()), max); i++) { + printf(format, h_v[i]); + } + std::cout << "\n"; +} + +template T1 div_round_up(const T1 a, const T2 b) { + return static_cast(ceil(static_cast(a) / b)); +} + +template thrust::device_ptr dptr(T *d_ptr) { + return thrust::device_pointer_cast(d_ptr); +} + +template T *raw(thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} + +template size_t size_bytes(const thrust::device_vector &v) { + return sizeof(T) * v.size(); +} /* * Kernel launcher */ @@ -470,4 +484,25 @@ inline void launch_n(size_t n, L lambda) { launch_n_kernel<<>>(n, lambda); } + +/* + * Random + */ + +struct BernoulliRng { + float p; + int seed; + + __host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {} + + __host__ __device__ bool operator()(const int i) const { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + + return dist(rng) <= p; + } +}; + + } // namespace dh diff --git a/plugin/updater_gpu/src/find_split.cuh b/plugin/updater_gpu/src/find_split.cuh index 458d8fba7..070ea2efd 100644 --- a/plugin/updater_gpu/src/find_split.cuh +++ b/plugin/updater_gpu/src/find_split.cuh @@ -4,9 +4,11 @@ #pragma once #include #include +#include #include "device_helpers.cuh" #include "find_split_multiscan.cuh" #include "find_split_sorting.cuh" +#include "gpu_data.cuh" #include "types_functions.cuh" namespace xgboost { @@ -62,24 +64,47 @@ void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes, dh::safe_cuda(cudaDeviceSynchronize()); } -void find_split(const ItemIter items_iter, Split *d_split_candidates, - Node *d_nodes, bst_uint num_items, int num_features, - const int *d_feature_offsets, gpu_gpair *d_node_sums, - int *d_node_offsets, const GPUTrainingParam param, - const int level, bool multiscan_algorithm) { +void colsample_level(GPUData *data, const TrainParam xgboost_param, + const std::vector &feature_set_tree, + std::vector *feature_set_level) { + unsigned n_bytree = + static_cast(xgboost_param.colsample_bytree * data->n_features); + unsigned n = + static_cast(n_bytree * xgboost_param.colsample_bylevel); + CHECK_GT(n, 0); + + *feature_set_level = feature_set_tree; + + std::shuffle((*feature_set_level).begin(), + (*feature_set_level).begin() + n_bytree, common::GlobalRandom()); + + data->feature_set = *feature_set_level; + + data->feature_flags.fill(0); + auto d_feature_set = data->feature_set.data(); + auto d_feature_flags = data->feature_flags.data(); + + dh::launch_n( + n, [=] __device__(int i) { d_feature_flags[d_feature_set[i]] = 1; }); +} + +void find_split(GPUData *data, const TrainParam xgboost_param, const int level, + bool multiscan_algorithm, + const std::vector &feature_set_tree, + std::vector *feature_set_level) { + colsample_level(data, xgboost_param, feature_set_tree, feature_set_level); + // Reset split candidates + data->split_candidates.fill(Split()); + if (multiscan_algorithm) { - find_split_candidates_multiscan(items_iter, d_split_candidates, d_nodes, - num_items, num_features, d_feature_offsets, - param, level); + find_split_candidates_multiscan(data, level); } else { - find_split_candidates_sorted(items_iter, d_split_candidates, d_nodes, - num_items, num_features, d_feature_offsets, - d_node_sums, d_node_offsets, param, level); + find_split_candidates_sorted(data, level); } // Find the best split for each node - reduce_split_candidates(d_split_candidates, d_nodes, level, num_features, - param); + reduce_split_candidates(data->split_candidates.data(), data->nodes.data(), + level, data->n_features, data->param); } } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/find_split_multiscan.cuh b/plugin/updater_gpu/src/find_split_multiscan.cuh index 0cd906e73..41af54c7e 100644 --- a/plugin/updater_gpu/src/find_split_multiscan.cuh +++ b/plugin/updater_gpu/src/find_split_multiscan.cuh @@ -5,6 +5,7 @@ #include #include #include "device_helpers.cuh" +#include "gpu_data.cuh" #include "types_functions.cuh" namespace xgboost { @@ -609,22 +610,11 @@ struct FindSplitEnactorMultiscan { } } - __device__ __forceinline__ void ResetSplitCandidates() { - const int max_nodes = 1 << level; - const int begin = blockIdx.x * max_nodes; - const int end = begin + max_nodes; - - for (auto i : dh::block_stride_range(begin, end)) { - d_split_candidates_out[i] = Split(); - } - } - __device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin, const bst_uint &segment_end) { // Current position bst_uint offset = segment_begin; - ResetSplitCandidates(); ResetTileCarry(); ResetSplits(); CacheNodes(); @@ -654,8 +644,9 @@ __launch_bounds__(1024, 2) const ItemIter items_iter, Split *d_split_candidates_out, const Node *d_nodes, const int node_begin, bst_uint num_items, int num_features, const int *d_feature_offsets, - const GPUTrainingParam param, const int level) { - if (num_items <= 0) { + const GPUTrainingParam param, const int *d_feature_flags, + const int level) { + if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) { return; } @@ -685,69 +676,45 @@ __launch_bounds__(1024, 2) } template -void find_split_candidates_multiscan_variation( - const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes, - int node_begin, int node_end, bst_uint num_items, int num_features, - const int *d_feature_offsets, const GPUTrainingParam param, - const int level) { - +void find_split_candidates_multiscan_variation(GPUData *data, const int level) { + const int node_begin = (1 << level) - 1; const int BLOCK_THREADS = 512; - CHECK((node_end - node_begin) <= N_NODES) << "Multiscan: N_NODES template " - "parameter too small for given " - "node range."; CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps. See FindSplitEnactor - ReduceSplits."; typedef FindSplitParamsMultiscan find_split_params; typedef ReduceParamsMultiscan reduce_params; - int grid_size = num_features; + int grid_size = data->n_features; find_split_candidates_multiscan_kernel< find_split_params, reduce_params><<>>( - items_iter, d_split_candidates, d_nodes, node_begin, num_items, - num_features, d_feature_offsets, param, level); + data->items_iter, data->split_candidates.data(), data->nodes.data(), + node_begin, data->fvalues.size(), data->n_features, data->foffsets.data(), + data->param, data->feature_flags.data(), level); dh::safe_cuda(cudaDeviceSynchronize()); } -void find_split_candidates_multiscan( - const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes, - bst_uint num_items, int num_features, const int *d_feature_offsets, - const GPUTrainingParam param, const int level) { +void find_split_candidates_multiscan(GPUData *data, const int level) { // Select templated variation of split finding algorithm switch (level) { case 0: - find_split_candidates_multiscan_variation<1>( - items_iter, d_split_candidates, d_nodes, 0, 1, num_items, num_features, - d_feature_offsets, param, level); + find_split_candidates_multiscan_variation<1>(data, level); break; case 1: - find_split_candidates_multiscan_variation<2>( - items_iter, d_split_candidates, d_nodes, 1, 3, num_items, num_features, - d_feature_offsets, param, level); + find_split_candidates_multiscan_variation<2>(data, level); break; case 2: - find_split_candidates_multiscan_variation<4>( - items_iter, d_split_candidates, d_nodes, 3, 7, num_items, num_features, - d_feature_offsets, param, level); + find_split_candidates_multiscan_variation<4>(data, level); break; case 3: - find_split_candidates_multiscan_variation<8>( - items_iter, d_split_candidates, d_nodes, 7, 15, num_items, num_features, - d_feature_offsets, param, level); + find_split_candidates_multiscan_variation<8>(data, level); break; case 4: - find_split_candidates_multiscan_variation<16>( - items_iter, d_split_candidates, d_nodes, 15, 31, num_items, - num_features, d_feature_offsets, param, level); - break; - case 5: - find_split_candidates_multiscan_variation<32>( - items_iter, d_split_candidates, d_nodes, 31, 63, num_items, - num_features, d_feature_offsets, param, level); + find_split_candidates_multiscan_variation<16>(data, level); break; } } diff --git a/plugin/updater_gpu/src/find_split_sorting.cuh b/plugin/updater_gpu/src/find_split_sorting.cuh index 661e7df91..dae597ce1 100644 --- a/plugin/updater_gpu/src/find_split_sorting.cuh +++ b/plugin/updater_gpu/src/find_split_sorting.cuh @@ -337,17 +337,8 @@ struct FindSplitEnactorSorting { WriteBestSplit(node_id_adjusted); } - __device__ __forceinline__ void ResetSplitCandidates() { - const int max_nodes = 1 << level; - const int begin = blockIdx.x * max_nodes; - - dh::block_fill(d_split_candidates_out + begin, max_nodes, Split()); - } - __device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin, const bst_uint &segment_end) { - ResetSplitCandidates(); - int node_begin = segment_begin; const int max_nodes = 1 << level; @@ -377,9 +368,9 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( const ItemIter items_iter, Split *d_split_candidates_out, const Node *d_nodes, bst_uint num_items, const int num_features, const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, - const GPUTrainingParam param, const int level) { + const GPUTrainingParam param, const int *d_feature_flags, const int level) { - if (num_items <= 0) { + if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) { return; } @@ -408,23 +399,19 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( .ProcessFeature(segment_begin, segment_end); } -void find_split_candidates_sorted(const ItemIter items_iter, - Split *d_split_candidates, Node *d_nodes, - bst_uint num_items, int num_features, - const int *d_feature_offsets, - gpu_gpair *d_node_sums, int *d_node_offsets, - const GPUTrainingParam param, - const int level) { +void find_split_candidates_sorted(GPUData * data, const int level) { const int BLOCK_THREADS = 512; CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps."; - int grid_size = num_features; + int grid_size = data->n_features; find_split_candidates_sorted_kernel< BLOCK_THREADS><<>>( - items_iter, d_split_candidates, d_nodes, num_items, num_features, - d_feature_offsets, d_node_sums, d_node_offsets, param, level); + data->items_iter, data->split_candidates.data(), data->nodes.data(), + data->fvalues.size(), data->n_features, + data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(), + data->param, data->feature_flags.data(), level); dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaDeviceSynchronize()); diff --git a/plugin/updater_gpu/src/gpu_builder.cu b/plugin/updater_gpu/src/gpu_builder.cu index 14bb5616c..a8c242443 100644 --- a/plugin/updater_gpu/src/gpu_builder.cu +++ b/plugin/updater_gpu/src/gpu_builder.cu @@ -12,143 +12,17 @@ #include #include #include +#include #include #include "../../../src/common/random.h" #include "device_helpers.cuh" #include "find_split.cuh" #include "gpu_builder.cuh" #include "types_functions.cuh" +#include "gpu_data.cuh" namespace xgboost { namespace tree { -struct GPUData { - GPUData() : allocated(false), n_features(0), n_instances(0) {} - - bool allocated; - int n_features; - int n_instances; - - dh::bulk_allocator ba; - GPUTrainingParam param; - - dh::dvec fvalues; - dh::dvec fvalues_temp; - dh::dvec fvalues_cached; - dh::dvec foffsets; - dh::dvec instance_id; - dh::dvec instance_id_temp; - dh::dvec instance_id_cached; - dh::dvec feature_id; - dh::dvec node_id; - dh::dvec node_id_temp; - dh::dvec node_id_instance; - dh::dvec gpair; - dh::dvec nodes; - dh::dvec split_candidates; - dh::dvec node_sums; - dh::dvec node_offsets; - dh::dvec sort_index_in; - dh::dvec sort_index_out; - - dh::dvec cub_mem; - - ItemIter items_iter; - - void Init(const std::vector &in_fvalues, - const std::vector &in_foffsets, - const std::vector &in_instance_id, - const std::vector &in_feature_id, - const std::vector &in_gpair, bst_uint n_instances_in, - bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) { - n_features = n_features_in; - n_instances = n_instances_in; - - uint32_t max_nodes = (1 << (max_depth + 1)) - 1; - uint32_t max_nodes_level = 1 << max_depth; - - // Calculate memory for sort - size_t cub_mem_size = 0; - cub::DoubleBuffer db_key; - cub::DoubleBuffer db_value; - - cub::DeviceSegmentedRadixSort::SortPairs( - cub_mem.data(), cub_mem_size, db_key, - db_value, in_fvalues.size(), n_features, - foffsets.data(), foffsets.data() + 1); - - // Allocate memory - size_t free_memory = dh::available_memory(); - ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(), - &fvalues_cached, in_fvalues.size(), &foffsets, - in_foffsets.size(), &instance_id, in_instance_id.size(), - &instance_id_temp, in_instance_id.size(), &instance_id_cached, - in_instance_id.size(), &feature_id, in_feature_id.size(), - &node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(), - &node_id_instance, n_instances, &gpair, n_instances, &nodes, - max_nodes, &split_candidates, max_nodes_level * n_features, - &node_sums, max_nodes_level * n_features, &node_offsets, - max_nodes_level * n_features, &sort_index_in, in_fvalues.size(), - &sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size); - - if (!param_in.silent) { - const int mb_size = 1048576; - LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" - << free_memory / mb_size << " MB on " << dh::device_name(); - } - node_id.fill(0); - node_id_instance.fill(0); - - fvalues = in_fvalues; - fvalues_cached = fvalues; - foffsets = in_foffsets; - instance_id = in_instance_id; - instance_id_cached = instance_id; - feature_id = in_feature_id; - - param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, - param_in.reg_alpha, param_in.max_delta_step); - - gpair = in_gpair; - - nodes.fill(Node()); - - items_iter = thrust::make_zip_iterator(thrust::make_tuple( - thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), - fvalues.tbegin(), node_id.tbegin())); - - allocated = true; - - dh::safe_cuda(cudaGetLastError()); - } - - ~GPUData() {} - - // Reset memory for new boosting iteration - void Reset(const std::vector &in_gpair) { - CHECK(allocated); - gpair = in_gpair; - instance_id = instance_id_cached; - fvalues = fvalues_cached; - nodes.fill(Node()); - node_id_instance.fill(0); - node_id.fill(0); - } - - bool IsAllocated() { return allocated; } - - // Gather from node_id_instance into node_id according to instance_id - void GatherNodeId() { - // Update node_id for each item - auto d_node_id = node_id.data(); - auto d_node_id_instance = node_id_instance.data(); - auto d_instance_id = instance_id.data(); - - dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { - // Item item = d_items[i]; - d_node_id[i] = d_node_id_instance[d_instance_id[i]]; - }); - } -}; GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); } @@ -253,15 +127,26 @@ void GPUBuilder::Sort(int level) { } } +void GPUBuilder::ColsampleTree() { + unsigned n = static_cast( + param.colsample_bytree * gpu_data->n_features); + CHECK_GT(n, 0); + + feature_set_tree.resize(gpu_data->n_features); + std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); + std::shuffle(feature_set_tree.begin(), feature_set_tree.end(), + common::GlobalRandom()); +} + void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree) { - cudaProfilerStart(); try { dh::Timer update; dh::Timer t; this->InitData(gpair, *p_fmat, *p_tree); t.printElapsed("init data"); this->InitFirstNode(); + this->ColsampleTree(); for (int level = 0; level < param.max_depth; level++) { bool use_multiscan_algorithm = level < multiscan_levels; @@ -280,11 +165,8 @@ void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, } dh::Timer split; - find_split(gpu_data->items_iter, gpu_data->split_candidates.data(), - gpu_data->nodes.data(), (bst_uint)gpu_data->fvalues.size(), - gpu_data->n_features, gpu_data->foffsets.data(), - gpu_data->node_sums.data(), gpu_data->node_offsets.data(), - gpu_data->param, level, use_multiscan_algorithm); + find_split(gpu_data, param, level, use_multiscan_algorithm, + feature_set_tree, &feature_set_level); split.printElapsed("split"); @@ -302,22 +184,6 @@ void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, std::cerr << "Unknown exception." << std::endl; exit(-1); } - cudaProfilerStop(); -} - -float GPUBuilder::GetSubsamplingRate(MetaInfo info) { - float subsample = 1.0; - uint32_t max_nodes = (1 << (param.max_depth + 1)) - 1; - uint32_t max_nodes_level = 1 << param.max_depth; - size_t required = 10 * info.num_row + 40 * info.num_nonzero - + 64 * max_nodes + 76 * max_nodes_level * info.num_col; - size_t available = dh::available_memory(); - while (available < required) { - subsample -= 0.05; - required = 10 * info.num_row + subsample * (44 * info.num_nonzero); - } - - return subsample; } void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, @@ -325,7 +191,7 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block"; if (gpu_data->IsAllocated()) { - gpu_data->Reset(gpair); + gpu_data->Reset(gpair, param.subsample); return; } @@ -333,35 +199,6 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, MetaInfo info = fmat.info(); - // Work out if dataset will fit on GPU - float subsample = this->GetSubsamplingRate(info); - CHECK(subsample > 0.0); - if (!param.silent && subsample < param.subsample) { - LOG(CONSOLE) << "Not enough device memory for entire dataset."; - } - - // Override subsample parameter if user-specified parameter is lower - subsample = std::min(param.subsample, subsample); - - std::vector row_flags; - - if (subsample < 1.0) { - if (!param.silent && subsample < 1.0) { - LOG(CONSOLE) << "Subsampling " << subsample * 100 << "% of rows."; - } - - const RowSet &rowset = fmat.buffered_rowset(); - row_flags.resize(info.num_row); - std::bernoulli_distribution coin_flip(subsample); - auto &rnd = common::GlobalRandom(); - for (size_t i = 0; i < rowset.size(); ++i) { - const bst_uint ridx = rowset[i]; - if (gpair[ridx].hess < 0.0f) - continue; - row_flags[ridx] = coin_flip(rnd); - } - } - std::vector foffsets; foffsets.push_back(0); std::vector feature_id; @@ -382,17 +219,9 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, for (const ColBatch::Entry *it = col.data; it != col.data + col.length; it++) { bst_uint inst_id = it->index; - if (subsample < 1.0) { - if (row_flags[inst_id]) { - fvalues.push_back(it->fvalue); - instance_id.push_back(inst_id); - feature_id.push_back(i); - } - } else { fvalues.push_back(it->fvalue); instance_id.push_back(inst_id); feature_id.push_back(i); - } } foffsets.push_back(fvalues.size()); } diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh index 61ccdbbcf..ba1521d35 100644 --- a/plugin/updater_gpu/src/gpu_builder.cuh +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -23,6 +23,7 @@ class GPUBuilder { RegTree *p_tree); void UpdateNodeId(int level); + private: void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT const RegTree &tree); @@ -31,12 +32,15 @@ class GPUBuilder { void Sort(int level); void InitFirstNode(); void CopyTree(RegTree &tree); // NOLINT + void ColsampleTree(); TrainParam param; GPUData *gpu_data; + std::vector feature_set_tree; + std::vector feature_set_level; int multiscan_levels = - 5; // Number of levels before switching to sorting algorithm + 5; // Number of levels before switching to sorting algorithm }; } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_data.cuh b/plugin/updater_gpu/src/gpu_data.cuh new file mode 100644 index 000000000..f3bc675d7 --- /dev/null +++ b/plugin/updater_gpu/src/gpu_data.cuh @@ -0,0 +1,162 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include +#include +#include "device_helpers.cuh" +#include "../../src/tree/param.h" +#include "types_functions.cuh" + +namespace xgboost { +namespace tree { + +struct GPUData { + GPUData() : allocated(false), n_features(0), n_instances(0) {} + + bool allocated; + int n_features; + int n_instances; + + dh::bulk_allocator ba; + GPUTrainingParam param; + + dh::dvec fvalues; + dh::dvec fvalues_temp; + dh::dvec fvalues_cached; + dh::dvec foffsets; + dh::dvec instance_id; + dh::dvec instance_id_temp; + dh::dvec instance_id_cached; + dh::dvec feature_id; + dh::dvec node_id; + dh::dvec node_id_temp; + dh::dvec node_id_instance; + dh::dvec gpair; + dh::dvec nodes; + dh::dvec split_candidates; + dh::dvec node_sums; + dh::dvec node_offsets; + dh::dvec sort_index_in; + dh::dvec sort_index_out; + + dh::dvec cub_mem; + + dh::dvec feature_flags; + dh::dvec feature_set; + + ItemIter items_iter; + + void Init(const std::vector &in_fvalues, + const std::vector &in_foffsets, + const std::vector &in_instance_id, + const std::vector &in_feature_id, + const std::vector &in_gpair, bst_uint n_instances_in, + bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) { + n_features = n_features_in; + n_instances = n_instances_in; + + uint32_t max_nodes = (1 << (max_depth + 1)) - 1; + uint32_t max_nodes_level = 1 << max_depth; + + // Calculate memory for sort + size_t cub_mem_size = 0; + cub::DoubleBuffer db_key; + cub::DoubleBuffer db_value; + + cub::DeviceSegmentedRadixSort::SortPairs( + cub_mem.data(), cub_mem_size, db_key, + db_value, in_fvalues.size(), n_features, + foffsets.data(), foffsets.data() + 1); + + // Allocate memory + size_t free_memory = dh::available_memory(); + ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(), + &fvalues_cached, in_fvalues.size(), &foffsets, + in_foffsets.size(), &instance_id, in_instance_id.size(), + &instance_id_temp, in_instance_id.size(), &instance_id_cached, + in_instance_id.size(), &feature_id, in_feature_id.size(), + &node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(), + &node_id_instance, n_instances, &gpair, n_instances, &nodes, + max_nodes, &split_candidates, max_nodes_level * n_features, + &node_sums, max_nodes_level * n_features, &node_offsets, + max_nodes_level * n_features, &sort_index_in, in_fvalues.size(), + &sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size, + &feature_flags, n_features, &feature_set, n_features); + + if (!param_in.silent) { + const int mb_size = 1048576; + LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" + << free_memory / mb_size << " MB on " << dh::device_name(); + } + + fvalues_cached = in_fvalues; + foffsets = in_foffsets; + instance_id_cached = in_instance_id; + feature_id = in_feature_id; + + param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, + param_in.reg_alpha, param_in.max_delta_step); + + + allocated = true; + + this->Reset(in_gpair, param_in.subsample); + + items_iter = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), + fvalues.tbegin(), node_id.tbegin())); + + + dh::safe_cuda(cudaGetLastError()); + } + + ~GPUData() {} + + // Set gradient pair to 0 with p = 1 - subsample + void MarkSubsample(float subsample) { + if (subsample == 1.0) { + return; + } + + auto d_gpair = gpair.data(); + dh::BernoulliRng rng(subsample, common::GlobalRandom()()); + + dh::launch_n(n_instances, [=] __device__(int i) { + if (!rng(i)) { + d_gpair[i] = gpu_gpair(); + } + }); + } + + // Reset memory for new boosting iteration + void Reset(const std::vector &in_gpair, float subsample) { + CHECK(allocated); + gpair = in_gpair; + this->MarkSubsample(subsample); + instance_id = instance_id_cached; + fvalues = fvalues_cached; + nodes.fill(Node()); + node_id_instance.fill(0); + node_id.fill(0); + } + + bool IsAllocated() { return allocated; } + + // Gather from node_id_instance into node_id according to instance_id + void GatherNodeId() { + // Update node_id for each item + auto d_node_id = node_id.data(); + auto d_node_id_instance = node_id_instance.data(); + auto d_instance_id = instance_id.data(); + + dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { + // Item item = d_items[i]; + d_node_id[i] = d_node_id_instance[d_instance_id[i]]; + }); + } +}; +} // namespace tree +} // namespace xgboost