Add categorical data support to GPU Hist. (#6164)

This commit is contained in:
Jiaming Yuan 2020-09-29 11:27:25 +08:00 committed by GitHub
parent 798af22ff4
commit 444131a2e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 306 additions and 103 deletions

View File

@ -536,6 +536,21 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}
template <class HContainer, class DContainer>
void CopyToD(HContainer const &h, DContainer *d) {
if (h.empty()) {
d->clear();
return;
}
d->resize(h.size());
using HVT = std::remove_cv_t<typename HContainer::value_type>;
using DVT = std::remove_cv_t<typename DContainer::value_type>;
static_assert(std::is_same<HVT, DVT>::value,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
cudaMemcpyHostToDevice));
}
// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};

View File

@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

View File

@ -1,8 +1,9 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include "evaluate_splits.cuh"
#include <limits>
#include "evaluate_splits.cuh"
#include "../../common/categorical.h"
namespace xgboost {
namespace tree {
@ -66,13 +67,84 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
if (threadIdx.x == 0) {
shared_sum = local_sum;
}
__syncthreads();
cub::CTA_SYNC();
return shared_sum;
}
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
GradientSumT __device__ operator()(
bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT>*,
GradientSumT const &missing,
EvaluateSplitInputs<GradientSumT> const &inputs, TempStorageT *) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
auto rest = inputs.parent_sum - bin - missing;
return rest;
}
};
template <typename GradientSumT>
struct UpdateOneHot {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientSumT const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
int split_gidx = (scan_begin + threadIdx.x);
float fvalue = inputs.feature_values[split_gidx];
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx,
GradientPair(left), GradientPair(right), true,
inputs.param);
}
};
template <typename GradientSumT, typename TempStorageT, typename ScanT>
struct NumericBin {
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT>* prefix_callback,
GradientSumT const &missing,
EvaluateSplitInputs<GradientSumT> inputs,
TempStorageT *temp_storage) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback);
return bin;
}
};
template <typename GradientSumT>
struct UpdateNumeric {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientSumT const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
// Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
fidx, GradientPair(left), GradientPair(right),
false, inputs.param);
}
};
/*! \brief Find the thread with best gain. */
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
typename MaxReduceT, typename TempStorageT, typename GradientSumT,
typename BinFn, typename UpdateFn>
__device__ void EvaluateFeature(
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
@ -83,12 +155,14 @@ __device__ void EvaluateFeature(
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
uint32_t gidx_end =
inputs.feature_segments[fidx + 1]; // end bin for i^th feature
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
auto bin_fn = BinFn();
auto update_fn = UpdateFn();
// Sum histogram bins for current feature
GradientSumT const feature_sum =
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin),
temp_storage);
feature_hist, temp_storage);
GradientSumT const missing = inputs.parent_sum - feature_sum;
float const null_gain = -std::numeric_limits<bst_float>::infinity();
@ -97,12 +171,7 @@ __device__ void EvaluateFeature(
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
// Gradient value for current bin.
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage);
// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
@ -127,24 +196,14 @@ __device__ void EvaluateFeature(
block_max = best;
}
__syncthreads();
cub::CTA_SYNC();
// Best thread updates split
if (threadIdx.x == block_max.key) {
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
fidx, GradientPair(left), GradientPair(right),
inputs.param);
update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs,
best_split);
}
__syncthreads();
cub::CTA_SYNC();
}
}
@ -186,11 +245,21 @@ __global__ void EvaluateSplitsKernel(
// One block for each feature. Features are sampled, so fidx != blockIdx.x
int fidx = inputs.feature_set[is_left ? blockIdx.x
: blockIdx.x - left.feature_set.size()];
if (common::IsCat(inputs.feature_types, fidx)) {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
TempStorage, GradientSumT,
OneHotBin<GradientSumT, TempStorage>,
UpdateOneHot<GradientSumT>>(fidx, inputs, evaluator, &best_split,
&temp_storage);
} else {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
TempStorage, GradientSumT,
NumericBin<GradientSumT, TempStorage, BlockScanT>,
UpdateNumeric<GradientSumT>>(fidx, inputs, evaluator, &best_split,
&temp_storage);
}
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
fidx, inputs, evaluator, &best_split, &temp_storage);
__syncthreads();
cub::CTA_SYNC();
if (threadIdx.x == 0) {
// Record best loss for each feature

View File

@ -18,6 +18,7 @@ struct EvaluateSplitInputs {
GradientSumT parent_sum;
GPUTrainingParam param;
common::Span<const bst_feature_t> feature_set;
common::Span<FeatureType const> feature_types;
common::Span<const uint32_t> feature_segments;
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;

View File

@ -59,6 +59,7 @@ struct DeviceSplitCandidate {
DefaultDirection dir {kLeftDir};
int findex {-1};
float fvalue {0};
bool is_cat { false };
GradientPair left_sum;
GradientPair right_sum;
@ -79,6 +80,7 @@ struct DeviceSplitCandidate {
float fvalue_in, int findex_in,
GradientPair left_sum_in,
GradientPair right_sum_in,
bool cat,
const GPUTrainingParam& param) {
if (loss_chg_in > loss_chg &&
left_sum_in.GetHess() >= param.min_child_weight &&
@ -86,6 +88,7 @@ struct DeviceSplitCandidate {
loss_chg = loss_chg_in;
dir = dir_in;
fvalue = fvalue_in;
is_cat = cat;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
@ -98,6 +101,7 @@ struct DeviceSplitCandidate {
<< "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "is_cat: " << c.is_cat << ", "
<< "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl;
return os;

View File

@ -19,7 +19,9 @@
#include "../common/io.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/bitfield.h"
#include "../common/timer.h"
#include "../common/categorical.h"
#include "../data/ellpack_page.cuh"
#include "param.h"
@ -161,6 +163,7 @@ template <typename GradientSumT>
struct GPUHistMakerDevice {
int device_id;
EllpackPageImpl* page;
common::Span<FeatureType const> feature_types;
BatchParam batch_param;
std::unique_ptr<RowPartitioner> row_partitioner;
@ -169,7 +172,6 @@ struct GPUHistMakerDevice {
common::Span<GradientPair> gpair;
dh::caching_device_vector<int> monotone_constraints;
dh::caching_device_vector<bst_float> prediction_cache;
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
@ -191,9 +193,12 @@ struct GPUHistMakerDevice {
std::unique_ptr<GradientBasedSampler> sampler;
std::unique_ptr<FeatureGroups> feature_groups;
// Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories;
GPUHistMakerDevice(int _device_id,
EllpackPageImpl* _page,
common::Span<FeatureType const> _feature_types,
bst_uint _n_rows,
TrainParam _param,
uint32_t column_sampler_seed,
@ -202,6 +207,7 @@ struct GPUHistMakerDevice {
BatchParam _batch_param)
: device_id(_device_id),
page(_page),
feature_types{_feature_types},
param(std::move(_param)),
tree_evaluator(param, n_features, _device_id),
column_sampler(column_sampler_seed),
@ -293,6 +299,7 @@ struct GPUHistMakerDevice {
{root_sum.GetGrad(), root_sum.GetHess()},
gpu_param,
feature_set,
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
@ -331,6 +338,7 @@ struct GPUHistMakerDevice {
candidate.split.left_sum.GetHess()},
gpu_param,
left_feature_set,
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
@ -341,6 +349,7 @@ struct GPUHistMakerDevice {
candidate.split.right_sum.GetHess()},
gpu_param,
right_feature_set,
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
@ -399,8 +408,11 @@ struct GPUHistMakerDevice {
hist.HistogramExists(nidx_parent);
}
void UpdatePosition(int nidx, RegTree::Node split_node) {
void UpdatePosition(int nidx, RegTree* p_tree) {
RegTree::Node split_node = (*p_tree)[nidx];
auto split_type = p_tree->NodeSplitType(nidx);
auto d_matrix = page->GetDeviceAccessor(device_id);
auto node_cats = dh::ToSpan(node_categories);
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
@ -409,11 +421,17 @@ struct GPUHistMakerDevice {
bst_float cut_value =
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
// Missing value
int new_position = 0;
bst_node_t new_position = 0;
if (isnan(cut_value)) {
new_position = split_node.DefaultChild();
} else {
if (cut_value <= split_node.SplitCond()) {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
} else {
go_left = cut_value <= split_node.SplitCond();
}
if (go_left) {
new_position = split_node.LeftChild();
} else {
new_position = split_node.RightChild();
@ -428,59 +446,84 @@ struct GPUHistMakerDevice {
// prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
auto const& h_split_types = p_tree->GetSplitTypes();
auto const& categories = p_tree->GetSplitCategories();
auto const& categories_segments = p_tree->GetSplitCategoriesPtr();
dh::caching_device_vector<FeatureType> d_split_types;
dh::caching_device_vector<uint32_t> d_categories;
dh::caching_device_vector<RegTree::Segment> d_categories_segments;
if (!categories.empty()) {
dh::CopyToD(h_split_types, &d_split_types);
dh::CopyToD(categories, &d_categories);
dh::CopyToD(categories_segments, &d_categories_segments);
}
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
}
if (page->n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, dh::ToSpan(d_nodes));
FinalisePositionInPage(page, dh::ToSpan(d_nodes),
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
dh::ToSpan(d_categories_segments));
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes));
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes),
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
dh::ToSpan(d_categories_segments));
}
}
}
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
void FinalisePositionInPage(EllpackPageImpl *page,
const common::Span<RegTree::Node> d_nodes,
common::Span<FeatureType const> d_feature_types,
common::Span<uint32_t const> categories,
common::Span<RegTree::Segment> categories_segments) {
auto d_matrix = page->GetDeviceAccessor(device_id);
row_partitioner->FinalisePosition(
[=] __device__(size_t row_id, int position) {
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
// What happens if user prune the tree?
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
}
node = d_nodes[position];
}
return position;
});
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
bool go_left = true;
if (common::IsCat(d_feature_types, position)) {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
} else {
go_left = element <= node.SplitCond();
}
if (go_left) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
return position;
});
}
void UpdatePredictionCache(bst_float* out_preds_d) {
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
auto d_ridx = row_partitioner->GetRows();
if (prediction_cache.size() != d_ridx.size()) {
prediction_cache.resize(d_ridx.size());
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d,
prediction_cache.size() * sizeof(bst_float),
cudaMemcpyDefault));
}
GPUTrainingParam param_d(param);
dh::TemporaryArray<GradientPair> device_node_sum_gradients(node_sum_gradients.size());
@ -491,21 +534,16 @@ struct GPUHistMakerDevice {
cudaMemcpyHostToDevice));
auto d_position = row_partitioner->GetPosition();
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto d_prediction_cache = prediction_cache.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
dh::LaunchN(
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(pos, param_d,
GradStats{d_node_sum_gradients[pos]});
d_prediction_cache[d_ridx[local_idx]] +=
out_preds_d[d_ridx[local_idx]] +=
weight * param_d.learning_rate;
});
dh::safe_cuda(cudaMemcpyAsync(
out_preds_d, prediction_cache.data().get(),
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
row_partitioner.reset();
}
@ -561,11 +599,27 @@ struct GPUHistMakerDevice {
auto left_weight = candidate.left_weight * param.learning_rate;
auto right_weight = candidate.right_weight * param.learning_rate;
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
auto is_cat = candidate.split.is_cat;
if (is_cat) {
auto cat = common::AsCat(candidate.split.fvalue);
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
LBitField32 cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats,
candidate.split.dir == kLeftDir, base_weight, left_weight,
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
}
// Set up child constraints
auto left_child = tree[candidate.nid].LeftChild();
@ -664,7 +718,7 @@ struct GPUHistMakerDevice {
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
this->UpdatePosition(candidate.nid, p_tree);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
@ -752,8 +806,10 @@ class GPUHistMakerSpecialised {
};
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
dh::safe_cuda(cudaSetDevice(device_));
info_->feature_types.SetDevice(device_);
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
page,
info_->feature_types.ConstDeviceSpan(),
info_->num_row_,
param_,
column_sampling_seed,
@ -804,7 +860,7 @@ class GPUHistMakerSpecialised {
}
monitor_.Start("UpdatePredictionCache");
p_out_preds->SetDevice(device_);
maker->UpdatePredictionCache(p_out_preds->DevicePointer());
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
monitor_.Stop("UpdatePredictionCache");
return true;
}

View File

@ -95,8 +95,7 @@ void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
TEST(GPUQuantile, Prune) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, n_bins, kCols, kRows, 0);
@ -293,9 +292,8 @@ TEST(GPUQuantile, AllReduceBasic) {
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
// Set up single node version
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
// Set up single node version;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0);

View File

@ -15,7 +15,7 @@ auto ZeroParam() {
}
} // anonymous namespace
TEST(GpuHist, EvaluateSingleSplit) {
void TestEvaluateSingleSplit(bool is_categorical) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPair parent_sum(0.0, 1.0);
TrainParam tparam = ZeroParam();
@ -33,11 +33,19 @@ TEST(GpuHist, EvaluateSingleSplit) {
thrust::device_vector<GradientPair> feature_histogram =
std::vector<GradientPair>{
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
dh::device_vector<FeatureType> feature_types(feature_set.size(),
FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types;
if (is_categorical) {
d_feature_types = dh::ToSpan(feature_types);
}
EvaluateSplitInputs<GradientPair> input{1,
parent_sum,
param,
dh::ToSpan(feature_set),
d_feature_types,
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -55,6 +63,14 @@ TEST(GpuHist, EvaluateSingleSplit) {
parent_sum.GetHess());
}
TEST(GpuHist, EvaluateSingleSplit) {
TestEvaluateSingleSplit(false);
}
TEST(GpuHist, EvaluateCategoricalSplit) {
TestEvaluateSingleSplit(true);
}
TEST(GpuHist, EvaluateSingleSplitMissing) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPair parent_sum(1.0, 1.5);
@ -74,6 +90,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
parent_sum,
param,
dh::ToSpan(feature_set),
{},
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -134,6 +151,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
parent_sum,
param,
dh::ToSpan(feature_set),
{},
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -174,6 +192,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
parent_sum,
param,
dh::ToSpan(feature_set),
{},
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -215,6 +234,7 @@ TEST(GpuHist, EvaluateSplits) {
parent_sum,
param,
dh::ToSpan(feature_set),
{},
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -224,6 +244,7 @@ TEST(GpuHist, EvaluateSplits) {
parent_sum,
param,
dh::ToSpan(feature_set),
{},
dh::ToSpan(feature_segments),
dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values),
@ -241,6 +262,5 @@ TEST(GpuHist, EvaluateSplits) {
EXPECT_EQ(result_right.findex, 0);
EXPECT_EQ(result_right.fvalue, 1.0);
}
} // namespace tree
} // namespace xgboost

View File

@ -80,7 +80,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols,
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, kNCols, kNCols,
true, batch_param);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
@ -130,6 +130,48 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true);
}
TEST(GpuHist, ApplySplit) {
RegTree tree;
ExpandEntry candidate;
candidate.nid = 0;
candidate.left_weight = 1.0f;
candidate.right_weight = 2.0f;
candidate.base_weight = 3.0f;
candidate.split.is_cat = true;
candidate.split.fvalue = 1.0f; // at cat 1
size_t n_rows = 10;
size_t n_cols = 10;
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
GenericParameter p;
p.InitAllowUnknown(Args{});
TrainParam tparam;
tparam.InitAllowUnknown(Args{});
BatchParam bparam;
bparam.gpu_id = 0;
bparam.max_bin = 3;
bparam.gpu_page_size = 0;
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
auto impl = ellpack.Impl();
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
ASSERT_EQ(updater.node_categories.size(), 1);
}
}
HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
@ -154,19 +196,18 @@ TEST(GpuHist, EvaluateRootSplit) {
TrainParam param;
std::vector<std::pair<std::string, std::string>> args {
{"max_depth", "1"},
{"max_leaves", "0"},
std::vector<std::pair<std::string, std::string>> args{
{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"}
};
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"}};
param.Init(args);
for (size_t i = 0; i < kNCols; ++i) {
param.monotone_constraints.emplace_back(0);
@ -178,7 +219,7 @@ TEST(GpuHist, EvaluateRootSplit) {
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {};
@ -257,7 +298,6 @@ void TestHistogramIndexImpl() {
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());
ASSERT_EQ(maker->page->gidx_buffer.Size(), maker_ext->page->gidx_buffer.Size());
}
TEST(GpuHist, TestHistogramIndex) {