Small refactor to categoricals (#7858)
This commit is contained in:
@@ -58,9 +58,12 @@ class GPUHistEvaluator {
|
||||
dh::device_vector<bst_feature_t> feature_idx_;
|
||||
// Training param used for evaluation
|
||||
TrainParam param_;
|
||||
// whether the input data requires sort based split, which is more complicated so we try
|
||||
// to avoid it if possible.
|
||||
bool has_sort_{false};
|
||||
// Do we have any categorical features that require sorting histograms?
|
||||
// use this to skip the expensive sort step
|
||||
bool need_sort_histogram_ = false;
|
||||
// Number of elements of categorical storage type
|
||||
// needed to hold categoricals for a single mode
|
||||
std::size_t node_categorical_storage_size_ = 0;
|
||||
|
||||
// Copy the categories from device to host asynchronously.
|
||||
void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out);
|
||||
@@ -69,12 +72,17 @@ class GPUHistEvaluator {
|
||||
* \brief Get host category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto HostCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
|
||||
std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
|
||||
if(h_split_cats_.size()<min_size){
|
||||
h_split_cats_.resize(min_size);
|
||||
}
|
||||
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2);
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
@@ -82,12 +90,15 @@ class GPUHistEvaluator {
|
||||
* \brief Get device category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto DeviceCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = split_cats_.size() / param_.MaxNodes();
|
||||
std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
|
||||
if(split_cats_.size()<min_size){
|
||||
split_cats_.resize(min_size);
|
||||
}
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2);
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
@@ -123,8 +134,7 @@ class GPUHistEvaluator {
|
||||
*/
|
||||
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
||||
copy_stream_.View().Sync();
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user