Add categorical data support to GPU Hist. (#6164)
This commit is contained in:
parent
798af22ff4
commit
444131a2e6
@ -536,6 +536,21 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template <class HContainer, class DContainer>
|
||||
void CopyToD(HContainer const &h, DContainer *d) {
|
||||
if (h.empty()) {
|
||||
d->clear();
|
||||
return;
|
||||
}
|
||||
d->resize(h.size());
|
||||
using HVT = std::remove_cv_t<typename HContainer::value_type>;
|
||||
using DVT = std::remove_cv_t<typename DContainer::value_type>;
|
||||
static_assert(std::is_same<HVT, DVT>::value,
|
||||
"Host and device containers must have same value type.");
|
||||
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Keep track of pinned memory allocation
|
||||
struct PinnedMemory {
|
||||
void *temp_storage{nullptr};
|
||||
|
||||
@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#include "evaluate_splits.cuh"
|
||||
#include <limits>
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "../../common/categorical.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -66,13 +67,84 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
||||
if (threadIdx.x == 0) {
|
||||
shared_sum = local_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
cub::CTA_SYNC();
|
||||
return shared_sum;
|
||||
}
|
||||
|
||||
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
|
||||
GradientSumT __device__ operator()(
|
||||
bool thread_active, uint32_t scan_begin,
|
||||
SumCallbackOp<GradientSumT>*,
|
||||
GradientSumT const &missing,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs, TempStorageT *) {
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
auto rest = inputs.parent_sum - bin - missing;
|
||||
return rest;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct UpdateOneHot {
|
||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||
bst_feature_t fidx, GradientSumT const &missing,
|
||||
GradientSumT const &bin,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||
DeviceSplitCandidate *best_split) {
|
||||
int split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = inputs.feature_values[split_gidx];
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx,
|
||||
GradientPair(left), GradientPair(right), true,
|
||||
inputs.param);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT, typename TempStorageT, typename ScanT>
|
||||
struct NumericBin {
|
||||
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
||||
SumCallbackOp<GradientSumT>* prefix_callback,
|
||||
GradientSumT const &missing,
|
||||
EvaluateSplitInputs<GradientSumT> inputs,
|
||||
TempStorageT *temp_storage) {
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback);
|
||||
return bin;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct UpdateNumeric {
|
||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||
bst_feature_t fidx, GradientSumT const &missing,
|
||||
GradientSumT const &bin,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||
DeviceSplitCandidate *best_split) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
||||
fidx, GradientPair(left), GradientPair(right),
|
||||
false, inputs.param);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT,
|
||||
typename BinFn, typename UpdateFn>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
@ -83,12 +155,14 @@ __device__ void EvaluateFeature(
|
||||
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end =
|
||||
inputs.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
|
||||
auto bin_fn = BinFn();
|
||||
auto update_fn = UpdateFn();
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum =
|
||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
|
||||
inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin),
|
||||
temp_storage);
|
||||
feature_hist, temp_storage);
|
||||
|
||||
GradientSumT const missing = inputs.parent_sum - feature_sum;
|
||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||
@ -97,12 +171,7 @@ __device__ void EvaluateFeature(
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
|
||||
scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
// Gradient value for current bin.
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage);
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
@ -127,24 +196,14 @@ __device__ void EvaluateFeature(
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
cub::CTA_SYNC();
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
||||
fidx, GradientPair(left), GradientPair(right),
|
||||
inputs.param);
|
||||
update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs,
|
||||
best_split);
|
||||
}
|
||||
__syncthreads();
|
||||
cub::CTA_SYNC();
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,11 +245,21 @@ __global__ void EvaluateSplitsKernel(
|
||||
// One block for each feature. Features are sampled, so fidx != blockIdx.x
|
||||
int fidx = inputs.feature_set[is_left ? blockIdx.x
|
||||
: blockIdx.x - left.feature_set.size()];
|
||||
if (common::IsCat(inputs.feature_types, fidx)) {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
|
||||
TempStorage, GradientSumT,
|
||||
OneHotBin<GradientSumT, TempStorage>,
|
||||
UpdateOneHot<GradientSumT>>(fidx, inputs, evaluator, &best_split,
|
||||
&temp_storage);
|
||||
} else {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
|
||||
TempStorage, GradientSumT,
|
||||
NumericBin<GradientSumT, TempStorage, BlockScanT>,
|
||||
UpdateNumeric<GradientSumT>>(fidx, inputs, evaluator, &best_split,
|
||||
&temp_storage);
|
||||
}
|
||||
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
|
||||
fidx, inputs, evaluator, &best_split, &temp_storage);
|
||||
|
||||
__syncthreads();
|
||||
cub::CTA_SYNC();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// Record best loss for each feature
|
||||
|
||||
@ -18,6 +18,7 @@ struct EvaluateSplitInputs {
|
||||
GradientSumT parent_sum;
|
||||
GPUTrainingParam param;
|
||||
common::Span<const bst_feature_t> feature_set;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
common::Span<const uint32_t> feature_segments;
|
||||
common::Span<const float> feature_values;
|
||||
common::Span<const float> min_fvalue;
|
||||
|
||||
@ -59,6 +59,7 @@ struct DeviceSplitCandidate {
|
||||
DefaultDirection dir {kLeftDir};
|
||||
int findex {-1};
|
||||
float fvalue {0};
|
||||
bool is_cat { false };
|
||||
|
||||
GradientPair left_sum;
|
||||
GradientPair right_sum;
|
||||
@ -79,6 +80,7 @@ struct DeviceSplitCandidate {
|
||||
float fvalue_in, int findex_in,
|
||||
GradientPair left_sum_in,
|
||||
GradientPair right_sum_in,
|
||||
bool cat,
|
||||
const GPUTrainingParam& param) {
|
||||
if (loss_chg_in > loss_chg &&
|
||||
left_sum_in.GetHess() >= param.min_child_weight &&
|
||||
@ -86,6 +88,7 @@ struct DeviceSplitCandidate {
|
||||
loss_chg = loss_chg_in;
|
||||
dir = dir_in;
|
||||
fvalue = fvalue_in;
|
||||
is_cat = cat;
|
||||
left_sum = left_sum_in;
|
||||
right_sum = right_sum_in;
|
||||
findex = findex_in;
|
||||
@ -98,6 +101,7 @@ struct DeviceSplitCandidate {
|
||||
<< "dir: " << c.dir << ", "
|
||||
<< "findex: " << c.findex << ", "
|
||||
<< "fvalue: " << c.fvalue << ", "
|
||||
<< "is_cat: " << c.is_cat << ", "
|
||||
<< "left sum: " << c.left_sum << ", "
|
||||
<< "right sum: " << c.right_sum << std::endl;
|
||||
return os;
|
||||
|
||||
@ -19,7 +19,9 @@
|
||||
#include "../common/io.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
|
||||
#include "param.h"
|
||||
@ -161,6 +163,7 @@ template <typename GradientSumT>
|
||||
struct GPUHistMakerDevice {
|
||||
int device_id;
|
||||
EllpackPageImpl* page;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
BatchParam batch_param;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
@ -169,7 +172,6 @@ struct GPUHistMakerDevice {
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::caching_device_vector<int> monotone_constraints;
|
||||
dh::caching_device_vector<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
@ -191,9 +193,12 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
// Storing split categories for last node.
|
||||
dh::caching_device_vector<uint32_t> node_categories;
|
||||
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
common::Span<FeatureType const> _feature_types,
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
@ -202,6 +207,7 @@ struct GPUHistMakerDevice {
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
feature_types{_feature_types},
|
||||
param(std::move(_param)),
|
||||
tree_evaluator(param, n_features, _device_id),
|
||||
column_sampler(column_sampler_seed),
|
||||
@ -293,6 +299,7 @@ struct GPUHistMakerDevice {
|
||||
{root_sum.GetGrad(), root_sum.GetHess()},
|
||||
gpu_param,
|
||||
feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@ -331,6 +338,7 @@ struct GPUHistMakerDevice {
|
||||
candidate.split.left_sum.GetHess()},
|
||||
gpu_param,
|
||||
left_feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@ -341,6 +349,7 @@ struct GPUHistMakerDevice {
|
||||
candidate.split.right_sum.GetHess()},
|
||||
gpu_param,
|
||||
right_feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@ -399,8 +408,11 @@ struct GPUHistMakerDevice {
|
||||
hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||
void UpdatePosition(int nidx, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[nidx];
|
||||
auto split_type = p_tree->NodeSplitType(nidx);
|
||||
auto d_matrix = page->GetDeviceAccessor(device_id);
|
||||
auto node_cats = dh::ToSpan(node_categories);
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
@ -409,11 +421,17 @@ struct GPUHistMakerDevice {
|
||||
bst_float cut_value =
|
||||
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
bst_node_t new_position = 0;
|
||||
if (isnan(cut_value)) {
|
||||
new_position = split_node.DefaultChild();
|
||||
} else {
|
||||
if (cut_value <= split_node.SplitCond()) {
|
||||
bool go_left = true;
|
||||
if (split_type == FeatureType::kCategorical) {
|
||||
go_left = common::Decision(node_cats, common::AsCat(cut_value));
|
||||
} else {
|
||||
go_left = cut_value <= split_node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
new_position = split_node.LeftChild();
|
||||
} else {
|
||||
new_position = split_node.RightChild();
|
||||
@ -428,59 +446,84 @@ struct GPUHistMakerDevice {
|
||||
// prediction cache
|
||||
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
|
||||
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto const& h_split_types = p_tree->GetSplitTypes();
|
||||
auto const& categories = p_tree->GetSplitCategories();
|
||||
auto const& categories_segments = p_tree->GetSplitCategoriesPtr();
|
||||
|
||||
dh::caching_device_vector<FeatureType> d_split_types;
|
||||
dh::caching_device_vector<uint32_t> d_categories;
|
||||
dh::caching_device_vector<RegTree::Segment> d_categories_segments;
|
||||
|
||||
if (!categories.empty()) {
|
||||
dh::CopyToD(h_split_types, &d_split_types);
|
||||
dh::CopyToD(categories, &d_categories);
|
||||
dh::CopyToD(categories_segments, &d_categories_segments);
|
||||
}
|
||||
|
||||
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
|
||||
}
|
||||
if (page->n_rows == p_fmat->Info().num_row_) {
|
||||
FinalisePositionInPage(page, dh::ToSpan(d_nodes));
|
||||
FinalisePositionInPage(page, dh::ToSpan(d_nodes),
|
||||
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
|
||||
dh::ToSpan(d_categories_segments));
|
||||
} else {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes));
|
||||
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes),
|
||||
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
|
||||
dh::ToSpan(d_categories_segments));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
|
||||
void FinalisePositionInPage(EllpackPageImpl *page,
|
||||
const common::Span<RegTree::Node> d_nodes,
|
||||
common::Span<FeatureType const> d_feature_types,
|
||||
common::Span<uint32_t const> categories,
|
||||
common::Span<RegTree::Segment> categories_segments) {
|
||||
auto d_matrix = page->GetDeviceAccessor(device_id);
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(size_t row_id, int position) {
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
// What happens if user prune the tree?
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
bool go_left = true;
|
||||
if (common::IsCat(d_feature_types, position)) {
|
||||
auto node_cats =
|
||||
categories.subspan(categories_segments[position].beg,
|
||||
categories_segments[position].size);
|
||||
go_left = common::Decision(node_cats, common::AsCat(element));
|
||||
} else {
|
||||
go_left = element <= node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
if (prediction_cache.size() != d_ridx.size()) {
|
||||
prediction_cache.resize(d_ridx.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d,
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
GPUTrainingParam param_d(param);
|
||||
dh::TemporaryArray<GradientPair> device_node_sum_gradients(node_sum_gradients.size());
|
||||
@ -491,21 +534,16 @@ struct GPUHistMakerDevice {
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto d_prediction_cache = prediction_cache.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(pos, param_d,
|
||||
GradStats{d_node_sum_gradients[pos]});
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
out_preds_d[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
out_preds_d, prediction_cache.data().get(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
row_partitioner.reset();
|
||||
}
|
||||
|
||||
@ -561,11 +599,27 @@ struct GPUHistMakerDevice {
|
||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||
auto right_weight = candidate.right_weight * param.learning_rate;
|
||||
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
auto is_cat = candidate.split.is_cat;
|
||||
if (is_cat) {
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight,
|
||||
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
}
|
||||
|
||||
// Set up child constraints
|
||||
auto left_child = tree[candidate.nid].LeftChild();
|
||||
@ -664,7 +718,7 @@ struct GPUHistMakerDevice {
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.Start("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
@ -752,8 +806,10 @@ class GPUHistMakerSpecialised {
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
info_->feature_types.SetDevice(device_);
|
||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->feature_types.ConstDeviceSpan(),
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
@ -804,7 +860,7 @@ class GPUHistMakerSpecialised {
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
maker->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -95,8 +95,7 @@ void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
|
||||
|
||||
TEST(GPUQuantile, Prune) {
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch(ft, n_bins, kCols, kRows, 0);
|
||||
|
||||
@ -293,9 +292,8 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
}
|
||||
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
// Set up single node version
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||
// Set up single node version;
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0);
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ auto ZeroParam() {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplit) {
|
||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(0.0, 1.0);
|
||||
TrainParam tparam = ZeroParam();
|
||||
@ -33,11 +33,19 @@ TEST(GpuHist, EvaluateSingleSplit) {
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
if (is_categorical) {
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
}
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
d_feature_types,
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -55,6 +63,14 @@ TEST(GpuHist, EvaluateSingleSplit) {
|
||||
parent_sum.GetHess());
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplit) {
|
||||
TestEvaluateSingleSplit(false);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateCategoricalSplit) {
|
||||
TestEvaluateSingleSplit(true);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(1.0, 1.5);
|
||||
@ -74,6 +90,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -134,6 +151,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -174,6 +192,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -215,6 +234,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -224,6 +244,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
@ -241,6 +262,5 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
EXPECT_EQ(result_right.findex, 0);
|
||||
EXPECT_EQ(result_right.fvalue, 1.0);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -80,7 +80,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
param.Init(args);
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
BatchParam batch_param{};
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols,
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, kNCols, kNCols,
|
||||
true, batch_param);
|
||||
xgboost::SimpleLCG gen;
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||
@ -130,6 +130,48 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
ExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
candidate.base_weight = 3.0f;
|
||||
candidate.split.is_cat = true;
|
||||
candidate.split.fvalue = 1.0f; // at cat 1
|
||||
|
||||
size_t n_rows = 10;
|
||||
size_t n_cols = 10;
|
||||
|
||||
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
|
||||
GenericParameter p;
|
||||
p.InitAllowUnknown(Args{});
|
||||
|
||||
TrainParam tparam;
|
||||
tparam.InitAllowUnknown(Args{});
|
||||
BatchParam bparam;
|
||||
bparam.gpu_id = 0;
|
||||
bparam.max_bin = 3;
|
||||
bparam.gpu_page_size = 0;
|
||||
|
||||
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
|
||||
auto impl = ellpack.Impl();
|
||||
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
|
||||
feature_types.SetDevice(bparam.gpu_id);
|
||||
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
|
||||
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
|
||||
updater.ApplySplit(candidate, &tree);
|
||||
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
|
||||
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
|
||||
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
|
||||
|
||||
ASSERT_EQ(updater.node_categories.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCutsWrapper GetHostCutMatrix () {
|
||||
HistogramCutsWrapper cmat;
|
||||
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||
@ -154,19 +196,18 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
|
||||
TrainParam param;
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> args {
|
||||
{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
std::vector<std::pair<std::string, std::string>> args{
|
||||
{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
|
||||
// Disable all other parameters.
|
||||
{"colsample_bynode", "1"},
|
||||
{"colsample_bylevel", "1"},
|
||||
{"colsample_bytree", "1"},
|
||||
{"min_child_weight", "0.01"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
{"max_delta_step", "0"}
|
||||
};
|
||||
// Disable all other parameters.
|
||||
{"colsample_bynode", "1"},
|
||||
{"colsample_bylevel", "1"},
|
||||
{"colsample_bytree", "1"},
|
||||
{"min_child_weight", "0.01"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
{"max_delta_step", "0"}};
|
||||
param.Init(args);
|
||||
for (size_t i = 0; i < kNCols; ++i) {
|
||||
param.monotone_constraints.emplace_back(0);
|
||||
@ -178,7 +219,7 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
BatchParam batch_param{};
|
||||
GPUHistMakerDevice<GradientPairPrecise>
|
||||
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
|
||||
maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.node_sum_gradients = {};
|
||||
|
||||
@ -257,7 +298,6 @@ void TestHistogramIndexImpl() {
|
||||
|
||||
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());
|
||||
ASSERT_EQ(maker->page->gidx_buffer.Size(), maker_ext->page->gidx_buffer.Size());
|
||||
|
||||
}
|
||||
|
||||
TEST(GpuHist, TestHistogramIndex) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user