Small refactor to categoricals (#7858)
This commit is contained in:
@@ -194,8 +194,6 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
// Storing split categories for last node.
|
||||
dh::caching_device_vector<uint32_t> node_categories;
|
||||
|
||||
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
|
||||
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
|
||||
@@ -239,7 +237,8 @@ struct GPUHistMakerDevice {
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||
ctx_->gpu_id);
|
||||
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
|
||||
@@ -349,14 +348,14 @@ struct GPUHistMakerDevice {
|
||||
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[nidx];
|
||||
auto split_type = p_tree->NodeSplitType(nidx);
|
||||
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[e.nid];
|
||||
auto split_type = p_tree->NodeSplitType(e.nid);
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
auto node_cats = dh::ToSpan(node_categories);
|
||||
auto node_cats = e.split.split_cats.Bits();
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
e.nid, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(bst_uint ridx) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
@@ -569,27 +568,12 @@ struct GPUHistMakerDevice {
|
||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
std::vector<uint32_t> split_cats;
|
||||
if (candidate.split.split_cats.Bits().empty()) {
|
||||
if (common::InvalidCat(candidate.split.fvalue)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0);
|
||||
common::CatBitField cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
} else {
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
node_categories.resize(candidate.split.split_cats.Bits().size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
node_categories.data().get(), candidate.split.split_cats.Data(),
|
||||
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
|
||||
@@ -676,7 +660,7 @@ struct GPUHistMakerDevice {
|
||||
// Update position is only run when child is valid, instead of right after apply
|
||||
// split (as in approx tree method). Hense we have the finalise position call
|
||||
// in GPU Hist.
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
|
||||
Reference in New Issue
Block a user