[GPU-Plugin] Improved split finding performance. (#2325)

This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-05-19 19:16:24 -07:00 committed by Rory Mitchell
parent 29289d2302
commit 3ca64ffa02
4 changed files with 206 additions and 193 deletions

View File

@ -168,5 +168,17 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
return features; return features;
} }
struct GpairCallbackOp {
// Running prefix
gpu_gpair running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
gpu_gpair old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -2,11 +2,11 @@
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#pragma once #pragma once
#include <cub/cub.cuh>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <cub/cub.cuh>
#include "common.cuh"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "types.cuh" #include "types.cuh"
#include "common.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -48,19 +48,8 @@ struct GpairTupleCallbackOp {
} }
}; };
struct GpairCallbackOp { template <int BLOCK_THREADS>
// Running prefix struct ReduceEnactorSorting {
gpu_gpair running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
gpu_gpair old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <int BLOCK_THREADS> struct ReduceEnactorSorting {
typedef cub::BlockScan<ScanTuple, BLOCK_THREADS> GpairScanT; typedef cub::BlockScan<ScanTuple, BLOCK_THREADS> GpairScanT;
struct _TempStorage { struct _TempStorage {
typename GpairScanT::TempStorage gpair_scan; typename GpairScanT::TempStorage gpair_scan;
@ -82,13 +71,15 @@ template <int BLOCK_THREADS> struct ReduceEnactorSorting {
const int level; const int level;
__device__ __forceinline__ __device__ __forceinline__
ReduceEnactorSorting(TempStorage &temp_storage, // NOLINT ReduceEnactorSorting(TempStorage &temp_storage, // NOLINT
gpu_gpair *d_block_node_sums, int *d_block_node_offsets, gpu_gpair *d_block_node_sums, int *d_block_node_offsets,
ItemIter item_iter, const int level) ItemIter item_iter, const int level)
: temp_storage(temp_storage.Alias()), : temp_storage(temp_storage.Alias()),
d_block_node_sums(d_block_node_sums), d_block_node_sums(d_block_node_sums),
d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), d_block_node_offsets(d_block_node_offsets),
callback_op(), level(level) {} item_iter(item_iter),
callback_op(),
level(level) {}
__device__ __forceinline__ void LoadTile(const bst_uint &offset, __device__ __forceinline__ void LoadTile(const bst_uint &offset,
const bst_uint &num_remaining) { const bst_uint &num_remaining) {
@ -102,7 +93,7 @@ template <int BLOCK_THREADS> struct ReduceEnactorSorting {
// Prevent overflow // Prevent overflow
const int level_begin = (1 << level) - 1; const int level_begin = (1 << level) - 1;
node_id_adjusted = node_id_adjusted =
max(static_cast<int>(node_id) - level_begin, -1); // NOLINT max(static_cast<int>(node_id) - level_begin, -1); // NOLINT
} }
} }
@ -175,15 +166,18 @@ struct FindSplitEnactorSorting {
const int level; const int level;
__device__ __forceinline__ FindSplitEnactorSorting( __device__ __forceinline__ FindSplitEnactorSorting(
TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT
int *d_block_node_offsets, const ItemIter item_iter, const Node *d_nodes, int *d_block_node_offsets, const ItemIter item_iter, const Node *d_nodes,
const GPUTrainingParam &param, Split *d_split_candidates_out, const GPUTrainingParam &param, Split *d_split_candidates_out,
const int level) const int level)
: temp_storage(temp_storage.Alias()), : temp_storage(temp_storage.Alias()),
d_block_node_sums(d_block_node_sums), d_block_node_sums(d_block_node_sums),
d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), d_block_node_offsets(d_block_node_offsets),
d_nodes(d_nodes), d_split_candidates_out(d_split_candidates_out), item_iter(item_iter),
level(level), param(param) {} d_nodes(d_nodes),
d_split_candidates_out(d_split_candidates_out),
level(level),
param(param) {}
__device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted, __device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted,
const bst_uint &node_begin, const bst_uint &node_begin,
@ -254,9 +248,9 @@ struct FindSplitEnactorSorting {
return fvalue != left_fvalue; return fvalue != left_fvalue;
} }
__device__ __forceinline__ void __device__ __forceinline__ void EvaluateSplits(
EvaluateSplits(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
const bst_uint &offset, const bst_uint &num_remaining) { const bst_uint &offset, const bst_uint &num_remaining) {
bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining && bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining &&
node_id_adjusted >= 0 && node_id >= 0; node_id_adjusted >= 0 && node_id >= 0;
@ -289,10 +283,10 @@ struct FindSplitEnactorSorting {
} }
} }
__device__ __forceinline__ void __device__ __forceinline__ void ProcessTile(
ProcessTile(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
const bst_uint &offset, const bst_uint &num_remaining, const bst_uint &offset, const bst_uint &num_remaining,
GpairCallbackOp &callback_op) { // NOLINT GpairCallbackOp &callback_op) { // NOLINT
LoadTile(node_id_adjusted, node_begin, offset, num_remaining); LoadTile(node_id_adjusted, node_begin, offset, num_remaining);
// Scan gpair // Scan gpair
@ -304,8 +298,8 @@ struct FindSplitEnactorSorting {
EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining); EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining);
} }
__device__ __forceinline__ void __device__ __forceinline__ void WriteBestSplit(
WriteBestSplit(const NodeIdT &node_id_adjusted) { const NodeIdT &node_id_adjusted) {
if (threadIdx.x < 32) { if (threadIdx.x < 32) {
bool active = threadIdx.x < N_WARPS; bool active = threadIdx.x < N_WARPS;
float warp_loss = float warp_loss =
@ -370,7 +364,6 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
const Node *d_nodes, bst_uint num_items, const int num_features, const Node *d_nodes, bst_uint num_items, const int num_features,
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
const GPUTrainingParam param, const int *d_feature_flags, const int level) { const GPUTrainingParam param, const int *d_feature_flags, const int level) {
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) { if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
return; return;
} }
@ -400,7 +393,7 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
.ProcessFeature(segment_begin, segment_end); .ProcessFeature(segment_begin, segment_end);
} }
void find_split_candidates_sorted(GPUData * data, const int level) { void find_split_candidates_sorted(GPUData *data, const int level) {
const int BLOCK_THREADS = 512; const int BLOCK_THREADS = 512;
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps."; CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
@ -410,9 +403,9 @@ void find_split_candidates_sorted(GPUData * data, const int level) {
find_split_candidates_sorted_kernel< find_split_candidates_sorted_kernel<
BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>( BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>(
data->items_iter, data->split_candidates.data(), data->nodes.data(), data->items_iter, data->split_candidates.data(), data->nodes.data(),
data->fvalues.size(), data->n_features, data->fvalues.size(), data->n_features, data->foffsets.data(),
data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(), data->node_sums.data(), data->node_offsets.data(), data->param,
data->param, data->feature_flags.data(), level); data->feature_flags.data(), level);
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaDeviceSynchronize());

View File

@ -245,174 +245,183 @@ __global__ void find_split_kernel(
} }
} }
} }
template <int BLOCK_THREADS>
__global__ void find_split_general_kernel(
const gpu_gpair* d_level_hist, int* d_feature_segments, int depth,
int n_features, int n_bins, Node* d_nodes, float* d_fidx_min_map,
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
bool* d_left_child_smallest, bool colsample, int* d_feature_flags) {
typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<gpu_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<gpu_gpair, BLOCK_THREADS> SumReduceT;
void GPUHistBuilder::FindSplit(int depth) { union TempStorage {
// Specialised based on max_bins typename BlockScanT::TempStorage scan;
if (param.max_bin <= 256) { typename MaxReduceT::TempStorage max_reduce;
this->FindSplit256(depth); typename SumReduceT::TempStorage sum_reduce;
} else if (param.max_bin <= 1024) { };
this->FindSplit1024(depth);
} else {
this->FindSplitLarge(depth);
}
}
void GPUHistBuilder::FindSplit256(int depth) { struct UninitializedSplit : cub::Uninitialized<Split> {};
CHECK_LE(param.max_bin, 256); struct UninitializedGpair : cub::Uninitialized<gpu_gpair> {};
const int BLOCK_THREADS = 256;
const int GRID_SIZE = n_nodes_level(depth);
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
feature_flags.data());
dh::safe_cuda(cudaDeviceSynchronize()); __shared__ UninitializedSplit uninitialized_split;
} Split& split = uninitialized_split.Alias();
void GPUHistBuilder::FindSplit1024(int depth) { __shared__ UninitializedGpair uninitialized_sum;
CHECK_LE(param.max_bin, 1024); gpu_gpair& shared_sum = uninitialized_sum.Alias();
const int BLOCK_THREADS = 1024; __shared__ ArgMaxT block_max;
const int GRID_SIZE = n_nodes_level(depth); __shared__ TempStorage temp_storage;
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
feature_flags.data());
dh::safe_cuda(cudaDeviceSynchronize()); if (threadIdx.x == 0) {
} split = Split();
void GPUHistBuilder::FindSplitLarge(int depth) {
auto counting = thrust::make_counting_iterator(0);
auto d_gidx_feature_map = gidx_feature_map.data();
int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1;
auto feature_boundary = [=] __device__(int idx_a, int idx_b) {
int gidx_a = idx_a % n_bins;
int gidx_b = idx_b % n_bins;
return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b];
}; // NOLINT
// Reduce node sums
{
size_t temp_storage_bytes;
cub::DeviceSegmentedReduce::Reduce(
nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::Reduce(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
} }
// Scan __syncthreads();
thrust::exclusive_scan_by_key(
counting, counting + hist.LevelSize(depth),
thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(),
gpu_gpair(), feature_boundary);
// Calculate gain int node_idx = n_nodes(depth - 1) + blockIdx.x;
auto d_gain = gain.data();
auto d_nodes = nodes.data();
auto d_node_sums = node_sums.data();
auto d_hist_scan = hist_scan.data();
GPUTrainingParam gpu_param_alias =
gpu_param; // Must be local variable to be used in device lambda
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
auto d_feature_flags = feature_flags.data();
dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) { for (int fidx = 0; fidx < n_features; fidx++) {
int node_segment = idx / n_bins; if (colsample && d_feature_flags[fidx] == 0) continue;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
int gidx = idx % n_bins;
int findex = d_gidx_feature_map[gidx];
// colsample int begin = d_feature_segments[blockIdx.x * n_features + fidx];
if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) { int end = d_feature_segments[blockIdx.x * n_features + fidx + 1];
d_gain[idx] = 0; int gidx = (begin - (blockIdx.x * n_bins)) + threadIdx.x;
} else { bool thread_active = threadIdx.x < end - begin;
gpu_gpair scan = d_hist_scan[idx];
gpu_gpair sum = d_node_sums[node_segment * n_features + findex]; gpu_gpair feature_sum = gpu_gpair();
gpu_gpair missing = parent_sum - sum; for (int reduce_begin = begin; reduce_begin < end;
reduce_begin += BLOCK_THREADS) {
// Scan histogram
gpu_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
: gpu_gpair();
feature_sum +=
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
}
if (threadIdx.x == 0) {
shared_sum = feature_sum;
}
// __syncthreads(); // no need to synch because below there is a Scan
GpairCallbackOp prefix_op = GpairCallbackOp();
for (int scan_begin = begin; scan_begin < end;
scan_begin += BLOCK_THREADS) {
gpu_gpair bin =
thread_active ? d_level_hist[scan_begin + threadIdx.x] : gpu_gpair();
BlockScanT(temp_storage.scan)
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Calculate gain
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
gpu_gpair missing = parent_sum - shared_sum;
bool missing_left; bool missing_left;
d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain, float gain = thread_active
gpu_param_alias, missing_left); ? loss_chg_missing(bin, missing, parent_sum, parent_gain,
gpu_param, missing_left)
: -FLT_MAX;
__syncthreads();
// Find thread with best gain
ArgMaxT tuple(threadIdx.x, gain);
ArgMaxT best =
MaxReduceT(temp_storage.max_reduce).Reduce(tuple, cub::ArgMax());
if (threadIdx.x == 0) {
block_max = best;
}
__syncthreads();
// Best thread updates split
if (threadIdx.x == block_max.key) {
float fvalue;
if (threadIdx.x == 0 &&
begin == scan_begin) { // check at start of first tile
fvalue = d_fidx_min_map[fidx];
} else {
fvalue = d_gidx_fvalue_map[gidx - 1];
}
gpu_gpair left = missing_left ? bin + missing : bin;
gpu_gpair right = parent_sum - left;
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
}
__syncthreads();
} // end scan
} // end over features
// Create node
if (threadIdx.x == 0) {
d_nodes[node_idx].split = split;
if (depth == 0) {
// split.Print();
} }
});
dh::safe_cuda(cudaDeviceSynchronize());
// Find best gain d_nodes[left_child_nidx(node_idx)] = Node(
{ split.left_sum,
size_t temp_storage_bytes; CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()),
cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(), CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess()));
argmax.data(), n_nodes_level(depth),
hist_node_segments.data(),
hist_node_segments.data() + 1);
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::ArgMax(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(),
argmax.data(), n_nodes_level(depth), hist_node_segments.data(),
hist_node_segments.data() + 1);
}
auto d_argmax = argmax.data(); d_nodes[right_child_nidx(node_idx)] = Node(
auto d_gidx_fvalue_map = gidx_fvalue_map.data(); split.right_sum,
auto d_fidx_min_map = fidx_min_map.data(); CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()),
auto d_left_child_smallest = left_child_smallest.data(); CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess()));
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
int max_idx = n_bins * idx + d_argmax[idx].key;
int gidx = max_idx % n_bins;
int fidx = d_gidx_feature_map[gidx];
int node_segment = max_idx / n_bins;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair scan = d_hist_scan[max_idx];
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
gpu_gpair sum = d_node_sums[node_segment * n_features + fidx];
gpu_gpair missing = parent_sum - sum;
bool missing_left;
float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain,
gpu_param_alias, missing_left);
float fvalue;
if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) {
fvalue = d_fidx_min_map[fidx];
} else {
fvalue = d_gidx_fvalue_map[gidx - 1];
}
gpu_gpair left = missing_left ? scan + missing : scan;
gpu_gpair right = parent_sum - left;
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
right, gpu_param_alias);
d_nodes[left_child_nidx(node_idx)] =
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
d_nodes[right_child_nidx(node_idx)] =
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
// Record smallest node // Record smallest node
if (left.hess() <= right.hess()) { if (split.left_sum.hess() <= split.right_sum.hess()) {
d_left_child_smallest[node_idx] = true; d_left_child_smallest[node_idx] = true;
} else { } else {
d_left_child_smallest[node_idx] = false; d_left_child_smallest[node_idx] = false;
} }
}); }
}
#define MIN_BLOCK_THREADS 32
#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size
void GPUHistBuilder::FindSplit(int depth) {
// Specialised based on max_bins
this->FindSplitSpecialize<MIN_BLOCK_THREADS>(depth);
}
template <>
void GPUHistBuilder::FindSplitSpecialize<MAX_BLOCK_THREADS>(int depth) {
const int GRID_SIZE = n_nodes_level(depth);
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
find_split_general_kernel<
MAX_BLOCK_THREADS><<<GRID_SIZE, MAX_BLOCK_THREADS>>>(
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
feature_flags.data());
dh::safe_cuda(cudaDeviceSynchronize());
}
template <int BLOCK_THREADS>
void GPUHistBuilder::FindSplitSpecialize(int depth) {
if (param.max_bin <= BLOCK_THREADS) {
const int GRID_SIZE = n_nodes_level(depth);
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
find_split_general_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(),
colsample, feature_flags.data());
} else {
this->FindSplitSpecialize<BLOCK_THREADS + 32>(depth);
}
dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaDeviceSynchronize());
} }

View File

@ -62,17 +62,16 @@ class GPUHistBuilder {
RegTree *p_tree); RegTree *p_tree);
void BuildHist(int depth); void BuildHist(int depth);
void FindSplit(int depth); void FindSplit(int depth);
void FindSplit256(int depth); template <int BLOCK_THREADS>
void FindSplit1024(int depth); void FindSplitSpecialize(int depth);
void FindSplitLarge(int depth);
void InitFirstNode(); void InitFirstNode();
void UpdatePosition(int depth); void UpdatePosition(int depth);
void UpdatePositionDense(int depth); void UpdatePositionDense(int depth);
void UpdatePositionSparse(int depth); void UpdatePositionSparse(int depth);
void ColSampleTree(); void ColSampleTree();
void ColSampleLevel(); void ColSampleLevel();
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix *data,
std::vector<bst_float>* p_out_preds); std::vector<bst_float> *p_out_preds);
TrainParam param; TrainParam param;
GPUTrainingParam gpu_param; GPUTrainingParam gpu_param;
@ -82,7 +81,7 @@ class GPUHistBuilder {
bool initialised; bool initialised;
bool is_dense; bool is_dense;
DeviceGMat device_matrix; DeviceGMat device_matrix;
const DMatrix* p_last_fmat_; const DMatrix *p_last_fmat_;
dh::bulk_allocator ba; dh::bulk_allocator ba;
dh::CubMemory cub_mem; dh::CubMemory cub_mem;
@ -101,8 +100,8 @@ class GPUHistBuilder {
dh::dvec<gpu_gpair> device_gpair; dh::dvec<gpu_gpair> device_gpair;
dh::dvec<Node> nodes; dh::dvec<Node> nodes;
dh::dvec<int> feature_flags; dh::dvec<int> feature_flags;
dh::dvec<bool> left_child_smallest; dh::dvec<bool> left_child_smallest;
dh::dvec<bst_float> prediction_cache; dh::dvec<bst_float> prediction_cache;
bool prediction_cache_initialised; bool prediction_cache_initialised;
std::vector<int> feature_set_tree; std::vector<int> feature_set_tree;