[GPU-Plugin] Major refactor (#2644)

* Removal of redundant code/files.
* Removal of exact namespace in GPU plugin
* Revert double precision histograms to single precision for performance on Maxwell/Kepler
This commit is contained in:
Rory Mitchell 2017-08-30 10:53:52 +12:00 committed by GitHub
parent 39adba51c5
commit 19a53814ce
26 changed files with 2170 additions and 5637 deletions

View File

@ -96,7 +96,6 @@ if(PLUGIN_UPDATER_GPU)
cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC) cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC)
target_link_libraries(gpuxgboost nccl) target_link_libraries(gpuxgboost nccl)
list(APPEND LINK_LIBRARIES gpuxgboost) list(APPEND LINK_LIBRARIES gpuxgboost)
list(APPEND SOURCES plugin/updater_gpu/src/register_updater_gpu.cc)
endif() endif()
add_library(objxgboost OBJECT ${SOURCES}) add_library(objxgboost OBJECT ${SOURCES})

View File

@ -214,7 +214,6 @@ pylint:
flake8 --ignore E501 tests/python flake8 --ignore E501 tests/python
test: $(ALL_TEST) test: $(ALL_TEST)
./plugin/updater_gpu/test/cpp/generate_data.sh
$(ALL_TEST) $(ALL_TEST)
check: test check: test

View File

@ -1,285 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <cstdio>
#include <stdexcept>
#include <string>
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "cub/cub.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
template <typename gpair_t>
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
bool missing_left) {
gpair_t left = scan;
if (missing_left) {
left += missing;
}
gpair_t right = parent_sum - left;
float left_gain = CalcGain(param, left.grad, left.hess);
float right_gain = CalcGain(param, right.grad, right.hess);
return left_gain + right_gain - parent_gain;
}
template <typename gpair_t>
__device__ float inline loss_chg_missing(const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// Total number of nodes in tree, given depth
__host__ __device__ inline int n_nodes(int depth) {
return (1 << (depth + 1)) - 1;
}
// Number of nodes at this level of the tree
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
// Whether a node is currently being processed at current depth
__host__ __device__ inline bool is_active(int nidx, int depth) {
return nidx >= n_nodes(depth - 1);
}
__host__ __device__ inline int parent_nidx(int nidx) { return (nidx - 1) / 2; }
__host__ __device__ inline int left_child_nidx(int nidx) {
return nidx * 2 + 1;
}
__host__ __device__ inline int right_child_nidx(int nidx) {
return nidx * 2 + 2;
}
__host__ __device__ inline bool is_left_child(int nidx) {
return nidx % 2 == 1;
}
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
inline void flag_nodes(const thrust::host_vector<Node>& nodes,
std::vector<NodeType>* node_flags, int nid,
NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node& n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
inline void dense2sparse_tree(RegTree* p_tree,
thrust::device_ptr<Node> nodes_begin,
thrust::device_ptr<Node> nodes_end,
const TrainParam& param) {
RegTree& tree = *p_tree;
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node& n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess;
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess;
nid++;
}
}
}
// Set gradient pair to 0 with p = 1 - subsample
inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample,
int offset) {
if (subsample == 1.0) {
return;
}
dh::dvec<bst_gpair>& gpair = *p_gpair;
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(gpair.device_idx(), gpair.size(), [=] __device__(int i) {
if (!rng(i + offset)) {
d_gpair[i] = bst_gpair();
}
});
}
// Set gradient pair to 0 with p = 1 - subsample
inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample) {
int offset = 0;
subsample_gpair(p_gpair, subsample, offset);
}
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
return features;
}
struct GpairCallbackOp {
// Running prefix
bst_gpair_precise running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(bst_gpair_precise()) {}
__device__ bst_gpair_precise operator()(bst_gpair_precise block_aggregate) {
bst_gpair_precise old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
/**
* @brief Helper function to sort the pairs using cub's segmented RadixSortPairs
* @param tmp_mem cub temporary memory info
* @param keys keys double-buffer array
* @param vals the values double-buffer array
* @param nVals number of elements in the array
* @param nSegs number of segments
* @param offsets the segments
*/
template <typename T1, typename T2>
void segmentedSort(dh::CubMemory* tmp_mem, dh::dvec2<T1>* keys,
dh::dvec2<T2>* vals, int nVals, int nSegs,
const dh::dvec<int>& offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.data(),
offsets.data() + 1, start, end));
tmp_mem->LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals, nSegs,
offsets.data(), offsets.data() + 1, start, end));
}
/**
* @brief Helper function to perform device-wide sum-reduction
* @param tmp_mem cub temporary memory info
* @param in the input array to be reduced
* @param out the output reduced value
* @param nVals number of elements in the input array
*/
template <typename T>
void sumReduction(dh::CubMemory& tmp_mem, dh::dvec<T>& in, dh::dvec<T>& out,
int nVals) {
size_t tmpSize;
dh::safe_cuda(
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize,
in.data(), out.data(), nVals));
}
/**
* @brief Fill a given constant value across all elements in the buffer
* @param out the buffer to be filled
* @param len number of elements i the buffer
* @param def default value to be filled
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void fillConst(int device_idx, T* out, int len, T def) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, len,
[=] __device__(int i) { out[i] = def; });
}
/**
* @brief gather elements
* @param out1 output gathered array for the first buffer
* @param in1 first input buffer
* @param out2 output gathered array for the second buffer
* @param in2 second input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2,
const int* instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
out1[i] = v1;
out2[i] = v2;
});
}
/**
* @brief gather elements
* @param out output gathered array
* @param in input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T* out, const T* in, const int* instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});
}
} // namespace tree
} // namespace xgboost

View File

@ -2,13 +2,11 @@
* Copyright 2017 XGBoost contributors * Copyright 2017 XGBoost contributors
*/ */
#pragma once #pragma once
#include <xgboost/logging.h>
#include <thrust/binary_search.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/random.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system/cuda/execution_policy.h> #include <thrust/system/cuda/execution_policy.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <xgboost/logging.h>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
@ -20,7 +18,6 @@
#include "nccl.h" #include "nccl.h"
// Uncomment to enable // Uncomment to enable
// #define DEVICE_TIMER
#define TIMERS #define TIMERS
namespace dh { namespace dh {
@ -61,25 +58,6 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
return code; return code;
} }
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort){
std::stringstream ss;
ss << file << "(" << line << ")";
std::string file_and_line;
ss >> file_and_line;
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
}
}
}
inline int n_visible_devices() { inline int n_visible_devices() {
int n_visgpus = 0; int n_visgpus = 0;
@ -237,7 +215,8 @@ __device__ range block_stride_range(T begin, T end) {
return r; return r;
} }
// Threadblock iterates over range, filling with value. Requires all threads in block to be active. // Threadblock iterates over range, filling with value. Requires all threads in
// block to be active.
template <typename IterT, typename ValueT> template <typename IterT, typename ValueT>
__device__ void block_fill(IterT begin, size_t n, ValueT value) { __device__ void block_fill(IterT begin, size_t n, ValueT value) {
for (auto i : block_stride_range(static_cast<size_t>(0), n)) { for (auto i : block_stride_range(static_cast<size_t>(0), n)) {
@ -485,7 +464,7 @@ class bulk_allocator {
} }
template <typename... Args> template <typename... Args>
void allocate(int device_idx, bool silent ,Args... args) { void allocate(int device_idx, bool silent, Args... args) {
size_t size = get_size_bytes(args...); size_t size = get_size_bytes(args...);
char *ptr = allocate_device(device_idx, size, MemoryT); char *ptr = allocate_device(device_idx, size, MemoryT);
@ -496,8 +475,7 @@ class bulk_allocator {
_size.push_back(size); _size.push_back(size);
_device_idx.push_back(device_idx); _device_idx.push_back(device_idx);
if(!silent) if (!silent) {
{
const int mb_size = 1048576; const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << size / mb_size << "MB on [" << device_idx LOG(CONSOLE) << "Allocated " << size / mb_size << "MB on [" << device_idx
<< "] " << device_name(device_idx) << ", " << "] " << device_name(device_idx) << ", "
@ -545,7 +523,6 @@ struct CubMemory {
bool IsAllocated() { return d_temp_storage != NULL; } bool IsAllocated() { return d_temp_storage != NULL; }
}; };
/* /*
* Utility functions * Utility functions
*/ */
@ -653,24 +630,6 @@ inline void multi_launch_n(size_t n, int n_devices, L lambda) {
#endif #endif
} }
/*
* Random
*/
struct BernoulliRng {
float p;
int seed;
__host__ __device__ BernoulliRng(float p, int seed) : p(p), seed(seed) {}
__host__ __device__ bool operator()(const int i) const {
thrust::default_random_engine rng(seed);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng) <= p;
}
};
/** /**
* @brief Helper macro to measure timing on GPU * @brief Helper macro to measure timing on GPU
* @param call the GPU call * @param call the GPU call
@ -687,9 +646,9 @@ struct BernoulliRng {
// Load balancing search // Load balancing search
template <typename coordinate_t, typename segments_t, typename offset_t> template <typename coordinate_t, typename segments_t, typename offset_t>
void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates, int num_tiles, void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates,
int tile_size, segments_t segments, offset_t num_rows, int num_tiles, int tile_size, segments_t segments,
offset_t num_elements) { offset_t num_rows, offset_t num_elements) {
dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) { dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) {
offset_t diagonal = idx * tile_size; offset_t diagonal = idx * tile_size;
coordinate_t tile_coordinate; coordinate_t tile_coordinate;
@ -761,8 +720,9 @@ __global__ void LbsKernel(coordinate_t *d_coordinates,
} }
template <typename func_t, typename segments_iter, typename offset_t> template <typename func_t, typename segments_iter, typename offset_t>
void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
segments_iter segments, offset_t num_segments, func_t f) { offset_t count, segments_iter segments,
offset_t num_segments, func_t f) {
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t; typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));
const int BLOCK_THREADS = 256; const int BLOCK_THREADS = 256;
@ -774,8 +734,8 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t cou
coordinate_t *tmp_tile_coordinates = coordinate_t *tmp_tile_coordinates =
reinterpret_cast<coordinate_t *>(temp_memory->d_temp_storage); reinterpret_cast<coordinate_t *>(temp_memory->d_temp_storage);
FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles, BLOCK_THREADS, segments, FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles,
num_segments, count); BLOCK_THREADS, segments, num_segments, count);
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t> LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t>
<<<num_tiles, BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f, <<<num_tiles, BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
@ -783,22 +743,24 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t cou
} }
template <typename func_t, typename offset_t> template <typename func_t, typename offset_t>
void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments, func_t f) { void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments,
func_t f) {
CHECK(count % num_segments == 0) << "Data is not dense."; CHECK(count % num_segments == 0) << "Data is not dense.";
launch_n(device_idx, count, [=]__device__(offset_t idx) launch_n(device_idx, count, [=] __device__(offset_t idx) {
{
offset_t segment = idx / (count / num_segments); offset_t segment = idx / (count / num_segments);
f(idx, segment); f(idx, segment);
}); });
} }
/** /**
* \fn template <typename func_t, typename segments_iter, typename offset_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, segments_iter segments, offset_t num_segments, bool is_dense, func_t f) * \fn template <typename func_t, typename segments_iter, typename offset_t>
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
* segments_iter segments, offset_t num_segments, bool is_dense, func_t f)
* *
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function * \brief Load balancing search function. Reads a CSR type matrix description
* to be executed on each element. Search 'modern GPU load balancing search' for more * and allows a function to be executed on each element. Search 'modern GPU load
* information. * balancing search' for more information.
* *
* \author Rory * \author Rory
* \date 7/9/2017 * \date 7/9/2017
@ -817,12 +779,106 @@ void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments, fu
template <typename func_t, typename segments_iter, typename offset_t> template <typename func_t, typename segments_iter, typename offset_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
segments_iter segments, offset_t num_segments, bool is_dense, func_t f) { segments_iter segments, offset_t num_segments, bool is_dense,
func_t f) {
if (is_dense) { if (is_dense) {
DenseTransformLbs(device_idx, count, num_segments, f); DenseTransformLbs(device_idx, count, num_segments, f);
} } else {
else { SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments,
SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments, f); f);
} }
} }
/**
* @brief Helper function to sort the pairs using cub's segmented RadixSortPairs
* @param tmp_mem cub temporary memory info
* @param keys keys double-buffer array
* @param vals the values double-buffer array
* @param nVals number of elements in the array
* @param nSegs number of segments
* @param offsets the segments
*/
template <typename T1, typename T2>
void segmentedSort(dh::CubMemory *tmp_mem, dh::dvec2<T1> *keys,
dh::dvec2<T2> *vals, int nVals, int nSegs,
const dh::dvec<int> &offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.data(),
offsets.data() + 1, start, end));
tmp_mem->LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals,
nSegs, offsets.data(), offsets.data() + 1, start, end));
}
/**
* @brief Helper function to perform device-wide sum-reduction
* @param tmp_mem cub temporary memory info
* @param in the input array to be reduced
* @param out the output reduced value
* @param nVals number of elements in the input array
*/
template <typename T>
void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
int nVals) {
size_t tmpSize;
dh::safe_cuda(
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize,
in.data(), out.data(), nVals));
}
/**
* @brief Fill a given constant value across all elements in the buffer
* @param out the buffer to be filled
* @param len number of elements i the buffer
* @param def default value to be filled
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void fillConst(int device_idx, T *out, int len, T def) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, len,
[=] __device__(int i) { out[i] = def; });
}
/**
* @brief gather elements
* @param out1 output gathered array for the first buffer
* @param in1 first input buffer
* @param out2 output gathered array for the second buffer
* @param in2 second input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T1 *out1, const T1 *in1, T2 *out2, const T2 *in2,
const int *instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
out1[i] = v1;
out2[i] = v2;
});
}
/**
* @brief gather elements
* @param out output gathered array
* @param in input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});
}
} // namespace dh } // namespace dh

View File

@ -1,184 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../../../src/tree/param.h"
#include "../common.cuh"
#include "node.cuh"
#include "../types.cuh"
namespace xgboost {
namespace tree {
namespace exact {
/**
* @enum ArgMaxByKeyAlgo best_split_evaluation.cuh
* @brief Help decide which algorithm to use for multi-argmax operation
*/
enum ArgMaxByKeyAlgo {
/** simplest, use gmem-atomics for all updates */
ABK_GMEM = 0,
/** use smem-atomics for updates (when number of keys are less) */
ABK_SMEM
};
/** max depth until which to use shared mem based atomics for argmax */
static const int MAX_ABK_LEVELS = 3;
HOST_DEV_INLINE Split maxSplit(Split a, Split b) {
Split out;
if (a.score < b.score) {
out.score = b.score;
out.index = b.index;
} else if (a.score == b.score) {
out.score = a.score;
out.index = (a.index < b.index) ? a.index : b.index;
} else {
out.score = a.score;
out.index = a.index;
}
return out;
}
DEV_INLINE void atomicArgMax(Split* address, Split val) {
unsigned long long* intAddress = (unsigned long long*)address;
unsigned long long old = *intAddress;
unsigned long long assumed;
do {
assumed = old;
Split res = maxSplit(val, *(Split*)&assumed);
old = atomicCAS(intAddress, assumed, *(uint64_t*)&res);
} while (assumed != old);
}
template <typename node_id_t>
DEV_INLINE void argMaxWithAtomics(
int id, Split* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const GPUTrainingParam& param) {
int nodeId = nodeAssigns[id];
///@todo: this is really a bad check! but will be fixed when we move
/// to key-based reduction
if ((id == 0) ||
!((nodeId == nodeAssigns[id - 1]) && (colIds[id] == colIds[id - 1]) &&
(vals[id] == vals[id - 1]))) {
if (nodeId != UNUSED_NODE) {
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
bst_gpair colSum = gradSums[sumId];
int uid = nodeId - nodeStart;
Node<node_id_t> n = nodes[nodeId];
bst_gpair parentSum = n.gradSum;
float parentGain = n.score;
bool tmp;
Split s;
bst_gpair missing = parentSum - colSum;
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
param, tmp);
s.index = id;
atomicArgMax(nodeSplits + uid, s);
} // end if nodeId != UNUSED_NODE
} // end if id == 0 ...
}
template <typename node_id_t>
__global__ void atomicArgMaxByKeyGmem(
Split* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums,
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
const TrainParam param) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < len; id += stride) {
argMaxWithAtomics(id, nodeSplits, gradScans, gradSums, vals, colIds,
nodeAssigns, nodes, nUniqKeys, nodeStart, len, GPUTrainingParam(param));
}
}
template <typename node_id_t>
__global__ void atomicArgMaxByKeySmem(
Split* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums,
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
const TrainParam param) {
extern __shared__ char sArr[];
Split* sNodeSplits = reinterpret_cast<Split*>(sArr);
int tid = threadIdx.x;
Split defVal;
#pragma unroll 1
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
sNodeSplits[i] = defVal;
}
__syncthreads();
int id = tid + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < len; id += stride) {
argMaxWithAtomics(id, sNodeSplits, gradScans, gradSums, vals, colIds,
nodeAssigns, nodes, nUniqKeys, nodeStart, len, param);
}
__syncthreads();
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
Split s = sNodeSplits[i];
atomicArgMax(nodeSplits + i, s);
}
}
/**
* @brief Performs argmax_by_key functionality but for cases when keys need not
* occur contiguously
* @param nodeSplits will contain information on best split for each node
* @param gradScans exclusive sum on sorted segments for each col
* @param gradSums gradient sum for each column in DMatrix based on to node-ids
* @param vals feature values
* @param colIds column index for each element in the feature values array
* @param nodeAssigns node-id assignments to each element in DMatrix
* @param nodes pointer to all nodes for this tree in BFS order
* @param nUniqKeys number of unique node-ids in this level
* @param nodeStart start index of the node-ids in this level
* @param len number of elements
* @param param training parameters
* @param algo which algorithm to use for argmax_by_key
*/
template <typename node_id_t, int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
void argMaxByKey(Split* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals,
const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param,
ArgMaxByKeyAlgo algo) {
fillConst<Split, BLKDIM, ITEMS_PER_THREAD>(dh::get_device_idx(param.gpu_id),
nodeSplits, nUniqKeys, Split());
int nBlks = dh::div_round_up(len, ITEMS_PER_THREAD * BLKDIM);
switch (algo) {
case ABK_GMEM:
atomicArgMaxByKeyGmem<node_id_t><<<nBlks, BLKDIM>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
case ABK_SMEM:
atomicArgMaxByKeySmem<
node_id_t><<<nBlks, BLKDIM, sizeof(Split) * nUniqKeys>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
default:
throw std::runtime_error("argMaxByKey: Bad algo passed!");
}
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,209 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../common.cuh"
namespace xgboost {
namespace tree {
namespace exact {
/**
* @struct Pair fused_scan_reduce_by_key.cuh
* @brief Pair used for key basd scan operations on bst_gpair
*/
struct Pair {
int key;
bst_gpair value;
};
/** define a key that's not used at all in the entire boosting process */
static const int NONE_KEY = -100;
/**
* @brief Allocate temporary buffers needed for scan operations
* @param tmpScans gradient buffer
* @param tmpKeys keys buffer
* @param size number of elements that will be scanned
*/
template <int BLKDIM_L1L3 = 256>
int scanTempBufferSize(int size) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
return nBlks;
}
struct AddByKey {
template <typename T>
HOST_DEV_INLINE T operator()(const T& first, const T& second) const {
T result;
if (first.key == second.key) {
result.key = first.key;
result.value = first.value + second.value;
} else {
result.key = second.key;
result.value = second.value;
}
return result;
}
};
/**
* @brief Gradient value getter function
* @param id the index into the vals or instIds array to which to fetch
* @param vals the gradient value buffer
* @param instIds instance index buffer
* @return the expected gradient value
*/
HOST_DEV_INLINE bst_gpair get(int id, const bst_gpair* vals, const int* instIds) {
id = instIds[id];
return vals[id];
}
template <typename node_id_t, int BLKDIM_L1L3>
__global__ void cubScanByKeyL1(bst_gpair* scans, const bst_gpair* vals,
const int* instIds, bst_gpair* mScans,
int* mKeys, const node_id_t* keys, int nUniqKeys,
const int* colIds, node_id_t nodeStart,
const int size) {
Pair rootPair = {NONE_KEY, bst_gpair(0.f, 0.f)};
int myKey;
bst_gpair myValue;
typedef cub::BlockScan<Pair, BLKDIM_L1L3> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
Pair threadData;
int tid = blockIdx.x * BLKDIM_L1L3 + threadIdx.x;
if (tid < size) {
myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
myValue = get(tid, vals, instIds);
} else {
myKey = NONE_KEY;
myValue = 0.f;
}
threadData.key = myKey;
threadData.value = myValue;
// get previous key, especially needed for the last thread in this block
// in order to pass on the partial scan values.
// this statement MUST appear before the checks below!
// else, the result of this shuffle operation will be undefined
int previousKey = __shfl_up(myKey, 1);
// Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());
if (tid < size) {
scans[tid] = threadData.value;
} else {
return;
}
if (threadIdx.x == BLKDIM_L1L3 - 1) {
threadData.value =
(myKey == previousKey) ? threadData.value : bst_gpair(0.0f, 0.0f);
mKeys[blockIdx.x] = myKey;
mScans[blockIdx.x] = threadData.value + myValue;
}
}
template <int BLKSIZE>
__global__ void cubScanByKeyL2(bst_gpair* mScans, int* mKeys, int mLength) {
typedef cub::BlockScan<Pair, BLKSIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
Pair threadData;
__shared__ typename BlockScan::TempStorage temp_storage;
for (int i = threadIdx.x; i < mLength; i += BLKSIZE - 1) {
threadData.key = mKeys[i];
threadData.value = mScans[i];
BlockScan(temp_storage).InclusiveScan(threadData, threadData, AddByKey());
mScans[i] = threadData.value;
__syncthreads();
}
}
template <typename node_id_t, int BLKDIM_L1L3>
__global__ void cubScanByKeyL3(bst_gpair* sums, bst_gpair* scans,
const bst_gpair* vals, const int* instIds,
const bst_gpair* mScans, const int* mKeys,
const node_id_t* keys, int nUniqKeys,
const int* colIds, node_id_t nodeStart,
const int size) {
int relId = threadIdx.x;
int tid = (blockIdx.x * BLKDIM_L1L3) + relId;
// to avoid the following warning from nvcc:
// __shared__ memory variable with non-empty constructor or destructor
// (potential race between threads)
__shared__ char gradBuff[sizeof(bst_gpair)];
__shared__ int s_mKeys;
bst_gpair* s_mScans = reinterpret_cast<bst_gpair*>(gradBuff);
if (tid >= size) return;
// cache block-wide partial scan info
if (relId == 0) {
s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : NONE_KEY;
s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : bst_gpair();
}
int myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
int previousKey = tid == 0 ? NONE_KEY : abs2uniqKey(tid - 1, keys, colIds,
nodeStart, nUniqKeys);
bst_gpair myValue = scans[tid];
__syncthreads();
if (blockIdx.x > 0 && s_mKeys == previousKey) {
myValue += s_mScans[0];
}
if (tid == size - 1) {
sums[previousKey] = myValue + get(tid, vals, instIds);
}
if ((previousKey != myKey) && (previousKey >= 0)) {
sums[previousKey] = myValue;
myValue = bst_gpair(0.0f, 0.0f);
}
scans[tid] = myValue;
}
/**
* @brief Performs fused reduce and scan by key functionality. It is assumed
* that
* the keys occur contiguously!
* @param sums the output gradient reductions for each element performed
* key-wise
* @param scans the output gradient scans for each element performed key-wise
* @param vals the gradients evaluated for each observation.
* @param instIds instance ids for each element
* @param keys keys to be used to segment the reductions. They need not occur
* contiguously in contrast to scan_by_key. Currently, we need one key per
* value in the 'vals' array.
* @param size number of elements in the 'vals' array
* @param nUniqKeys max number of uniq keys found per column
* @param nCols number of columns
* @param tmpScans temporary scan buffer needed for cub-pyramid algo
* @param tmpKeys temporary key buffer needed for cub-pyramid algo
* @param colIds column indices for each element in the array
* @param nodeStart index of the leftmost node in the current level
*/
template <typename node_id_t, int BLKDIM_L1L3 = 256, int BLKDIM_L2 = 512>
void reduceScanByKey(bst_gpair* sums, bst_gpair* scans, const bst_gpair* vals,
const int* instIds, const node_id_t* keys, int size,
int nUniqKeys, int nCols, bst_gpair* tmpScans,
int* tmpKeys, const int* colIds, node_id_t nodeStart) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
cudaMemset(sums, 0, nUniqKeys * nCols * sizeof(bst_gpair));
cubScanByKeyL1<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
nodeStart, size);
cubScanByKeyL2<BLKDIM_L2><<<1, BLKDIM_L2>>>(tmpScans, tmpKeys, nBlks);
cubScanByKeyL3<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
sums, scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
nodeStart, size);
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,386 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include "../../../../src/tree/param.h"
#include "../common.cuh"
#include "argmax_by_key.cuh"
#include "fused_scan_reduce_by_key.cuh"
#include "node.cuh"
#include "split2node.cuh"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
namespace exact {
template <typename node_id_t>
__global__ void initRootNode(Node<node_id_t>* nodes, const bst_gpair* sums,
const TrainParam param) {
// gradients already evaluated inside transferGrads
Node<node_id_t> n;
n.gradSum = sums[0];
n.score = CalcGain(param, n.gradSum.grad, n.gradSum.hess);
n.weight = CalcWeight(param, n.gradSum.grad, n.gradSum.hess);
n.id = 0;
nodes[0] = n;
}
template <typename node_id_t>
__global__ void assignColIds(int* colIds, const int* colOffsets) {
int myId = blockIdx.x;
int start = colOffsets[myId];
int end = colOffsets[myId + 1];
for (int id = start + threadIdx.x; id < end; id += blockDim.x) {
colIds[id] = myId;
}
}
template <typename node_id_t>
__global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
const Node<node_id_t>* nodes, int nRows) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
if (id >= nRows) {
return;
}
// if this element belongs to none of the currently active node-id's
node_id_t nId = nodeIdsPerInst[id];
if (nId == UNUSED_NODE) {
return;
}
const Node<node_id_t> n = nodes[nId];
node_id_t result;
if (n.isLeaf() || n.isUnused()) {
result = UNUSED_NODE;
} else if (n.isDefaultLeft()) {
result = (2 * n.id) + 1;
} else {
result = (2 * n.id) + 2;
}
nodeIdsPerInst[id] = result;
}
template <typename node_id_t>
__global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
const node_id_t* nodeIds, const int* instId,
const Node<node_id_t>* nodes,
const int* colOffsets, const float* vals,
int nVals, int nCols) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < nVals; id += stride) {
// fusing generation of indices for node locations
nodeLocations[id] = id;
// using nodeIds here since the previous kernel would have updated
// the nodeIdsPerInst with all default assignments
int nId = nodeIds[id];
// if this element belongs to none of the currently active node-id's
if (nId != UNUSED_NODE) {
const Node<node_id_t> n = nodes[nId];
int colId = n.colIdx;
// printf("nid=%d colId=%d id=%d\n", nId, colId, id);
int start = colOffsets[colId];
int end = colOffsets[colId + 1];
///@todo: too much wasteful threads!!
if ((id >= start) && (id < end) && !(n.isLeaf() || n.isUnused())) {
node_id_t result = (2 * n.id) + 1 + (vals[id] >= n.threshold);
nodeIdsPerInst[instId[id]] = result;
}
}
}
}
template <typename node_id_t>
__global__ void markLeavesKernel(Node<node_id_t>* nodes, int len) {
int id = (blockIdx.x * blockDim.x) + threadIdx.x;
if ((id < len) && !nodes[id].isUnused()) {
int lid = (id << 1) + 1;
int rid = (id << 1) + 2;
if ((lid >= len) || (rid >= len)) {
nodes[id].score = -FLT_MAX; // bottom-most nodes
} else if (nodes[lid].isUnused() && nodes[rid].isUnused()) {
nodes[id].score = -FLT_MAX; // unused child nodes
}
}
}
// unit test forward declaration for friend function access
template <typename node_id_t>
void testSmallData();
template <typename node_id_t>
void testLargeData();
template <typename node_id_t>
void testAllocate();
template <typename node_id_t>
void testMarkLeaves();
template <typename node_id_t>
void testDense2Sparse();
template <typename node_id_t>
class GPUBuilder;
template <typename node_id_t>
std::shared_ptr<xgboost::DMatrix> setupGPUBuilder(
const std::string& file,
xgboost::tree::exact::GPUBuilder<node_id_t>& builder);
template <typename node_id_t>
class GPUBuilder {
public:
GPUBuilder() : allocated(false) {}
~GPUBuilder() {}
void Init(const TrainParam& p) {
param = p;
maxNodes = (1 << (param.max_depth + 1)) - 1;
maxLeaves = 1 << param.max_depth;
}
void UpdateParam(const TrainParam& param) { this->param = param; }
/// @note: Update should be only after Init!!
void Update(const std::vector<bst_gpair>& gpair, DMatrix* hMat,
RegTree* hTree) {
if (!allocated) {
setupOneTimeData(*hMat);
}
for (int i = 0; i < param.max_depth; ++i) {
if (i == 0) {
// make sure to start on a fresh tree with sorted values!
vals.current_dvec() = vals_cached;
instIds.current_dvec() = instIds_cached;
transferGrads(gpair);
}
int nNodes = 1 << i;
node_id_t nodeStart = nNodes - 1;
initNodeData(i, nodeStart, nNodes);
findSplit(i, nodeStart, nNodes);
}
// mark all the used nodes with unused children as leaf nodes
markLeaves();
dense2sparse(hTree);
}
private:
friend void testSmallData<node_id_t>();
friend void testLargeData<node_id_t>();
friend void testAllocate<node_id_t>();
friend void testMarkLeaves<node_id_t>();
friend void testDense2Sparse<node_id_t>();
friend std::shared_ptr<xgboost::DMatrix> setupGPUBuilder<node_id_t>(
const std::string& file, GPUBuilder<node_id_t>& builder);
TrainParam param;
/** whether we have initialized memory already (so as not to repeat!) */
bool allocated;
/** feature values stored in column-major compressed format */
dh::dvec2<float> vals;
dh::dvec<float> vals_cached;
/** corresponding instance id's of these featutre values */
dh::dvec2<int> instIds;
dh::dvec<int> instIds_cached;
/** column offsets for these feature values */
dh::dvec<int> colOffsets;
dh::dvec<bst_gpair> gradsInst;
dh::dvec2<node_id_t> nodeAssigns;
dh::dvec2<int> nodeLocations;
dh::dvec<Node<node_id_t>> nodes;
dh::dvec<node_id_t> nodeAssignsPerInst;
dh::dvec<bst_gpair> gradSums;
dh::dvec<bst_gpair> gradScans;
dh::dvec<Split> nodeSplits;
int nVals;
int nRows;
int nCols;
int maxNodes;
int maxLeaves;
dh::CubMemory tmp_mem;
dh::dvec<bst_gpair> tmpScanGradBuff;
dh::dvec<int> tmpScanKeyBuff;
dh::dvec<int> colIds;
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
void findSplit(int level, node_id_t nodeStart, int nNodes) {
reduceScanByKey(gradSums.data(), gradScans.data(), gradsInst.data(),
instIds.current(), nodeAssigns.current(), nVals, nNodes,
nCols, tmpScanGradBuff.data(), tmpScanKeyBuff.data(),
colIds.data(), nodeStart);
argMaxByKey(nodeSplits.data(), gradScans.data(), gradSums.data(),
vals.current(), colIds.data(), nodeAssigns.current(),
nodes.data(), nNodes, nodeStart, nVals, param,
level <= MAX_ABK_LEVELS ? ABK_SMEM : ABK_GMEM);
split2node(nodes.data(), nodeSplits.data(), gradScans.data(),
gradSums.data(), vals.current(), colIds.data(),
colOffsets.data(), nodeAssigns.current(), nNodes, nodeStart,
nCols, param);
}
void allocateAllData(int offsetSize) {
int tmpBuffSize = scanTempBufferSize(nVals);
ba.allocate(dh::get_device_idx(param.gpu_id), param.silent, &vals, nVals, &vals_cached,
nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets,
offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
nRows, &gradSums, maxLeaves * nCols, &gradScans, nVals,
&nodeSplits, maxLeaves, &tmpScanGradBuff, tmpBuffSize,
&tmpScanKeyBuff, tmpBuffSize, &colIds, nVals);
}
void setupOneTimeData(DMatrix& hMat) {
size_t free_memory = dh::available_memory(dh::get_device_idx(param.gpu_id));
if (!hMat.SingleColBlock()) {
throw std::runtime_error("exact::GPUBuilder - must have 1 column block");
}
std::vector<float> fval;
std::vector<int> fId, offset;
convertToCsc(hMat, fval, fId, offset);
allocateAllData((int)offset.size());
transferAndSortData(fval, fId, offset);
allocated = true;
}
void convertToCsc(DMatrix& hMat, std::vector<float>& fval,
std::vector<int>& fId, std::vector<int>& offset) {
MetaInfo info = hMat.info();
nRows = info.num_row;
nCols = info.num_col;
offset.reserve(nCols + 1);
offset.push_back(0);
fval.reserve(nCols * nRows);
fId.reserve(nCols * nRows);
// in case you end up with a DMatrix having no column access
// then make sure to enable that before copying the data!
if (!hMat.HaveColAccess()) {
const std::vector<bool> enable(nCols, true);
hMat.InitColAccess(enable, 1, nRows);
}
dmlc::DataIter<ColBatch>* iter = hMat.ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch& batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst& col = batch[i];
for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
it++) {
int inst_id = static_cast<int>(it->index);
fval.push_back(it->fvalue);
fId.push_back(inst_id);
}
offset.push_back(fval.size());
}
}
nVals = fval.size();
}
void transferAndSortData(const std::vector<float>& fval,
const std::vector<int>& fId,
const std::vector<int>& offset) {
vals.current_dvec() = fval;
instIds.current_dvec() = fId;
colOffsets = offset;
segmentedSort<float, int>(&tmp_mem, &vals, &instIds, nVals, nCols,
colOffsets);
vals_cached = vals.current_dvec();
instIds_cached = instIds.current_dvec();
assignColIds<node_id_t><<<nCols, 512>>>(colIds.data(), colOffsets.data());
}
void transferGrads(const std::vector<bst_gpair>& gpair) {
// HACK
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
sizeof(bst_gpair) * nRows,
cudaMemcpyHostToDevice));
// evaluate the full-grad reduction for the root node
sumReduction<bst_gpair>(tmp_mem, gradsInst, gradSums, nRows);
}
void initNodeData(int level, node_id_t nodeStart, int nNodes) {
// all instances belong to root node at the beginning!
if (level == 0) {
nodes.fill(Node<node_id_t>());
nodeAssigns.current_dvec().fill(0);
nodeAssignsPerInst.fill(0);
// for root node, just update the gradient/score/weight/id info
// before splitting it! Currently all data is on GPU, hence this
// stupid little kernel
initRootNode<<<1, 1>>>(nodes.data(), gradSums.data(), param);
} else {
const int BlkDim = 256;
const int ItemsPerThread = 4;
// assign default node ids first
int nBlks = dh::div_round_up(nRows, BlkDim);
fillDefaultNodeIds<<<nBlks, BlkDim>>>(nodeAssignsPerInst.data(),
nodes.data(), nRows);
// evaluate the correct child indices of non-missing values next
nBlks = dh::div_round_up(nVals, BlkDim * ItemsPerThread);
assignNodeIds<<<nBlks, BlkDim>>>(
nodeAssignsPerInst.data(), nodeLocations.current(),
nodeAssigns.current(), instIds.current(), nodes.data(),
colOffsets.data(), vals.current(), nVals, nCols);
// gather the node assignments across all other columns too
gather<node_id_t>(dh::get_device_idx(param.gpu_id), nodeAssigns.current(),
nodeAssignsPerInst.data(), instIds.current(), nVals);
sortKeys(level);
}
}
void sortKeys(int level) {
// segmented-sort the arrays based on node-id's
// but we don't need more than level+1 bits for sorting!
segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
colOffsets, 0, level + 1);
gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
vals.buff().selector ^= 1;
instIds.buff().selector ^= 1;
}
void markLeaves() {
const int BlkDim = 128;
int nBlks = dh::div_round_up(maxNodes, BlkDim);
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
}
void dense2sparse(RegTree* p_tree) {
RegTree& tree = *p_tree;
std::vector<Node<node_id_t>> hNodes = nodes.as_vector();
int nodeId = 0;
for (int i = 0; i < maxNodes; ++i) {
const Node<node_id_t>& n = hNodes[i];
if ((i != 0) && hNodes[i].isLeaf()) {
tree[nodeId].set_leaf(n.weight * param.learning_rate);
tree.stat(nodeId).sum_hess = n.gradSum.hess;
++nodeId;
} else if (!hNodes[i].isUnused()) {
tree.AddChilds(nodeId);
tree[nodeId].set_split(n.colIdx, n.threshold, n.dir == LeftDir);
tree.stat(nodeId).loss_chg = n.score;
tree.stat(nodeId).sum_hess = n.gradSum.hess;
tree.stat(nodeId).base_weight = n.weight;
tree[tree[nodeId].cleft()].set_leaf(0);
tree[tree[nodeId].cright()].set_leaf(0);
++nodeId;
}
}
}
};
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,156 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../common.cuh"
namespace xgboost {
namespace tree {
namespace exact {
/**
* @enum DefaultDirection node.cuh
* @brief Default direction to be followed in case of missing values
*/
enum DefaultDirection {
/** move to left child */
LeftDir = 0,
/** move to right child */
RightDir
};
/** used to assign default id to a Node */
static const int UNUSED_NODE = -1;
/**
* @struct Split node.cuh
* @brief Abstraction of a possible split in the decision tree
*/
struct Split {
/** the optimal gain score for this node */
float score;
/** index where to split in the DMatrix */
int index;
HOST_DEV_INLINE Split() : score(-FLT_MAX), index(INT_MAX) {}
/**
* @brief Whether the split info is valid to be used to create a new child
* @param minSplitLoss minimum score above which decision to split is made
* @return true if splittable, else false
*/
HOST_DEV_INLINE bool isSplittable(float minSplitLoss) const {
return ((score >= minSplitLoss) && (index != INT_MAX));
}
};
/**
* @struct Node node.cuh
* @brief Abstraction of a node in the decision tree
*/
template <typename node_id_t>
class Node {
public:
/** sum of gradients across all training samples part of this node */
bst_gpair gradSum;
/** the optimal score for this node */
float score;
/** weightage for this node */
float weight;
/** default direction for missing values */
DefaultDirection dir;
/** threshold value for comparison */
float threshold;
/** column (feature) index whose value needs to be compared in this node */
int colIdx;
/** node id (used as key for reduce/scan) */
node_id_t id;
HOST_DEV_INLINE Node()
: gradSum(),
score(-FLT_MAX),
weight(-FLT_MAX),
dir(LeftDir),
threshold(0.f),
colIdx(UNUSED_NODE),
id(UNUSED_NODE) {}
/** Tells whether this node is part of the decision tree */
HOST_DEV_INLINE bool isUnused() const { return (id == UNUSED_NODE); }
/** Tells whether this node is a leaf of the decision tree */
HOST_DEV_INLINE bool isLeaf() const {
return (!isUnused() && (score == -FLT_MAX));
}
/** Tells whether default direction is left child or not */
HOST_DEV_INLINE bool isDefaultLeft() const { return (dir == LeftDir); }
};
/**
* @struct Segment node.cuh
* @brief Space inefficient, but super easy to implement structure to define
* the start and end of a segment in the input array
*/
struct Segment {
/** start index of the segment */
int start;
/** end index of the segment */
int end;
HOST_DEV_INLINE Segment() : start(-1), end(-1) {}
/** Checks whether the current structure defines a valid segment */
HOST_DEV_INLINE bool isValid() const {
return !((start == -1) || (end == -1));
}
};
/**
* @enum NodeType node.cuh
* @brief Useful to decribe the node type in a dense BFS-order tree array
*/
enum NodeType {
/** a non-leaf node */
NODE = 0,
/** leaf node */
LEAF,
/** unused node */
UNUSED
};
/**
* @brief Absolute BFS order IDs to col-wise unique IDs based on user input
* @param tid the index of the element that this thread should access
* @param abs the array of absolute IDs
* @param colIds the array of column IDs for each element
* @param nodeStart the start of the node ID at this level
* @param nKeys number of nodes at this level.
* @return the uniq key
*/
template <typename node_id_t>
HOST_DEV_INLINE int abs2uniqKey(int tid, const node_id_t* abs,
const int* colIds, node_id_t nodeStart,
int nKeys) {
int a = abs[tid];
if (a == UNUSED_NODE) return a;
return ((a - nodeStart) + (colIds[tid] * nKeys));
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,145 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../../../src/tree/param.h"
#include "node.cuh"
namespace xgboost {
namespace tree {
namespace exact {
/**
* @brief Helper function to update the child node based on the current status
* of its parent node
* @param nodes the nodes array in which the position at 'nid' will be updated
* @param nid the nodeId in the 'nodes' array corresponding to this child node
* @param grad gradient sum for this child node
* @param minChildWeight minimum child weight for the split
* @param alpha L1 regularizer for weight updates
* @param lambda lambda as in xgboost
* @param maxStep max weight step update
*/
template <typename node_id_t>
DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
const bst_gpair& grad,
const TrainParam& param) {
nodes[nid].gradSum = grad;
nodes[nid].score = CalcGain(param, grad.grad, grad.hess);
nodes[nid].weight = CalcWeight(param, grad.grad, grad.hess);
nodes[nid].id = nid;
}
/**
* @brief Helper function to update the child nodes based on the current status
* of their parent node
* @param nodes the nodes array in which the position at 'nid' will be updated
* @param pid the nodeId of the parent
* @param gradL gradient sum for the left child node
* @param gradR gradient sum for the right child node
* @param param the training parameter struct
*/
template <typename node_id_t>
DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
const bst_gpair& gradL, const bst_gpair& gradR,
const TrainParam& param) {
int childId = (pid * 2) + 1;
updateOneChildNode(nodes, childId, gradL, param);
updateOneChildNode(nodes, childId + 1, gradR, param);
}
template <typename node_id_t>
DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
const Node<node_id_t>& n, int absNodeId,
int colId, const bst_gpair& gradScan,
const bst_gpair& colSum, float thresh,
const TrainParam& param) {
bool missingLeft = true;
// get the default direction for the current node
bst_gpair missing = n.gradSum - colSum;
loss_chg_missing(gradScan, missing, n.gradSum, n.score, param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes
bst_gpair lGradSum, rGradSum;
if (missingLeft) {
lGradSum = gradScan + n.gradSum - colSum;
} else {
lGradSum = gradScan;
}
rGradSum = n.gradSum - lGradSum;
updateChildNodes(nodes, absNodeId, lGradSum, rGradSum, param);
// update default-dir, threshold and feature id for current node
nodes[absNodeId].dir = missingLeft ? LeftDir : RightDir;
nodes[absNodeId].colIdx = colId;
nodes[absNodeId].threshold = thresh;
}
template <typename node_id_t, int BLKDIM = 256>
__global__ void split2nodeKernel(
Node<node_id_t>* nodes, const Split* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const int* colOffsets, const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols, const TrainParam param) {
int uid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (uid >= nUniqKeys) {
return;
}
int absNodeId = uid + nodeStart;
Split s = nodeSplits[uid];
if (s.isSplittable(param.min_split_loss)) {
int idx = s.index;
int nodeInstId =
abs2uniqKey(idx, nodeAssigns, colIds, nodeStart, nUniqKeys);
updateNodeAndChildren(nodes, s, nodes[absNodeId], absNodeId, colIds[idx],
gradScans[idx], gradSums[nodeInstId], vals[idx],
param);
} else {
// cannot be split further, so this node is a leaf!
nodes[absNodeId].score = -FLT_MAX;
}
}
/**
* @brief function to convert split information into node
* @param nodes the output nodes
* @param nodeSplits split information
* @param gradScans scan of sorted gradients across columns
* @param gradSums key-wise gradient reduction across columns
* @param vals the feature values
* @param colIds column indices for each element in the array
* @param colOffsets column segment offsets
* @param nodeAssigns node-id assignment to every feature value
* @param nUniqKeys number of nodes that we are currently working on
* @param nodeStart start offset of the nodes in the overall BFS tree
* @param nCols number of columns
* @param preUniquifiedKeys whether to uniquify the keys from inside kernel or
* not
* @param param the training parameter struct
*/
template <typename node_id_t, int BLKDIM = 256>
void split2node(Node<node_id_t>* nodes, const Split* nodeSplits,
const bst_gpair* gradScans, const bst_gpair* gradSums,
const float* vals, const int* colIds, const int* colOffsets,
const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols, const TrainParam param) {
int nBlks = dh::div_round_up(nUniqKeys, BLKDIM);
split2nodeKernel<<<nBlks, BLKDIM>>>(nodes, nodeSplits, gradScans, gradSums,
vals, colIds, colOffsets, nodeAssigns,
nUniqKeys, nodeStart, nCols, param);
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,147 +0,0 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <thrust/sequence.h>
#include <xgboost/logging.h>
#include <cub/cub.cuh>
#include <vector>
#include "../../src/tree/param.h"
#include "common.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
struct GPUData {
GPUData() : allocated(false), n_features(0), n_instances(0) {}
bool allocated;
int n_features;
int n_instances;
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
// dh::bulk_allocator<int> ba;
GPUTrainingParam param;
dh::dvec<float> fvalues;
dh::dvec<float> fvalues_temp;
dh::dvec<float> fvalues_cached;
dh::dvec<int> foffsets;
dh::dvec<bst_uint> instance_id;
dh::dvec<bst_uint> instance_id_temp;
dh::dvec<bst_uint> instance_id_cached;
dh::dvec<int> feature_id;
dh::dvec<NodeIdT> node_id;
dh::dvec<NodeIdT> node_id_temp;
dh::dvec<NodeIdT> node_id_instance;
dh::dvec<gpu_gpair> gpair;
dh::dvec<Node> nodes;
dh::dvec<Split> split_candidates;
dh::dvec<gpu_gpair> node_sums;
dh::dvec<int> node_offsets;
dh::dvec<int> sort_index_in;
dh::dvec<int> sort_index_out;
dh::dvec<char> cub_mem;
dh::dvec<int> feature_flags;
dh::dvec<int> feature_set;
ItemIter items_iter;
void Init(const std::vector<float> &in_fvalues,
const std::vector<int> &in_foffsets,
const std::vector<bst_uint> &in_instance_id,
const std::vector<int> &in_feature_id,
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
bst_uint n_features_in, int max_depth, const TrainParam &param_in) {
n_features = n_features_in;
n_instances = n_instances_in;
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
uint32_t max_nodes_level = 1 << max_depth;
// Calculate memory for sort
size_t cub_mem_size = 0;
cub::DoubleBuffer<NodeIdT> db_key;
cub::DoubleBuffer<int> db_value;
cub::DeviceSegmentedRadixSort::SortPairs(
cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(),
n_features, foffsets.data(), foffsets.data() + 1);
// Allocate memory
size_t free_memory =
dh::available_memory(dh::get_device_idx(param_in.gpu_id));
ba.allocate(
dh::get_device_idx(param_in.gpu_id), &fvalues, in_fvalues.size(),
&fvalues_temp, in_fvalues.size(), &fvalues_cached, in_fvalues.size(),
&foffsets, in_foffsets.size(), &instance_id, in_instance_id.size(),
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
in_instance_id.size(), &feature_id, in_feature_id.size(), &node_id,
in_fvalues.size(), &node_id_temp, in_fvalues.size(), &node_id_instance,
n_instances, &gpair, n_instances, &nodes, max_nodes, &split_candidates,
max_nodes_level * n_features, &node_sums, max_nodes_level * n_features,
&node_offsets, max_nodes_level * n_features, &sort_index_in,
in_fvalues.size(), &sort_index_out, in_fvalues.size(), &cub_mem,
cub_mem_size, &feature_flags, n_features, &feature_set, n_features);
if (!param_in.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on "
<< dh::device_name(dh::get_device_idx(param_in.gpu_id));
}
fvalues_cached = in_fvalues;
foffsets = in_foffsets;
instance_id_cached = in_instance_id;
feature_id = in_feature_id;
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step);
allocated = true;
this->Reset(in_gpair, param_in.subsample);
items_iter = thrust::make_zip_iterator(thrust::make_tuple(
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
fvalues.tbegin(), node_id.tbegin()));
dh::safe_cuda(cudaGetLastError());
}
~GPUData() {}
// Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
CHECK(allocated);
gpair = in_gpair;
subsample_gpair(&gpair, subsample);
instance_id = instance_id_cached;
fvalues = fvalues_cached;
nodes.fill(Node());
node_id_instance.fill(0);
node_id.fill(0);
}
bool IsAllocated() { return allocated; }
// Gather from node_id_instance into node_id according to instance_id
void GatherNodeId() {
// Update node_id for each item
auto d_node_id = node_id.data();
auto d_node_id_instance = node_id_instance.data();
auto d_instance_id = instance_id.data();
dh::launch_n(node_id.device_idx(), fvalues.size(),
[=] __device__(bst_uint i) {
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
});
}
};
} // namespace tree
} // namespace xgboost

File diff suppressed because it is too large Load Diff

View File

@ -1,134 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <thrust/device_vector.h>
#include <xgboost/tree_updater.h>
#include <cub/util_type.cuh> // Need key value pair definition
#include <vector>
#include "../../src/common/hist_util.h"
#include "../../src/tree/param.h"
#include "../../src/common/compressed_iterator.h"
#include "device_helpers.cuh"
#include "types.cuh"
#include "nccl.h"
namespace xgboost {
namespace tree {
struct DeviceGMat {
dh::dvec<common::compressed_byte_t> gidx_buffer;
common::CompressedIterator<uint32_t> gidx;
dh::dvec<size_t> row_ptr;
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, bst_ulong row_end,int n_bins);
};
struct HistBuilder {
bst_gpair_precise *d_hist;
int n_bins;
__host__ __device__ HistBuilder(bst_gpair_precise *ptr, int n_bins);
__device__ void Add(bst_gpair_precise gpair, int gidx, int nidx) const;
__device__ bst_gpair_precise Get(int gidx, int nidx) const;
};
struct DeviceHist {
int n_bins;
dh::dvec<bst_gpair_precise> data;
void Init(int max_depth);
void Reset(int device_idx);
HistBuilder GetBuilder();
bst_gpair_precise *GetLevelPtr(int depth);
int LevelSize(int depth);
};
class GPUHistBuilder {
public:
GPUHistBuilder();
~GPUHistBuilder();
void Init(const TrainParam &param);
void UpdateParam(const TrainParam &param) {
this->param = param;
}
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree);
void BuildHist(int depth);
void FindSplit(int depth);
template <int BLOCK_THREADS>
void FindSplitSpecialize(int depth);
template <int BLOCK_THREADS>
void LaunchFindSplit(int depth);
void InitFirstNode(const std::vector<bst_gpair> &gpair);
void UpdatePosition(int depth);
void UpdatePositionDense(int depth);
void UpdatePositionSparse(int depth);
void ColSampleTree();
void ColSampleLevel();
bool UpdatePredictionCache(const DMatrix *data,
std::vector<bst_float> *p_out_preds);
TrainParam param;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo *info;
bool initialised;
bool is_dense;
const DMatrix *p_last_fmat_;
bool prediction_cache_initialised;
// choose which memory type to use (DEVICE or DEVICE_MANAGED)
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
// dh::bulk_allocator<dh::memory_type::DEVICE_MANAGED> ba; // can't be used
// with NCCL
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
bst_uint num_rows;
int n_devices;
// below vectors are for each devices used
std::vector<int> dList;
std::vector<int> device_row_segments;
std::vector<size_t> device_element_segments;
std::vector<dh::CubMemory> temp_memory;
std::vector<DeviceHist> hist_vec;
std::vector<dh::dvec<Node>> nodes;
std::vector<dh::dvec<Node>> nodes_temp;
std::vector<dh::dvec<Node>> nodes_child_temp;
std::vector<dh::dvec<bool>> left_child_smallest;
std::vector<dh::dvec<bool>> left_child_smallest_temp;
std::vector<dh::dvec<int>> feature_flags;
std::vector<dh::dvec<float>> fidx_min_map;
std::vector<dh::dvec<int>> feature_segments;
std::vector<dh::dvec<bst_float>> prediction_cache;
std::vector<dh::dvec<int>> position;
std::vector<dh::dvec<int>> position_tmp;
std::vector<DeviceGMat> device_matrix;
std::vector<dh::dvec<bst_gpair>> device_gpair;
std::vector<dh::dvec<int>> gidx_feature_map;
std::vector<dh::dvec<float>> gidx_fvalue_map;
std::vector<cudaStream_t *> streams;
std::vector<ncclComm_t> comms;
std::vector<std::vector<ncclComm_t>> find_split_comms;
double cpu_init_time;
double gpu_init_time;
dh::Timer cpu_time;
double gpu_time;
};
} // namespace tree
} // namespace xgboost

View File

@ -1,20 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <xgboost/tree_updater.h>
#include "updater_gpu.cuh"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker(); });
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker(); });
} // namespace tree
} // namespace xgboost

View File

@ -1,122 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include <xgboost/tree_model.h>
#include <cfloat>
#include <tuple> // The linter is not very smart and thinks we need this
namespace xgboost {
namespace tree {
struct GPUTrainingParam {
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// L2 regularization factor
float reg_lambda;
// L1 regularization factor
float reg_alpha;
// maximum delta update we can add in weight estimation
// this parameter can be used to stabilize update
// default=0 means no constraint on weight delta
float max_delta_step;
__host__ __device__ GPUTrainingParam() {}
__host__ __device__ GPUTrainingParam(const TrainParam &param)
: min_child_weight(param.min_child_weight),
reg_lambda(param.reg_lambda),
reg_alpha(param.reg_alpha),
max_delta_step(param.max_delta_step) {}
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
float reg_lambda_in, float reg_alpha_in,
float max_delta_step_in)
: min_child_weight(min_child_weight_in),
reg_lambda(reg_lambda_in),
reg_alpha(reg_alpha_in),
max_delta_step(max_delta_step_in) {}
};
struct Split {
float loss_chg;
bool missing_left;
float fvalue;
int findex;
bst_gpair left_sum;
bst_gpair right_sum;
__host__ __device__ Split()
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
__device__ void Update(float loss_chg_in, bool missing_left_in,
float fvalue_in, int findex_in, bst_gpair left_sum_in,
bst_gpair right_sum_in,
const GPUTrainingParam &param) {
if (loss_chg_in > loss_chg &&
left_sum_in.hess>= param.min_child_weight &&
right_sum_in.hess>= param.min_child_weight) {
loss_chg = loss_chg_in;
missing_left = missing_left_in;
fvalue = fvalue_in;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}
// Does not check minimum weight
__device__ void Update(Split &s) { // NOLINT
if (s.loss_chg > loss_chg) {
loss_chg = s.loss_chg;
missing_left = s.missing_left;
fvalue = s.fvalue;
findex = s.findex;
left_sum = s.left_sum;
right_sum = s.right_sum;
}
}
//__host__ __device__ void Print() {
// printf("Loss: %1.4f\n", loss_chg);
// printf("Missing left: %d\n", missing_left);
// printf("fvalue: %1.4f\n", fvalue);
// printf("Left sum: ");
// left_sum.print();
// printf("Right sum: ");
// right_sum.print();
//}
};
struct split_reduce_op {
template <typename T>
__device__ __forceinline__ T operator()(T &a, T b) { // NOLINT
b.Update(a);
return b;
}
};
struct Node {
bst_gpair sum_gradients;
float root_gain;
float weight;
Split split;
__host__ __device__ Node() : weight(0), root_gain(0) {}
__host__ __device__ Node(bst_gpair sum_gradients_in, float root_gain_in,
float weight_in) {
sum_gradients = sum_gradients_in;
root_gain = root_gain_in;
weight = weight_in;
}
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
};
} // namespace tree
} // namespace xgboost

View File

@ -1,77 +1,754 @@
/*! /*!
* Copyright 2017 XGBoost contributors * Copyright 2017 XGBoost contributors
*/ */
#include "updater_gpu.cuh"
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <vector>
#include <utility> #include <utility>
#include <string> #include <vector>
#include "../../../src/common/random.h"
#include "../../../src/common/sync.h"
#include "../../../src/tree/param.h" #include "../../../src/tree/param.h"
#include "exact/gpu_builder.cuh" #include "updater_gpu_common.cuh"
#include "gpu_hist_builder.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
GPUMaker::GPUMaker() : builder(new exact::GPUBuilder<int16_t>()) {} DMLC_REGISTRY_FILE_TAG(updater_gpu);
void GPUMaker::Init( /**
const std::vector<std::pair<std::string, std::string>>& args) { * @brief Absolute BFS order IDs to col-wise unique IDs based on user input
param.InitAllowUnknown(args); * @param tid the index of the element that this thread should access
builder->Init(param); * @param abs the array of absolute IDs
* @param colIds the array of column IDs for each element
* @param nodeStart the start of the node ID at this level
* @param nKeys number of nodes at this level.
* @return the uniq key
*/
static HOST_DEV_INLINE node_id_t abs2uniqKey(int tid, const node_id_t* abs,
const int* colIds, node_id_t nodeStart,
int nKeys) {
int a = abs[tid];
if (a == UNUSED_NODE) return a;
return ((a - nodeStart) + (colIds[tid] * nKeys));
} }
void GPUMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder->UpdateParam(param);
try { /**
// build tree * @struct Pair
for (size_t i = 0; i < trees.size(); ++i) { * @brief Pair used for key basd scan operations on bst_gpair
builder->Update(gpair, dmat, trees[i]); */
struct Pair {
int key;
bst_gpair value;
};
/** define a key that's not used at all in the entire boosting process */
static const int NONE_KEY = -100;
/**
* @brief Allocate temporary buffers needed for scan operations
* @param tmpScans gradient buffer
* @param tmpKeys keys buffer
* @param size number of elements that will be scanned
*/
template <int BLKDIM_L1L3 = 256>
int scanTempBufferSize(int size) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
return nBlks;
}
struct AddByKey {
template <typename T>
HOST_DEV_INLINE T operator()(const T& first, const T& second) const {
T result;
if (first.key == second.key) {
result.key = first.key;
result.value = first.value + second.value;
} else {
result.key = second.key;
result.value = second.value;
} }
} catch (const std::exception& e) { return result;
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
} }
param.learning_rate = lr; };
/**
* @brief Gradient value getter function
* @param id the index into the vals or instIds array to which to fetch
* @param vals the gradient value buffer
* @param instIds instance index buffer
* @return the expected gradient value
*/
HOST_DEV_INLINE bst_gpair get(int id, const bst_gpair* vals,
const int* instIds) {
id = instIds[id];
return vals[id];
} }
GPUHistMaker::GPUHistMaker() : builder(new GPUHistBuilder()) {} template <int BLKDIM_L1L3>
__global__ void cubScanByKeyL1(bst_gpair* scans, const bst_gpair* vals,
void GPUHistMaker::Init( const int* instIds, bst_gpair* mScans,
const std::vector<std::pair<std::string, std::string>>& args) { int* mKeys, const node_id_t* keys, int nUniqKeys,
param.InitAllowUnknown(args); const int* colIds, node_id_t nodeStart,
builder->Init(param); const int size) {
Pair rootPair = {NONE_KEY, bst_gpair(0.f, 0.f)};
int myKey;
bst_gpair myValue;
typedef cub::BlockScan<Pair, BLKDIM_L1L3> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
Pair threadData;
int tid = blockIdx.x * BLKDIM_L1L3 + threadIdx.x;
if (tid < size) {
myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
myValue = get(tid, vals, instIds);
} else {
myKey = NONE_KEY;
myValue = 0.f;
}
threadData.key = myKey;
threadData.value = myValue;
// get previous key, especially needed for the last thread in this block
// in order to pass on the partial scan values.
// this statement MUST appear before the checks below!
// else, the result of this shuffle operation will be undefined
int previousKey = __shfl_up(myKey, 1);
// Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());
if (tid < size) {
scans[tid] = threadData.value;
} else {
return;
}
if (threadIdx.x == BLKDIM_L1L3 - 1) {
threadData.value =
(myKey == previousKey) ? threadData.value : bst_gpair(0.0f, 0.0f);
mKeys[blockIdx.x] = myKey;
mScans[blockIdx.x] = threadData.value + myValue;
}
} }
void GPUHistMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat, template <int BLKSIZE>
const std::vector<RegTree*>& trees) { __global__ void cubScanByKeyL2(bst_gpair* mScans, int* mKeys, int mLength) {
GradStats::CheckInfo(dmat->info()); typedef cub::BlockScan<Pair, BLKSIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
// rescale learning rate according to size of trees Pair threadData;
float lr = param.learning_rate; __shared__ typename BlockScan::TempStorage temp_storage;
param.learning_rate = lr / trees.size(); for (int i = threadIdx.x; i < mLength; i += BLKSIZE - 1) {
builder->UpdateParam(param); threadData.key = mKeys[i];
// build tree threadData.value = mScans[i];
try { BlockScan(temp_storage).InclusiveScan(threadData, threadData, AddByKey());
for (size_t i = 0; i < trees.size(); ++i) { mScans[i] = threadData.value;
builder->Update(gpair, dmat, trees[i]); __syncthreads();
}
}
template <int BLKDIM_L1L3>
__global__ void cubScanByKeyL3(bst_gpair* sums, bst_gpair* scans,
const bst_gpair* vals, const int* instIds,
const bst_gpair* mScans, const int* mKeys,
const node_id_t* keys, int nUniqKeys,
const int* colIds, node_id_t nodeStart,
const int size) {
int relId = threadIdx.x;
int tid = (blockIdx.x * BLKDIM_L1L3) + relId;
// to avoid the following warning from nvcc:
// __shared__ memory variable with non-empty constructor or destructor
// (potential race between threads)
__shared__ char gradBuff[sizeof(bst_gpair)];
__shared__ int s_mKeys;
bst_gpair* s_mScans = reinterpret_cast<bst_gpair*>(gradBuff);
if (tid >= size) return;
// cache block-wide partial scan info
if (relId == 0) {
s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : NONE_KEY;
s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : bst_gpair();
}
int myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
int previousKey =
tid == 0 ? NONE_KEY
: abs2uniqKey(tid - 1, keys, colIds, nodeStart, nUniqKeys);
bst_gpair myValue = scans[tid];
__syncthreads();
if (blockIdx.x > 0 && s_mKeys == previousKey) {
myValue += s_mScans[0];
}
if (tid == size - 1) {
sums[previousKey] = myValue + get(tid, vals, instIds);
}
if ((previousKey != myKey) && (previousKey >= 0)) {
sums[previousKey] = myValue;
myValue = bst_gpair(0.0f, 0.0f);
}
scans[tid] = myValue;
}
/**
* @brief Performs fused reduce and scan by key functionality. It is assumed
* that
* the keys occur contiguously!
* @param sums the output gradient reductions for each element performed
* key-wise
* @param scans the output gradient scans for each element performed key-wise
* @param vals the gradients evaluated for each observation.
* @param instIds instance ids for each element
* @param keys keys to be used to segment the reductions. They need not occur
* contiguously in contrast to scan_by_key. Currently, we need one key per
* value in the 'vals' array.
* @param size number of elements in the 'vals' array
* @param nUniqKeys max number of uniq keys found per column
* @param nCols number of columns
* @param tmpScans temporary scan buffer needed for cub-pyramid algo
* @param tmpKeys temporary key buffer needed for cub-pyramid algo
* @param colIds column indices for each element in the array
* @param nodeStart index of the leftmost node in the current level
*/
template <int BLKDIM_L1L3 = 256, int BLKDIM_L2 = 512>
void reduceScanByKey(bst_gpair* sums, bst_gpair* scans, const bst_gpair* vals,
const int* instIds, const node_id_t* keys, int size,
int nUniqKeys, int nCols, bst_gpair* tmpScans,
int* tmpKeys, const int* colIds, node_id_t nodeStart) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
cudaMemset(sums, 0, nUniqKeys * nCols * sizeof(bst_gpair));
cubScanByKeyL1<BLKDIM_L1L3>
<<<nBlks, BLKDIM_L1L3>>>(scans, vals, instIds, tmpScans, tmpKeys, keys,
nUniqKeys, colIds, nodeStart, size);
cubScanByKeyL2<BLKDIM_L2><<<1, BLKDIM_L2>>>(tmpScans, tmpKeys, nBlks);
cubScanByKeyL3<BLKDIM_L1L3>
<<<nBlks, BLKDIM_L1L3>>>(sums, scans, vals, instIds, tmpScans, tmpKeys,
keys, nUniqKeys, colIds, nodeStart, size);
}
/**
* @struct ExactSplitCandidate
* @brief Abstraction of a possible split in the decision tree
*/
struct ExactSplitCandidate {
/** the optimal gain score for this node */
float score;
/** index where to split in the DMatrix */
int index;
HOST_DEV_INLINE ExactSplitCandidate() : score(-FLT_MAX), index(INT_MAX) {}
/**
* @brief Whether the split info is valid to be used to create a new child
* @param minSplitLoss minimum score above which decision to split is made
* @return true if splittable, else false
*/
HOST_DEV_INLINE bool isSplittable(float minSplitLoss) const {
return ((score >= minSplitLoss) && (index != INT_MAX));
}
};
/**
* @enum ArgMaxByKeyAlgo best_split_evaluation.cuh
* @brief Help decide which algorithm to use for multi-argmax operation
*/
enum ArgMaxByKeyAlgo {
/** simplest, use gmem-atomics for all updates */
ABK_GMEM = 0,
/** use smem-atomics for updates (when number of keys are less) */
ABK_SMEM
};
/** max depth until which to use shared mem based atomics for argmax */
static const int MAX_ABK_LEVELS = 3;
HOST_DEV_INLINE ExactSplitCandidate maxSplit(ExactSplitCandidate a,
ExactSplitCandidate b) {
ExactSplitCandidate out;
if (a.score < b.score) {
out.score = b.score;
out.index = b.index;
} else if (a.score == b.score) {
out.score = a.score;
out.index = (a.index < b.index) ? a.index : b.index;
} else {
out.score = a.score;
out.index = a.index;
}
return out;
}
DEV_INLINE void atomicArgMax(ExactSplitCandidate* address,
ExactSplitCandidate val) {
unsigned long long* intAddress = (unsigned long long*)address; // NOLINT
unsigned long long old = *intAddress; // NOLINT
unsigned long long assumed; // NOLINT
do {
assumed = old;
ExactSplitCandidate res =
maxSplit(val, *reinterpret_cast<ExactSplitCandidate*>(&assumed));
old = atomicCAS(intAddress, assumed, *reinterpret_cast<uint64_t*>(&res));
} while (assumed != old);
}
DEV_INLINE void argMaxWithAtomics(
int id, ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const GPUTrainingParam& param) {
int nodeId = nodeAssigns[id];
// @todo: this is really a bad check! but will be fixed when we move
// to key-based reduction
if ((id == 0) ||
!((nodeId == nodeAssigns[id - 1]) && (colIds[id] == colIds[id - 1]) &&
(vals[id] == vals[id - 1]))) {
if (nodeId != UNUSED_NODE) {
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
bst_gpair colSum = gradSums[sumId];
int uid = nodeId - nodeStart;
DeviceDenseNode n = nodes[nodeId];
bst_gpair parentSum = n.sum_gradients;
float parentGain = n.root_gain;
bool tmp;
ExactSplitCandidate s;
bst_gpair missing = parentSum - colSum;
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
param, tmp);
s.index = id;
atomicArgMax(nodeSplits + uid, s);
} // end if nodeId != UNUSED_NODE
} // end if id == 0 ...
}
__global__ void atomicArgMaxByKeyGmem(
ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < len; id += stride) {
argMaxWithAtomics(id, nodeSplits, gradScans, gradSums, vals, colIds,
nodeAssigns, nodes, nUniqKeys, nodeStart, len,
GPUTrainingParam(param));
}
}
__global__ void atomicArgMaxByKeySmem(
ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param) {
extern __shared__ char sArr[];
ExactSplitCandidate* sNodeSplits =
reinterpret_cast<ExactSplitCandidate*>(sArr);
int tid = threadIdx.x;
ExactSplitCandidate defVal;
#pragma unroll 1
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
sNodeSplits[i] = defVal;
}
__syncthreads();
int id = tid + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < len; id += stride) {
argMaxWithAtomics(id, sNodeSplits, gradScans, gradSums, vals, colIds,
nodeAssigns, nodes, nUniqKeys, nodeStart, len, param);
}
__syncthreads();
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
ExactSplitCandidate s = sNodeSplits[i];
atomicArgMax(nodeSplits + i, s);
}
}
/**
* @brief Performs argmax_by_key functionality but for cases when keys need not
* occur contiguously
* @param nodeSplits will contain information on best split for each node
* @param gradScans exclusive sum on sorted segments for each col
* @param gradSums gradient sum for each column in DMatrix based on to node-ids
* @param vals feature values
* @param colIds column index for each element in the feature values array
* @param nodeAssigns node-id assignments to each element in DMatrix
* @param nodes pointer to all nodes for this tree in BFS order
* @param nUniqKeys number of unique node-ids in this level
* @param nodeStart start index of the node-ids in this level
* @param len number of elements
* @param param training parameters
* @param algo which algorithm to use for argmax_by_key
*/
template <int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
void argMaxByKey(ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals,
const int* colIds, const node_id_t* nodeAssigns,
const DeviceDenseNode* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param,
ArgMaxByKeyAlgo algo) {
dh::fillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
dh::get_device_idx(param.gpu_id), nodeSplits, nUniqKeys,
ExactSplitCandidate());
int nBlks = dh::div_round_up(len, ITEMS_PER_THREAD * BLKDIM);
switch (algo) {
case ABK_GMEM:
atomicArgMaxByKeyGmem<<<nBlks, BLKDIM>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
case ABK_SMEM:
atomicArgMaxByKeySmem<<<nBlks, BLKDIM,
sizeof(ExactSplitCandidate) * nUniqKeys>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
default:
throw std::runtime_error("argMaxByKey: Bad algo passed!");
}
}
__global__ void assignColIds(int* colIds, const int* colOffsets) {
int myId = blockIdx.x;
int start = colOffsets[myId];
int end = colOffsets[myId + 1];
for (int id = start + threadIdx.x; id < end; id += blockDim.x) {
colIds[id] = myId;
}
}
__global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
const DeviceDenseNode* nodes, int nRows) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
if (id >= nRows) {
return;
}
// if this element belongs to none of the currently active node-id's
node_id_t nId = nodeIdsPerInst[id];
if (nId == UNUSED_NODE) {
return;
}
const DeviceDenseNode n = nodes[nId];
node_id_t result;
if (n.IsLeaf() || n.IsUnused()) {
result = UNUSED_NODE;
} else if (n.dir == LeftDir) {
result = (2 * n.idx) + 1;
} else {
result = (2 * n.idx) + 2;
}
nodeIdsPerInst[id] = result;
}
__global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
const node_id_t* nodeIds, const int* instId,
const DeviceDenseNode* nodes,
const int* colOffsets, const float* vals,
int nVals, int nCols) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < nVals; id += stride) {
// fusing generation of indices for node locations
nodeLocations[id] = id;
// using nodeIds here since the previous kernel would have updated
// the nodeIdsPerInst with all default assignments
int nId = nodeIds[id];
// if this element belongs to none of the currently active node-id's
if (nId != UNUSED_NODE) {
const DeviceDenseNode n = nodes[nId];
int colId = n.fidx;
// printf("nid=%d colId=%d id=%d\n", nId, colId, id);
int start = colOffsets[colId];
int end = colOffsets[colId + 1];
// @todo: too much wasteful threads!!
if ((id >= start) && (id < end) && !(n.IsLeaf() || n.IsUnused())) {
node_id_t result = (2 * n.idx) + 1 + (vals[id] >= n.fvalue);
nodeIdsPerInst[instId[id]] = result;
}
} }
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
} }
param.learning_rate = lr;
} }
bool GPUHistMaker::UpdatePredictionCache(const DMatrix* data, __global__ void markLeavesKernel(DeviceDenseNode* nodes, int len) {
std::vector<bst_float>* out_preds) { int id = (blockIdx.x * blockDim.x) + threadIdx.x;
return builder->UpdatePredictionCache(data, out_preds); if ((id < len) && !nodes[id].IsUnused()) {
int lid = (id << 1) + 1;
int rid = (id << 1) + 2;
if ((lid >= len) || (rid >= len)) {
nodes[id].root_gain = -FLT_MAX; // bottom-most nodes
} else if (nodes[lid].IsUnused() && nodes[rid].IsUnused()) {
nodes[id].root_gain = -FLT_MAX; // unused child nodes
}
}
} }
class GPUMaker : public TreeUpdater {
protected:
TrainParam param;
/** whether we have initialized memory already (so as not to repeat!) */
bool allocated;
/** feature values stored in column-major compressed format */
dh::dvec2<float> vals;
dh::dvec<float> vals_cached;
/** corresponding instance id's of these featutre values */
dh::dvec2<int> instIds;
dh::dvec<int> instIds_cached;
/** column offsets for these feature values */
dh::dvec<int> colOffsets;
dh::dvec<bst_gpair> gradsInst;
dh::dvec2<node_id_t> nodeAssigns;
dh::dvec2<int> nodeLocations;
dh::dvec<DeviceDenseNode> nodes;
dh::dvec<node_id_t> nodeAssignsPerInst;
dh::dvec<bst_gpair> gradSums;
dh::dvec<bst_gpair> gradScans;
dh::dvec<ExactSplitCandidate> nodeSplits;
int nVals;
int nRows;
int nCols;
int maxNodes;
int maxLeaves;
dh::CubMemory tmp_mem;
dh::dvec<bst_gpair> tmpScanGradBuff;
dh::dvec<int> tmpScanKeyBuff;
dh::dvec<int> colIds;
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
public:
GPUMaker() : allocated(false) {}
~GPUMaker() {}
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
maxNodes = (1 << (param.max_depth + 1)) - 1;
maxLeaves = 1 << param.max_depth;
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
try {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
UpdateTree(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
/// @note: Update should be only after Init!!
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
RegTree* hTree) {
if (!allocated) {
setupOneTimeData(dmat);
}
for (int i = 0; i < param.max_depth; ++i) {
if (i == 0) {
// make sure to start on a fresh tree with sorted values!
vals.current_dvec() = vals_cached;
instIds.current_dvec() = instIds_cached;
transferGrads(gpair);
}
int nNodes = 1 << i;
node_id_t nodeStart = nNodes - 1;
initNodeData(i, nodeStart, nNodes);
findSplit(i, nodeStart, nNodes);
}
// mark all the used nodes with unused children as leaf nodes
markLeaves();
dense2sparse_tree(hTree, nodes, param);
}
void split2node(int nNodes, node_id_t nodeStart) {
auto d_nodes = nodes.data();
auto d_gradScans = gradScans.data();
auto d_gradSums = gradSums.data();
auto d_nodeAssigns = nodeAssigns.current();
auto d_colIds = colIds.data();
auto d_vals = vals.current();
auto d_nodeSplits = nodeSplits.data();
int nUniqKeys = nNodes;
float min_split_loss = param.min_split_loss;
auto gpu_param = GPUTrainingParam(param);
dh::launch_n(param.gpu_id, nNodes, [=] __device__(int uid) {
int absNodeId = uid + nodeStart;
ExactSplitCandidate s = d_nodeSplits[uid];
if (s.isSplittable(min_split_loss)) {
int idx = s.index;
int nodeInstId =
abs2uniqKey(idx, d_nodeAssigns, d_colIds, nodeStart, nUniqKeys);
bool missingLeft = true;
const DeviceDenseNode& n = d_nodes[absNodeId];
bst_gpair gradScan = d_gradScans[idx];
bst_gpair gradSum = d_gradSums[nodeInstId];
float thresh = d_vals[idx];
int colId = d_colIds[idx];
// get the default direction for the current node
bst_gpair missing = n.sum_gradients - gradSum;
loss_chg_missing(gradScan, missing, n.sum_gradients, n.root_gain,
gpu_param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes
bst_gpair lGradSum = missingLeft ? gradScan + missing : gradScan;
bst_gpair rGradSum = n.sum_gradients - lGradSum;
// Create children
d_nodes[left_child_nidx(absNodeId)] =
DeviceDenseNode(lGradSum, left_child_nidx(absNodeId), gpu_param);
d_nodes[right_child_nidx(absNodeId)] =
DeviceDenseNode(rGradSum, right_child_nidx(absNodeId), gpu_param);
// Set split for parent
d_nodes[absNodeId].SetSplit(thresh, colId,
missingLeft ? LeftDir : RightDir);
} else {
// cannot be split further, so this node is a leaf!
d_nodes[absNodeId].root_gain = -FLT_MAX;
}
});
}
void findSplit(int level, node_id_t nodeStart, int nNodes) {
reduceScanByKey(gradSums.data(), gradScans.data(), gradsInst.data(),
instIds.current(), nodeAssigns.current(), nVals, nNodes,
nCols, tmpScanGradBuff.data(), tmpScanKeyBuff.data(),
colIds.data(), nodeStart);
argMaxByKey(nodeSplits.data(), gradScans.data(), gradSums.data(),
vals.current(), colIds.data(), nodeAssigns.current(),
nodes.data(), nNodes, nodeStart, nVals, param,
level <= MAX_ABK_LEVELS ? ABK_SMEM : ABK_GMEM);
split2node(nNodes, nodeStart);
}
void allocateAllData(int offsetSize) {
int tmpBuffSize = scanTempBufferSize(nVals);
ba.allocate(dh::get_device_idx(param.gpu_id), param.silent, &vals, nVals,
&vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals,
&colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
nRows, &gradSums, maxLeaves * nCols, &gradScans, nVals,
&nodeSplits, maxLeaves, &tmpScanGradBuff, tmpBuffSize,
&tmpScanKeyBuff, tmpBuffSize, &colIds, nVals);
}
void setupOneTimeData(DMatrix* dmat) {
size_t free_memory = dh::available_memory(dh::get_device_idx(param.gpu_id));
if (!dmat->SingleColBlock()) {
throw std::runtime_error("exact::GPUBuilder - must have 1 column block");
}
std::vector<float> fval;
std::vector<int> fId, offset;
convertToCsc(dmat, &fval, &fId, &offset);
allocateAllData(static_cast<int>(offset.size()));
transferAndSortData(fval, fId, offset);
allocated = true;
}
void convertToCsc(DMatrix* dmat, std::vector<float>* fval,
std::vector<int>* fId, std::vector<int>* offset) {
MetaInfo info = dmat->info();
nRows = info.num_row;
nCols = info.num_col;
offset->reserve(nCols + 1);
offset->push_back(0);
fval->reserve(nCols * nRows);
fId->reserve(nCols * nRows);
// in case you end up with a DMatrix having no column access
// then make sure to enable that before copying the data!
if (!dmat->HaveColAccess()) {
const std::vector<bool> enable(nCols, true);
dmat->InitColAccess(enable, 1, nRows);
}
dmlc::DataIter<ColBatch>* iter = dmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch& batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst& col = batch[i];
for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
it++) {
int inst_id = static_cast<int>(it->index);
fval->push_back(it->fvalue);
fId->push_back(inst_id);
}
offset->push_back(fval->size());
}
}
nVals = fval->size();
}
void transferAndSortData(const std::vector<float>& fval,
const std::vector<int>& fId,
const std::vector<int>& offset) {
vals.current_dvec() = fval;
instIds.current_dvec() = fId;
colOffsets = offset;
dh::segmentedSort<float, int>(&tmp_mem, &vals, &instIds, nVals, nCols,
colOffsets);
vals_cached = vals.current_dvec();
instIds_cached = instIds.current_dvec();
assignColIds<<<nCols, 512>>>(colIds.data(), colOffsets.data());
}
void transferGrads(const std::vector<bst_gpair>& gpair) {
// HACK
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
sizeof(bst_gpair) * nRows,
cudaMemcpyHostToDevice));
// evaluate the full-grad reduction for the root node
dh::sumReduction<bst_gpair>(tmp_mem, gradsInst, gradSums, nRows);
}
void initNodeData(int level, node_id_t nodeStart, int nNodes) {
// all instances belong to root node at the beginning!
if (level == 0) {
nodes.fill(DeviceDenseNode());
nodeAssigns.current_dvec().fill(0);
nodeAssignsPerInst.fill(0);
// for root node, just update the gradient/score/weight/id info
// before splitting it! Currently all data is on GPU, hence this
// stupid little kernel
auto d_nodes = nodes.data();
auto d_sums = gradSums.data();
auto gpu_params = GPUTrainingParam(param);
dh::launch_n(param.gpu_id, 1, [=] __device__(int idx) {
d_nodes[0] = DeviceDenseNode(d_sums[0], 0, gpu_params);
});
} else {
const int BlkDim = 256;
const int ItemsPerThread = 4;
// assign default node ids first
int nBlks = dh::div_round_up(nRows, BlkDim);
fillDefaultNodeIds<<<nBlks, BlkDim>>>(nodeAssignsPerInst.data(),
nodes.data(), nRows);
// evaluate the correct child indices of non-missing values next
nBlks = dh::div_round_up(nVals, BlkDim * ItemsPerThread);
assignNodeIds<<<nBlks, BlkDim>>>(
nodeAssignsPerInst.data(), nodeLocations.current(),
nodeAssigns.current(), instIds.current(), nodes.data(),
colOffsets.data(), vals.current(), nVals, nCols);
// gather the node assignments across all other columns too
dh::gather(dh::get_device_idx(param.gpu_id), nodeAssigns.current(),
nodeAssignsPerInst.data(), instIds.current(), nVals);
sortKeys(level);
}
}
void sortKeys(int level) {
// segmented-sort the arrays based on node-id's
// but we don't need more than level+1 bits for sorting!
segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
colOffsets, 0, level + 1);
dh::gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
vals.buff().selector ^= 1;
instIds.buff().selector ^= 1;
}
void markLeaves() {
const int BlkDim = 128;
int nBlks = dh::div_round_up(maxNodes, BlkDim);
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
}
};
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker(); });
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,48 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <xgboost/tree_updater.h>
#include <memory>
#include "../../../src/tree/param.h"
namespace xgboost {
namespace tree {
// Forward declare builder classes
class GPUHistBuilder;
namespace exact {
template <typename node_id_t>
class GPUBuilder;
}
class GPUMaker : public TreeUpdater {
protected:
TrainParam param;
std::unique_ptr<exact::GPUBuilder<int16_t>> builder;
public:
GPUMaker();
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override;
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees);
};
class GPUHistMaker : public TreeUpdater {
public:
GPUHistMaker();
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override;
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* out_preds) override;
protected:
TrainParam param;
std::unique_ptr<GPUHistBuilder> builder;
};
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,243 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <thrust/random.h>
#include <cstdio>
#include <stdexcept>
#include <string>
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "cub/cub.cuh"
#include "device_helpers.cuh"
namespace xgboost {
namespace tree {
struct GPUTrainingParam {
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// L2 regularization factor
float reg_lambda;
// L1 regularization factor
float reg_alpha;
// maximum delta update we can add in weight estimation
// this parameter can be used to stabilize update
// default=0 means no constraint on weight delta
float max_delta_step;
__host__ __device__ GPUTrainingParam() {}
__host__ __device__ GPUTrainingParam(const TrainParam& param)
: min_child_weight(param.min_child_weight),
reg_lambda(param.reg_lambda),
reg_alpha(param.reg_alpha),
max_delta_step(param.max_delta_step) {}
};
typedef int node_id_t;
/** used to assign default id to a Node */
static const int UNUSED_NODE = -1;
/**
* @enum DefaultDirection node.cuh
* @brief Default direction to be followed in case of missing values
*/
enum DefaultDirection {
/** move to left child */
LeftDir = 0,
/** move to right child */
RightDir
};
struct DeviceDenseNode {
bst_gpair sum_gradients;
float root_gain;
float weight;
/** default direction for missing values */
DefaultDirection dir;
/** threshold value for comparison */
float fvalue;
/** \brief The feature index. */
int fidx;
/** node id (used as key for reduce/scan) */
node_id_t idx;
HOST_DEV_INLINE DeviceDenseNode()
: sum_gradients(),
root_gain(-FLT_MAX),
weight(-FLT_MAX),
dir(LeftDir),
fvalue(0.f),
fidx(UNUSED_NODE),
idx(UNUSED_NODE) {}
HOST_DEV_INLINE DeviceDenseNode(bst_gpair sum_gradients, node_id_t nidx,
const GPUTrainingParam& param)
: sum_gradients(sum_gradients),
dir(LeftDir),
fvalue(0.f),
fidx(UNUSED_NODE),
idx(nidx) {
this->root_gain = CalcGain(param, sum_gradients.grad, sum_gradients.hess);
this->weight = CalcWeight(param, sum_gradients.grad, sum_gradients.hess);
}
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) {
this->fvalue = fvalue;
this->fidx = fidx;
this->dir = dir;
}
/** Tells whether this node is part of the decision tree */
HOST_DEV_INLINE bool IsUnused() const { return (idx == UNUSED_NODE); }
/** Tells whether this node is a leaf of the decision tree */
HOST_DEV_INLINE bool IsLeaf() const {
return (!IsUnused() && (fidx == UNUSED_NODE));
}
};
template <typename gpair_t>
__device__ inline float device_calc_loss_chg(
const GPUTrainingParam& param, const gpair_t& scan, const gpair_t& missing,
const gpair_t& parent_sum, const float& parent_gain, bool missing_left) {
gpair_t left = scan;
if (missing_left) {
left += missing;
}
gpair_t right = parent_sum - left;
float left_gain = CalcGain(param, left.grad, left.hess);
float right_gain = CalcGain(param, right.grad, right.hess);
return left_gain + right_gain - parent_gain;
}
template <typename gpair_t>
__device__ float inline loss_chg_missing(const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// Total number of nodes in tree, given depth
__host__ __device__ inline int n_nodes(int depth) {
return (1 << (depth + 1)) - 1;
}
// Number of nodes at this level of the tree
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
// Whether a node is currently being processed at current depth
__host__ __device__ inline bool is_active(int nidx, int depth) {
return nidx >= n_nodes(depth - 1);
}
__host__ __device__ inline int parent_nidx(int nidx) { return (nidx - 1) / 2; }
__host__ __device__ inline int left_child_nidx(int nidx) {
return nidx * 2 + 1;
}
__host__ __device__ inline int right_child_nidx(int nidx) {
return nidx * 2 + 2;
}
__host__ __device__ inline bool is_left_child(int nidx) {
return nidx % 2 == 1;
}
// Copy gpu dense representation of tree to xgboost sparse representation
inline void dense2sparse_tree(RegTree* p_tree,
const dh::dvec<DeviceDenseNode>& nodes,
const TrainParam& param) {
RegTree& tree = *p_tree;
std::vector<DeviceDenseNode> h_nodes = nodes.as_vector();
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceDenseNode& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.AddChilds(nid);
tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir);
tree.stat(nid).loss_chg = n.root_gain;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess;
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (n.IsLeaf()) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess;
nid++;
}
}
}
/*
* Random
*/
struct BernoulliRng {
float p;
int seed;
__host__ __device__ BernoulliRng(float p, int seed) : p(p), seed(seed) {}
__host__ __device__ bool operator()(const int i) const {
thrust::default_random_engine rng(seed);
thrust::uniform_real_distribution<float> dist;
rng.discard(i);
return dist(rng) <= p;
}
};
// Set gradient pair to 0 with p = 1 - subsample
inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample,
int offset = 0) {
if (subsample == 1.0) {
return;
}
dh::dvec<bst_gpair>& gpair = *p_gpair;
auto d_gpair = gpair.data();
BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(gpair.device_idx(), gpair.size(), [=] __device__(int i) {
if (!rng(i + offset)) {
d_gpair[i] = bst_gpair();
}
});
}
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
return features;
}
} // namespace tree
} // namespace xgboost

File diff suppressed because it is too large Load Diff

View File

@ -1,126 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "../../src/exact/argmax_by_key.cuh"
#include "../../src/exact/node.cuh"
#include "utils.cuh"
namespace xgboost {
namespace tree {
namespace exact {
TEST(ArgMaxByKey, maxSplit) {
Split a, b, out;
a.score = 2.f;
a.index = 3;
b.score = 3.f;
b.index = 4;
out = maxSplit(a, b);
EXPECT_FLOAT_EQ(out.score, b.score);
EXPECT_EQ(out.index, b.index);
b.score = 2.f;
b.index = 4;
out = maxSplit(a, b);
EXPECT_FLOAT_EQ(out.score, a.score);
EXPECT_EQ(out.index, a.index);
b.score = 2.f;
b.index = 2;
out = maxSplit(a, b);
EXPECT_FLOAT_EQ(out.score, a.score);
EXPECT_EQ(out.index, b.index);
b.score = 1.f;
b.index = 1;
out = maxSplit(a, b);
EXPECT_FLOAT_EQ(out.score, a.score);
EXPECT_EQ(out.index, a.index);
}
template <typename node_id_t>
void argMaxTest(ArgMaxByKeyAlgo algo) {
const int nVals = 1024;
const int level = 0;
const int nKeys = 1 << level;
bst_gpair* scans = new bst_gpair[nVals];
float* vals = new float[nVals];
int* colIds = new int[nVals];
scans[0] = bst_gpair();
vals[0] = 0.f;
colIds[0] = 0;
for (int i = 1; i < nVals; ++i) {
scans[i].grad = scans[i-1].grad + (0.1f * 2.f);
scans[i].hess = scans[i-1].hess + (0.1f * 2.f);
vals[i] = static_cast<float>(i) * 0.1f;
colIds[i] = 0;
}
float* dVals;
allocateAndUpdateOnGpu<float>(dVals, vals, nVals);
bst_gpair* dScans;
allocateAndUpdateOnGpu<bst_gpair>(dScans, scans, nVals);
bst_gpair* sums = new bst_gpair[nKeys];
sums[0].grad = sums[0].hess = (0.1f * 2.f * nVals);
bst_gpair* dSums;
allocateAndUpdateOnGpu<bst_gpair>(dSums, sums, nKeys);
int* dColIds;
allocateAndUpdateOnGpu<int>(dColIds, colIds, nVals);
Split* splits = new Split[nKeys];
Split* dSplits;
allocateOnGpu<Split>(dSplits, nKeys);
node_id_t* nodeAssigns = new node_id_t[nVals];
memset(nodeAssigns, 0, sizeof(node_id_t)*nVals);
node_id_t* dNodeAssigns;
allocateAndUpdateOnGpu<node_id_t>(dNodeAssigns, nodeAssigns, nVals);
Node<node_id_t>* nodes = new Node<node_id_t>[nKeys];
nodes[0].gradSum = sums[0];
nodes[0].id = 0;
TrainParam param;
param.min_child_weight = 0.0f;
param.reg_alpha = 0.f;
param.reg_lambda = 2.f;
param.max_delta_step = 0.f;
nodes[0].score = CalcGain(param, sums[0].grad, sums[0].hess);
Node<node_id_t>* dNodes;
allocateAndUpdateOnGpu<Node<node_id_t> >(dNodes, nodes, nKeys);
argMaxByKey<node_id_t>(dSplits, dScans, dSums, dVals, dColIds, dNodeAssigns,
dNodes, nKeys, 0, nVals, param, algo);
updateHostPtr<Split>(splits, dSplits, nKeys);
EXPECT_FLOAT_EQ(0.f, splits->score);
EXPECT_EQ(0, splits->index);
dh::safe_cuda(cudaFree(dNodeAssigns));
delete [] nodeAssigns;
dh::safe_cuda(cudaFree(dSplits));
delete [] splits;
dh::safe_cuda(cudaFree(dColIds));
delete [] colIds;
dh::safe_cuda(cudaFree(dSums));
delete [] sums;
dh::safe_cuda(cudaFree(dVals));
delete [] vals;
dh::safe_cuda(cudaFree(dScans));
delete [] scans;
}
TEST(ArgMaxByKey, testOneColGmem) {
argMaxTest<int16_t>(ABK_GMEM);
}
TEST(ArgMaxByKey, testOneColSmem) {
argMaxTest<int16_t>(ABK_SMEM);
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,117 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "../../src/exact/fused_scan_reduce_by_key.cuh"
#include "../../src/exact/node.cuh"
#include "utils.cuh"
namespace xgboost {
namespace tree {
namespace exact {
template <typename node_id_t>
class ReduceScanByKey: public Generator<node_id_t> {
public:
ReduceScanByKey(int nc, int nr, int nk, const std::string& tName):
Generator<node_id_t>(nc, nr, nk, tName),
hSums(nullptr), dSums(nullptr), hScans(nullptr), dScans(nullptr),
outSize(this->size), nSegments(this->nKeys*this->nCols),
hOffsets(nullptr), dOffsets(nullptr) {
hSums = new bst_gpair[nSegments];
allocateOnGpu<bst_gpair>(dSums, nSegments);
hScans = new bst_gpair[outSize];
allocateOnGpu<bst_gpair>(dScans, outSize);
bst_gpair* buckets = new bst_gpair[nSegments];
for (int i = 0; i < nSegments; i++) {
buckets[i] = bst_gpair();
}
for (int i = 0; i < nSegments; i++) {
hSums[i] = bst_gpair();
}
for (size_t i = 0; i < this->size; i++) {
if (this->hKeys[i] >= 0 && this->hKeys[i] < nSegments) {
node_id_t key = abs2uniqKey<node_id_t>(i, this->hKeys,
this->hColIds, 0,
this->nKeys);
hSums[key] += this->hVals[i];
}
}
for (int i = 0; i < this->size; ++i) {
node_id_t key = abs2uniqKey<node_id_t>(i, this->hKeys,
this->hColIds, 0,
this->nKeys);
hScans[i] = buckets[key];
buckets[key] += this->hVals[i];
}
// it's a dense matrix that we are currently looking at, so offsets
// are nicely aligned! (need not be the case in real datasets)
hOffsets = new int[this->nCols];
size_t off = 0;
for (int i = 0; i < this->nCols; ++i, off+=this->nRows) {
hOffsets[i] = off;
}
allocateAndUpdateOnGpu<int>(dOffsets, hOffsets, this->nCols);
}
~ReduceScanByKey() {
delete [] hScans;
delete [] hSums;
delete [] hOffsets;
dh::safe_cuda(cudaFree(dScans));
dh::safe_cuda(cudaFree(dSums));
dh::safe_cuda(cudaFree(dOffsets));
}
void run() {
bst_gpair* tmpScans;
int* tmpKeys;
int tmpSize = scanTempBufferSize(this->size);
allocateOnGpu<bst_gpair>(tmpScans, tmpSize);
allocateOnGpu<int>(tmpKeys, tmpSize);
TIMEIT(reduceScanByKey<node_id_t>
(dSums, dScans, this->dVals, this->dInstIds, this->dKeys,
this->size, this->nKeys, this->nCols, tmpScans, tmpKeys,
this->dColIds, 0),
this->testName);
dh::safe_cuda(cudaFree(tmpScans));
dh::safe_cuda(cudaFree(tmpKeys));
this->compare(hSums, dSums, nSegments);
this->compare(hScans, dScans, outSize);
}
private:
bst_gpair* hSums;
bst_gpair* dSums;
bst_gpair* hScans;
bst_gpair* dScans;
int outSize;
int nSegments;
int* hOffsets;
int* dOffsets;
};
TEST(ReduceScanByKey, testInt16) {
ReduceScanByKey<int16_t>(32, 512, 32, "ReduceScanByKey").run();
}
TEST(ReduceScanByKey, testInt32) {
ReduceScanByKey<int>(32, 512, 32, "ReduceScanByKey").run();
}
} // namespace exact
} // namespace tree
} // namespace xgboost

File diff suppressed because it is too large Load Diff

View File

@ -1,308 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "utils.cuh"
#include "../../src/exact/gpu_builder.cuh"
#include "../../src/exact/node.cuh"
namespace xgboost {
namespace tree {
namespace exact {
static const std::vector<int> smallColSizes = {0, 5, 0, 6, 4, 0, 0, 2, 0, 11,
2, 9, 0, 5, 1, 0, 12, 3};
template <typename node_id_t>
void testSmallData() {
GPUBuilder<node_id_t> builder;
std::shared_ptr<DMatrix> dm =
setupGPUBuilder<node_id_t>("plugin/updater_gpu/test/cpp/data/small.sample.libsvm",
builder, 1);
// data dimensions
ASSERT_EQ(60, builder.nVals);
ASSERT_EQ(15, builder.nRows);
ASSERT_EQ(18, builder.nCols);
ASSERT_TRUE(builder.allocated);
// column counts
int* tmpOff = new int[builder.nCols+1];
updateHostPtr<int>(tmpOff, builder.colOffsets.data(), builder.nCols+1);
for (int i = 0; i < 15; ++i) {
EXPECT_EQ(smallColSizes[i], tmpOff[i+1]-tmpOff[i]);
}
float* tmpVal = new float[builder.nVals];
updateHostPtr<float>(tmpVal, builder.vals.current(), builder.nVals);
int* tmpInst = new int[builder.nVals];
updateHostPtr<int>(tmpInst, builder.instIds.current(), builder.nVals);
bst_gpair* tmpGrad = new bst_gpair[builder.nRows];
updateHostPtr<bst_gpair>(tmpGrad, builder.gradsInst.data(), builder.nRows);
EXPECT_EQ(0, tmpInst[0]);
EXPECT_FLOAT_EQ(1.f, tmpVal[0]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[0]%10), get(0, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[0]%10), get(0, tmpGrad, tmpInst).hess);
EXPECT_EQ(2, tmpInst[1]);
EXPECT_FLOAT_EQ(1.f, tmpVal[1]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[1]%10), get(1, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[1]%10), get(1, tmpGrad, tmpInst).hess);
EXPECT_EQ(7, tmpInst[2]);
EXPECT_FLOAT_EQ(1.f, tmpVal[2]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[2]%10), get(2, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[2]%10), get(2, tmpGrad, tmpInst).hess);
delete [] tmpGrad;
delete [] tmpOff;
delete [] tmpInst;
delete [] tmpVal;
int* colIds = new int[builder.nVals];
updateHostPtr<int>(colIds, builder.colIds.data(), builder.nVals);
std::vector<int> colSizeCopy(smallColSizes);
int colIdxCurr = 0;
for (int i = 0; i < builder.nVals; ++i) {
while (colSizeCopy[colIdxCurr] == 0) {
++colIdxCurr;
}
--colSizeCopy[colIdxCurr];
EXPECT_EQ(colIdxCurr, colIds[i]);
}
delete [] colIds;
}
TEST(CudaGPUBuilderTest, SetupOneTimeDataSmallInt16) {
testSmallData<int16_t>();
}
TEST(CudaGPUBuilderTest, SetupOneTimeDataSmallInt32) {
testSmallData<int>();
}
template <typename node_id_t>
void testLargeData() {
GPUBuilder<node_id_t> builder;
std::shared_ptr<DMatrix> dm =
setupGPUBuilder<node_id_t>("plugin/updater_gpu/test/cpp/data/sample.libsvm",
builder, 1);
ASSERT_EQ(35442, builder.nVals);
ASSERT_EQ(1611, builder.nRows);
ASSERT_EQ(127, builder.nCols);
ASSERT_TRUE(builder.allocated);
int* tmpOff = new int[builder.nCols+1];
updateHostPtr<int>(tmpOff, builder.colOffsets.data(), builder.nCols+1);
EXPECT_EQ(0, tmpOff[1]-tmpOff[0]); // 1st col
EXPECT_EQ(83, tmpOff[2]-tmpOff[1]); // 2nd col
EXPECT_EQ(1, tmpOff[3]-tmpOff[2]); // 3rd col
float* tmpVal = new float[builder.nVals];
updateHostPtr<float>(tmpVal, builder.vals.current(), builder.nVals);
int* tmpInst = new int[builder.nVals];
updateHostPtr<int>(tmpInst, builder.instIds.current(), builder.nVals);
bst_gpair* tmpGrad = new bst_gpair[builder.nRows];
updateHostPtr<bst_gpair>(tmpGrad, builder.gradsInst.data(), builder.nRows);
// the order of observations is messed up before the convertToCsc call!
// hence, the instance IDs have been manually checked and put here.
EXPECT_EQ(1164, tmpInst[0]);
EXPECT_FLOAT_EQ(1.f, tmpVal[0]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[0]%10), get(0, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[0]%10), get(0, tmpGrad, tmpInst).hess);
EXPECT_EQ(1435, tmpInst[1]);
EXPECT_FLOAT_EQ(1.f, tmpVal[1]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[1]%10), get(1, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[1]%10), get(1, tmpGrad, tmpInst).hess);
EXPECT_EQ(1421, tmpInst[2]);
EXPECT_FLOAT_EQ(1.f, tmpVal[2]);
EXPECT_FLOAT_EQ(1.f+(float)(tmpInst[2]%10), get(2, tmpGrad, tmpInst).grad);
EXPECT_FLOAT_EQ(.5f+(float)(tmpInst[2]%10), get(2, tmpGrad, tmpInst).hess);
delete [] tmpGrad;
delete [] tmpOff;
delete [] tmpInst;
delete [] tmpVal;
}
TEST(CudaGPUBuilderTest, SetupOneTimeDataLargeInt16) {
testLargeData<int16_t>();
}
TEST(CudaGPUBuilderTest, SetupOneTimeDataLargeInt32) {
testLargeData<int>();
}
int getColId(int* offsets, int id, int nCols) {
for (int i = 1; i <= nCols; ++i) {
if (id < offsets[i]) {
return (i-1);
}
}
return -1;
}
template <typename node_id_t>
void testAllocate() {
GPUBuilder<node_id_t> builder;
std::shared_ptr<DMatrix> dm =
setupGPUBuilder<node_id_t>("plugin/updater_gpu/test/cpp/data/small.sample.libsvm",
builder, 1);
ASSERT_EQ(3, builder.maxNodes);
ASSERT_EQ(2, builder.maxLeaves);
Node<node_id_t>* n = new Node<node_id_t>[builder.maxNodes];
updateHostPtr<Node<node_id_t> >(n, builder.nodes.data(), builder.maxNodes);
for (int i = 0; i < builder.maxNodes; ++i) {
if (i == 0) {
EXPECT_FALSE(n[i].isLeaf());
EXPECT_FALSE(n[i].isUnused());
} else {
EXPECT_TRUE(n[i].isLeaf());
EXPECT_FALSE(n[i].isUnused());
}
}
bst_gpair sum;
sum.grad = 0.f;
sum.hess = 0.f;
for (int i = 0; i < builder.maxNodes; ++i) {
if (!n[i].isUnused()) {
sum += n[i].gradSum;
}
}
// law of conservation of gradients! :)
EXPECT_FLOAT_EQ(2.f*n[0].gradSum.grad, sum.grad);
EXPECT_FLOAT_EQ(2.f*n[0].gradSum.hess, sum.hess);
node_id_t* assigns = new node_id_t[builder.nVals];
int* offsets = new int[builder.nCols+1];
updateHostPtr<node_id_t>(assigns, builder.nodeAssigns.current(),
builder.nVals);
updateHostPtr<int>(offsets, builder.colOffsets.data(), builder.nCols+1);
for (int i = 0; i < builder.nVals; ++i) {
EXPECT_EQ((node_id_t)0, assigns[i]);
}
delete [] n;
delete [] assigns;
delete [] offsets;
}
TEST(CudaGPUBuilderTest, AllocateNodeDataInt16) {
testAllocate<int16_t>();
}
TEST(CudaGPUBuilderTest, AllocateNodeDataInt32) {
testAllocate<int>();
}
template <typename node_id_t>
void assign(Node<node_id_t> *n, float g, float h, float sc, float wt,
DefaultDirection d, float th, int c, int i) {
n->gradSum.grad = g;
n->gradSum.hess = h;
n->score = sc;
n->weight = wt;
n->dir = d;
n->threshold = th;
n->colIdx = c;
n->id = (node_id_t)i;
}
template <typename node_id_t>
void testMarkLeaves() {
GPUBuilder<node_id_t> builder;
std::shared_ptr<DMatrix> dm =
setupGPUBuilder<node_id_t>("plugin/updater_gpu/test/cpp/data/small.sample.libsvm",
builder, 3);
ASSERT_EQ(15, builder.maxNodes);
ASSERT_EQ(8, builder.maxLeaves);
Node<node_id_t>* hNodes = new Node<node_id_t>[builder.maxNodes];
assign<node_id_t>(&hNodes[0], 2.f, 1.f, .75f, 0.5f, LeftDir, 0.25f, 0, 0);
assign<node_id_t>(&hNodes[1], 2.f, 1.f, .75f, 0.5f, RightDir, 0.5f, 1, 1);
assign<node_id_t>(&hNodes[2], 2.f, 1.f, .75f, 0.5f, LeftDir, 0.75f, 2, 2);
assign<node_id_t>(&hNodes[3], 2.f, 1.f, .75f, 0.5f, RightDir, 1.f, 3, 3);
assign<node_id_t>(&hNodes[4], 2.f, 1.f, .75f, 0.5f, LeftDir, 1.25f, 4, 4);
hNodes[5] = Node<node_id_t>();
assign<node_id_t>(&hNodes[6], 2.f, 1.f, .75f, 0.5f, LeftDir, 1.75f, 6, 6);
hNodes[7] = Node<node_id_t>();
hNodes[8] = Node<node_id_t>();
hNodes[9] = Node<node_id_t>();
hNodes[10] = Node<node_id_t>();
hNodes[11] = Node<node_id_t>();
hNodes[12] = Node<node_id_t>();
hNodes[13] = Node<node_id_t>();
hNodes[14] = Node<node_id_t>();
updateDevicePtr<Node<node_id_t> >(builder.nodes.data(), hNodes, builder.maxNodes);
builder.markLeaves();
Node<node_id_t>* outNodes = new Node<node_id_t>[builder.maxNodes];
updateHostPtr<Node<node_id_t> >(outNodes, builder.nodes.data(), builder.maxNodes);
for (int i = 0; i < builder.maxNodes; ++i) {
if ((i >= 7) || (i == 5)) {
EXPECT_TRUE(outNodes[i].isUnused());
} else {
EXPECT_FALSE(outNodes[i].isUnused());
}
}
for (int i = 0; i < builder.maxNodes; ++i) {
if ((i == 3) || (i == 4) || (i == 6)) {
EXPECT_TRUE(outNodes[i].isLeaf());
} else {
EXPECT_FALSE(outNodes[i].isLeaf());
}
}
delete [] outNodes;
delete [] hNodes;
}
TEST(CudaGPUBuilderTest, MarkLeavesInt16) {
testMarkLeaves<int16_t>();
}
TEST(CudaGPUBuilderTest, MarkLeavesInt32) {
testMarkLeaves<int>();
}
template <typename node_id_t>
void testDense2Sparse() {
GPUBuilder<node_id_t> builder;
std::shared_ptr<DMatrix> dm =
setupGPUBuilder<node_id_t>("plugin/updater_gpu/test/cpp/data/small.sample.libsvm",
builder, 3);
ASSERT_EQ(15, builder.maxNodes);
ASSERT_EQ(8, builder.maxLeaves);
Node<node_id_t>* hNodes = new Node<node_id_t>[builder.maxNodes];
assign<node_id_t>(&hNodes[0], 2.f, 1.f, .75f, 0.5f, LeftDir, 0.25f, 0, 0);
assign<node_id_t>(&hNodes[1], 2.f, 1.f, .75f, 0.5f, RightDir, 0.5f, 1, 1);
assign<node_id_t>(&hNodes[2], 2.f, 1.f, .75f, 0.5f, LeftDir, 0.75f, 2, 2);
assign<node_id_t>(&hNodes[3], 2.f, 1.f, .75f, 0.5f, RightDir, 1.f, 3, 3);
assign<node_id_t>(&hNodes[4], 2.f, 1.f, .75f, 0.5f, LeftDir, 1.25f, 4, 4);
hNodes[5] = Node<node_id_t>();
assign<node_id_t>(&hNodes[6], 2.f, 1.f, .75f, 0.5f, LeftDir, 1.75f, 6, 6);
assign<node_id_t>(&hNodes[7], 2.f, 1.f, .75f, 0.5f, LeftDir, 1.75f, 7, 7);
hNodes[8] = Node<node_id_t>();
hNodes[9] = Node<node_id_t>();
hNodes[10] = Node<node_id_t>();
hNodes[11] = Node<node_id_t>();
hNodes[12] = Node<node_id_t>();
hNodes[13] = Node<node_id_t>();
hNodes[14] = Node<node_id_t>();
updateDevicePtr<Node<node_id_t> >(builder.nodes.data(), hNodes, builder.maxNodes);
builder.markLeaves();
RegTree tree;
builder.dense2sparse(&tree);
EXPECT_EQ(9, tree.param.num_nodes);
delete [] hNodes;
}
TEST(CudaGPUBuilderTest, Dense2SparseInt16) {
testDense2Sparse<int16_t>();
}
TEST(CudaGPUBuilderTest, Dense2SparseInt32) {
testDense2Sparse<int>();
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,64 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "../../src/exact/node.cuh"
namespace xgboost {
namespace tree {
namespace exact {
TEST(Split, Test) {
Split s;
EXPECT_FALSE(s.isSplittable(0.5f));
s.score = 1.f;
EXPECT_FALSE(s.isSplittable(0.5f));
s.index = 2;
EXPECT_TRUE(s.isSplittable(0.5f));
EXPECT_FALSE(s.isSplittable(1.5f));
}
TEST(Node, Test) {
Node<int16_t> n;
EXPECT_TRUE(n.isUnused());
EXPECT_FALSE(n.isLeaf());
EXPECT_TRUE(n.isDefaultLeft());
n.dir = RightDir;
EXPECT_TRUE(n.isUnused());
EXPECT_FALSE(n.isLeaf());
EXPECT_FALSE(n.isDefaultLeft());
n.id = 123;
EXPECT_FALSE(n.isUnused());
EXPECT_TRUE(n.isLeaf());
EXPECT_FALSE(n.isDefaultLeft());
n.score = 0.5f;
EXPECT_FALSE(n.isUnused());
EXPECT_FALSE(n.isLeaf());
EXPECT_FALSE(n.isDefaultLeft());
}
TEST(Segment, Test) {
Segment s;
EXPECT_FALSE(s.isValid());
s.start = 2;
EXPECT_FALSE(s.isValid());
s.end = 4;
EXPECT_TRUE(s.isValid());
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,41 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils.cuh"
namespace xgboost {
namespace tree {
namespace exact {
std::shared_ptr<DMatrix> generateData(const std::string& file) {
std::shared_ptr<DMatrix> data(DMatrix::Load(file, false, false, "libsvm"));
return data;
}
std::shared_ptr<DMatrix> preparePluginInputs(const std::string &file,
std::vector<bst_gpair> *gpair) {
std::shared_ptr<DMatrix> dm = generateData(file);
gpair->reserve(dm->info().num_row);
for (int i = 0; i < dm->info().num_row; ++i) {
gpair->push_back(bst_gpair(1.f+static_cast<float>(i%10),
0.5f+static_cast<float>(i%10)));
}
return dm;
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -1,231 +0,0 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <random>
#include <memory>
#include <string>
#include <xgboost/data.h>
#include "gtest/gtest.h"
#include "../../src/exact/gpu_builder.cuh"
#include "../../src/device_helpers.cuh"
#include <vector>
#include <stdlib.h>
namespace xgboost {
namespace tree {
namespace exact {
template <typename T>
inline void allocateOnGpu(T*& arr, size_t nElems) {
dh::safe_cuda(cudaMalloc((void**)&arr, sizeof(T)*nElems));
}
template <typename T>
inline void updateDevicePtr(T* dArr, const T* hArr, size_t nElems) {
dh::safe_cuda(cudaMemcpy(dArr, hArr, sizeof(T)*nElems, cudaMemcpyHostToDevice));
}
template <typename T>
inline void updateHostPtr(T* hArr, const T* dArr, size_t nElems) {
dh::safe_cuda(cudaMemcpy(hArr, dArr, sizeof(T)*nElems, cudaMemcpyDeviceToHost));
}
template <typename T>
inline void allocateAndUpdateOnGpu(T*& dArr, const T* hArr, size_t nElems) {
allocateOnGpu<T>(dArr, nElems);
updateDevicePtr<T>(dArr, hArr, nElems);
}
static const float Thresh = 0.005f;
static const float SuperSmall = 0.001f;
static const float SuperSmallThresh = 0.00001f;
// lets assume dense matrix for simplicity
template <typename T>
class Generator {
public:
Generator(int nc, int nr, int nk, const std::string& tName):
nCols(nc), nRows(nr), nKeys(nk), size(nc*nr), hKeys(nullptr),
dKeys(nullptr), hVals(nullptr), dVals(nullptr), testName(tName),
dColIds(nullptr), hColIds(nullptr), dInstIds(nullptr),
hInstIds(nullptr) {
generateKeys();
generateVals();
// to simulate the same sorted key-value pairs in the main code
// don't need it as generateKeys always generates in sorted order!
//sortKeyValues();
evalColIds();
evalInstIds();
}
virtual ~Generator() {
delete [] hKeys;
delete [] hVals;
delete [] hColIds;
delete [] hInstIds;
dh::safe_cuda(cudaFree(dColIds));
dh::safe_cuda(cudaFree(dKeys));
dh::safe_cuda(cudaFree(dVals));
dh::safe_cuda(cudaFree(dInstIds));
}
virtual void run() = 0;
protected:
int nCols;
int nRows;
int nKeys;
int size;
T* hKeys;
T* dKeys;
bst_gpair* hVals;
bst_gpair* dVals;
std::string testName;
int* dColIds;
int* hColIds;
int* dInstIds;
int* hInstIds;
void evalColIds() {
hColIds = new int[size];
for (int i=0;i<size;++i) {
hColIds[i] = i / nRows;
}
allocateAndUpdateOnGpu<int>(dColIds, hColIds, size);
}
void evalInstIds() {
hInstIds = new int[size];
for (int i=0;i<size;++i) {
hInstIds[i] = i;
}
allocateAndUpdateOnGpu<int>(dInstIds, hInstIds, size);
}
float diffRatio(float a, float b, bool& isSmall) {
isSmall = true;
if (a == 0.f) return fabs(b);
else if (b == 0.f) return fabs(a);
else if ((fabs(a) < SuperSmall) && (fabs(b) < SuperSmall)) {
return fabs(a - b);
}
else {
isSmall = false;
return fabs((a < b)? (b - a)/b : (a - b)/a);
}
}
void compare(bst_gpair* exp, bst_gpair* dAct, size_t len) {
bst_gpair* act = new bst_gpair[len];
updateHostPtr<bst_gpair>(act, dAct, len);
for (size_t i=0;i<len;++i) {
bool isSmall;
float ratioG = diffRatio(exp[i].grad, act[i].grad, isSmall);
float ratioH = diffRatio(exp[i].hess, act[i].hess, isSmall);
float thresh = isSmall? SuperSmallThresh : Thresh;
if ((ratioG >= Thresh) || (ratioH >= Thresh)) {
printf("(exp) %f %f -> (act) %f %f : rG=%f rH=%f th=%f @%lu\n",
exp[i].grad, exp[i].hess, act[i].grad, act[i].hess, ratioG, ratioH,
thresh, i);
}
ASSERT_TRUE(ratioG < thresh);
ASSERT_TRUE(ratioH < thresh);
}
delete [] act;
}
void generateKeys() {
hKeys = new T[size];
T currKey = 0;
for (int i=0;i<size;++i) {
if (i % nRows == 0) { // start fresh for a new column
currKey = 0;
}
hKeys[i] = currKey;
float val = randVal();
if ((val > 0.8f) && (currKey < nKeys-1)) {
++currKey;
}
}
allocateAndUpdateOnGpu<T>(dKeys, hKeys, size);
}
void generateVals() {
hVals = new bst_gpair[size];
for (size_t i=0;i<size;++i) {
hVals[i].grad = randVal(-1.f, 1.f);
hVals[i].hess = randVal(-1.f, 1.f);
}
allocateAndUpdateOnGpu<bst_gpair>(dVals, hVals, size);
}
void sortKeyValues() {
char* storage = nullptr;
size_t tmpSize;
dh::safe_cuda(cub::DeviceRadixSort::SortPairs(NULL, tmpSize, dKeys, dKeys,
dVals, dVals, size));
allocateOnGpu<char>(storage, tmpSize);
void* tmpStorage = static_cast<void*>(storage);
dh::safe_cuda(cub::DeviceRadixSort::SortPairs(tmpStorage, tmpSize, dKeys,
dKeys, dVals, dVals, size));
dh::safe_cuda(cudaFree(storage));
updateHostPtr<bst_gpair>(hVals, dVals, size);
updateHostPtr<T>(hKeys, dKeys, size);
}
float randVal() const {
float val = rand() * 1.f / RAND_MAX;
return val;
}
float randVal(float min, float max) const {
float val = randVal();
val = (val * (max - min)) - min;
return val;
}
};
std::shared_ptr<DMatrix> generateData(const std::string& file);
std::shared_ptr<DMatrix> preparePluginInputs(const std::string& file,
std::vector<bst_gpair> *gpair);
template <typename node_id_t>
std::shared_ptr<DMatrix> setupGPUBuilder(const std::string& file,
GPUBuilder<node_id_t> &builder,
int max_depth=1) {
std::vector<bst_gpair> gpair;
std::shared_ptr<DMatrix> dm = preparePluginInputs(file, &gpair);
TrainParam p;
RegTree tree;
p.min_split_loss = 0.f;
p.max_depth = max_depth;
p.min_child_weight = 0.f;
p.reg_alpha = 0.f;
p.reg_lambda = 1.f;
p.max_delta_step = 0.f;
builder.Init(p);
builder.Update(gpair, dm.get(), &tree);
return dm;
}
} // namespace exact
} // namespace tree
} // namespace xgboost

View File

@ -32,5 +32,9 @@ DMLC_REGISTRY_LINK_TAG(updater_prune);
DMLC_REGISTRY_LINK_TAG(updater_fast_hist); DMLC_REGISTRY_LINK_TAG(updater_fast_hist);
DMLC_REGISTRY_LINK_TAG(updater_histmaker); DMLC_REGISTRY_LINK_TAG(updater_histmaker);
DMLC_REGISTRY_LINK_TAG(updater_sync); DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu);
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
#endif
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost