Add categorical data support to GPU Hist. (#6164)
This commit is contained in:
@@ -19,7 +19,9 @@
|
||||
#include "../common/io.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
|
||||
#include "param.h"
|
||||
@@ -161,6 +163,7 @@ template <typename GradientSumT>
|
||||
struct GPUHistMakerDevice {
|
||||
int device_id;
|
||||
EllpackPageImpl* page;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
BatchParam batch_param;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
@@ -169,7 +172,6 @@ struct GPUHistMakerDevice {
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::caching_device_vector<int> monotone_constraints;
|
||||
dh::caching_device_vector<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
@@ -191,9 +193,12 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
// Storing split categories for last node.
|
||||
dh::caching_device_vector<uint32_t> node_categories;
|
||||
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
common::Span<FeatureType const> _feature_types,
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
@@ -202,6 +207,7 @@ struct GPUHistMakerDevice {
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
feature_types{_feature_types},
|
||||
param(std::move(_param)),
|
||||
tree_evaluator(param, n_features, _device_id),
|
||||
column_sampler(column_sampler_seed),
|
||||
@@ -293,6 +299,7 @@ struct GPUHistMakerDevice {
|
||||
{root_sum.GetGrad(), root_sum.GetHess()},
|
||||
gpu_param,
|
||||
feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@@ -331,6 +338,7 @@ struct GPUHistMakerDevice {
|
||||
candidate.split.left_sum.GetHess()},
|
||||
gpu_param,
|
||||
left_feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@@ -341,6 +349,7 @@ struct GPUHistMakerDevice {
|
||||
candidate.split.right_sum.GetHess()},
|
||||
gpu_param,
|
||||
right_feature_set,
|
||||
feature_types,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
@@ -399,8 +408,11 @@ struct GPUHistMakerDevice {
|
||||
hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||
void UpdatePosition(int nidx, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[nidx];
|
||||
auto split_type = p_tree->NodeSplitType(nidx);
|
||||
auto d_matrix = page->GetDeviceAccessor(device_id);
|
||||
auto node_cats = dh::ToSpan(node_categories);
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
@@ -409,11 +421,17 @@ struct GPUHistMakerDevice {
|
||||
bst_float cut_value =
|
||||
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
bst_node_t new_position = 0;
|
||||
if (isnan(cut_value)) {
|
||||
new_position = split_node.DefaultChild();
|
||||
} else {
|
||||
if (cut_value <= split_node.SplitCond()) {
|
||||
bool go_left = true;
|
||||
if (split_type == FeatureType::kCategorical) {
|
||||
go_left = common::Decision(node_cats, common::AsCat(cut_value));
|
||||
} else {
|
||||
go_left = cut_value <= split_node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
new_position = split_node.LeftChild();
|
||||
} else {
|
||||
new_position = split_node.RightChild();
|
||||
@@ -428,59 +446,84 @@ struct GPUHistMakerDevice {
|
||||
// prediction cache
|
||||
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
|
||||
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto const& h_split_types = p_tree->GetSplitTypes();
|
||||
auto const& categories = p_tree->GetSplitCategories();
|
||||
auto const& categories_segments = p_tree->GetSplitCategoriesPtr();
|
||||
|
||||
dh::caching_device_vector<FeatureType> d_split_types;
|
||||
dh::caching_device_vector<uint32_t> d_categories;
|
||||
dh::caching_device_vector<RegTree::Segment> d_categories_segments;
|
||||
|
||||
if (!categories.empty()) {
|
||||
dh::CopyToD(h_split_types, &d_split_types);
|
||||
dh::CopyToD(categories, &d_categories);
|
||||
dh::CopyToD(categories_segments, &d_categories_segments);
|
||||
}
|
||||
|
||||
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
|
||||
}
|
||||
if (page->n_rows == p_fmat->Info().num_row_) {
|
||||
FinalisePositionInPage(page, dh::ToSpan(d_nodes));
|
||||
FinalisePositionInPage(page, dh::ToSpan(d_nodes),
|
||||
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
|
||||
dh::ToSpan(d_categories_segments));
|
||||
} else {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes));
|
||||
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes),
|
||||
dh::ToSpan(d_split_types), dh::ToSpan(d_categories),
|
||||
dh::ToSpan(d_categories_segments));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FinalisePositionInPage(EllpackPageImpl* page, const common::Span<RegTree::Node> d_nodes) {
|
||||
void FinalisePositionInPage(EllpackPageImpl *page,
|
||||
const common::Span<RegTree::Node> d_nodes,
|
||||
common::Span<FeatureType const> d_feature_types,
|
||||
common::Span<uint32_t const> categories,
|
||||
common::Span<RegTree::Segment> categories_segments) {
|
||||
auto d_matrix = page->GetDeviceAccessor(device_id);
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(size_t row_id, int position) {
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
// What happens if user prune the tree?
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
bool go_left = true;
|
||||
if (common::IsCat(d_feature_types, position)) {
|
||||
auto node_cats =
|
||||
categories.subspan(categories_segments[position].beg,
|
||||
categories_segments[position].size);
|
||||
go_left = common::Decision(node_cats, common::AsCat(element));
|
||||
} else {
|
||||
go_left = element <= node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
return position;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
if (prediction_cache.size() != d_ridx.size()) {
|
||||
prediction_cache.resize(d_ridx.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d,
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
GPUTrainingParam param_d(param);
|
||||
dh::TemporaryArray<GradientPair> device_node_sum_gradients(node_sum_gradients.size());
|
||||
@@ -491,21 +534,16 @@ struct GPUHistMakerDevice {
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto d_prediction_cache = prediction_cache.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(pos, param_d,
|
||||
GradStats{d_node_sum_gradients[pos]});
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
out_preds_d[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
out_preds_d, prediction_cache.data().get(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
row_partitioner.reset();
|
||||
}
|
||||
|
||||
@@ -561,11 +599,27 @@ struct GPUHistMakerDevice {
|
||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||
auto right_weight = candidate.right_weight * param.learning_rate;
|
||||
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
auto is_cat = candidate.split.is_cat;
|
||||
if (is_cat) {
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight,
|
||||
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
}
|
||||
|
||||
// Set up child constraints
|
||||
auto left_child = tree[candidate.nid].LeftChild();
|
||||
@@ -664,7 +718,7 @@ struct GPUHistMakerDevice {
|
||||
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.Start("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
@@ -752,8 +806,10 @@ class GPUHistMakerSpecialised {
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
info_->feature_types.SetDevice(device_);
|
||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->feature_types.ConstDeviceSpan(),
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
@@ -804,7 +860,7 @@ class GPUHistMakerSpecialised {
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
maker->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user