[backport] Always use partition based categorical splits. (#7857) (#7865)

This commit is contained in:
Jiaming Yuan 2022-05-06 19:14:19 +08:00 committed by GitHub
parent f4eb6b984e
commit b1b6246e35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 77 additions and 103 deletions

View File

@ -74,23 +74,20 @@ Optimal Partitioning
.. versionadded:: 1.6
Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
node split, the proof of optimality for numerical output was first introduced by `[1]
<#references>`__. The algorithm is used in decision trees `[2] <#references>`__, later
LightGBM `[3] <#references>`__ brought it to the context of gradient boosting trees and
now is also adopted in XGBoost as an optional feature for handling categorical
splits. More specifically, the proof by Fisher `[1] <#references>`__ states that, when
trying to partition a set of discrete values into groups based on the distances between a
measure of these values, one only needs to look at sorted partitions instead of
enumerating all possible permutations. In the context of decision trees, the discrete
values are categories, and the measure is the output leaf value. Intuitively, we want to
group the categories that output similar leaf values. During split finding, we first sort
the gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details. When objective is not
regression or binary classification, XGBoost will fallback to using onehot encoding
instead.
used for each feature, see :doc:`/parameter` for details.
**********************

View File

@ -36,10 +36,6 @@ struct ObjInfo {
explicit ObjInfo(Task t) : task{t} {}
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {}
XGBOOST_DEVICE bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
}
};
} // namespace xgboost
#endif // XGBOOST_TASK_H_

View File

@ -12,7 +12,6 @@
#include "xgboost/data.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/task.h"
namespace xgboost {
namespace common {
@ -82,8 +81,8 @@ inline void InvalidCategory() {
/*!
* \brief Whether should we use onehot encoding for categorical data.
*/
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot();
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot) {
bool use_one_hot = n_cats < max_cat_to_onehot;
return use_one_hot;
}

View File

@ -199,13 +199,11 @@ __device__ void EvaluateFeature(
}
template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitsKernel(
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
ObjInfo task,
common::Span<bst_feature_t> sorted_idx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_candidates) {
__global__ void EvaluateSplitsKernel(EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<bst_feature_t> sorted_idx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_candidates) {
// KeyValuePair here used as threadIdx.x -> gain_value
using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT =
@ -241,7 +239,7 @@ __global__ void EvaluateSplitsKernel(
if (common::IsCat(inputs.feature_types, fidx)) {
auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx];
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) {
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot)) {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
kOneHot>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
} else {
@ -310,7 +308,7 @@ __device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits) {
if (!split_cats_.empty()) {
@ -323,7 +321,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
// One block for each feature
uint32_t constexpr kBlockThreads = 256;
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, task, this->SortedIdx(left),
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, this->SortedIdx(left),
evaluator, dh::ToSpan(feature_best_splits));
// Reduce to get best candidate for left and right child over all features
@ -365,7 +363,7 @@ void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT
}
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<GPUExpandEntry> out_entries) {
@ -373,7 +371,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(2);
auto out_splits = dh::ToSpan(splits_out_storage);
this->EvaluateSplits(left, right, task, evaluator, out_splits);
this->EvaluateSplits(left, right, evaluator, out_splits);
auto d_sorted_idx = this->SortedIdx(left);
auto d_entries = out_entries;
@ -385,7 +383,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
auto fidx = out_splits[i].findex;
if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
bool is_left = i == 0;
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
@ -405,11 +403,11 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
template <typename GradientSumT>
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
EvaluateSplitInputs<GradientSumT> input, float weight, ObjInfo task) {
EvaluateSplitInputs<GradientSumT> input, float weight) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
auto out_split = dh::ToSpan(splits_out);
auto evaluator = tree_evaluator_.GetEvaluator<GPUTrainingParam>();
this->EvaluateSplits(input, {}, task, evaluator, out_split);
this->EvaluateSplits(input, {}, evaluator, out_split);
auto cats_out = this->DeviceCatStorage(input.nidx);
auto d_sorted_idx = this->SortedIdx(input);
@ -421,7 +419,7 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
auto fidx = out_split[i].findex;
if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
}

View File

@ -114,7 +114,7 @@ class GPUHistEvaluator {
/**
* \brief Reset the evaluator, should be called before any use.
*/
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, ObjInfo task,
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param, int32_t device);
/**
@ -150,21 +150,20 @@ class GPUHistEvaluator {
// impl of evaluate splits, contains CUDA kernels so it's public
void EvaluateSplits(EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
EvaluateSplitInputs<GradientSumT> right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits);
/**
* \brief Evaluate splits for left and right nodes.
*/
void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
void EvaluateSplits(GPUExpandEntry candidate,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<GPUExpandEntry> out_splits);
/**
* \brief Evaluate splits for root node.
*/
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight,
ObjInfo task);
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight);
};
} // namespace tree
} // namespace xgboost

View File

@ -16,12 +16,12 @@ namespace xgboost {
namespace tree {
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
common::Span<FeatureType const> ft, ObjInfo task,
common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param,
int32_t device) {
param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device};
if (cuts.HasCategorical() && !task.UseOneHot()) {
if (cuts.HasCategorical()) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
auto beg = thrust::make_counting_iterator<size_t>(1ul);
@ -34,7 +34,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
auto idx = i - 1;
if (common::IsCat(ft, idx)) {
auto n_bins = ptrs[i] - ptrs[idx];
bool use_sort = !common::UseOneHot(n_bins, to_onehot, task);
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
return use_sort;
}
return false;

View File

@ -11,7 +11,6 @@
#include <utility>
#include <vector>
#include "xgboost/task.h"
#include "../param.h"
#include "../constraints.h"
#include "../split_evaluator.h"
@ -39,7 +38,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
int32_t n_threads_ {0};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;
ObjInfo task_;
// if sum of statistics for non-missing values in the node
// is equal to sum of statistics for all values:
@ -244,7 +242,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
} else {
@ -345,7 +343,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
auto const& Stats() const { return snode_; }
auto Task() const { return task_; }
float InitRoot(GradStats const& root_sum) {
snode_.resize(1);
@ -363,12 +360,11 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
n_threads_{n_threads},
task_{task} {
n_threads_{n_threads} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);

View File

@ -28,10 +28,8 @@ DMLC_REGISTRY_FILE_TAG(updater_approx);
namespace {
// Return the BatchParam used by DMatrix.
template <typename GradientSumT>
auto BatchSpec(TrainParam const &p, common::Span<float> hess,
HistEvaluator<GradientSumT, CPUExpandEntry> const &evaluator) {
return BatchParam{p.max_bin, hess, !evaluator.Task().const_hess};
auto BatchSpec(TrainParam const &p, common::Span<float> hess, ObjInfo const task) {
return BatchParam{p.max_bin, hess, !task.const_hess};
}
auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
@ -46,7 +44,8 @@ class GloablApproxBuilder {
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
GenericParameter const *ctx_;
Context const *ctx_;
ObjInfo const task_;
std::vector<ApproxRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
@ -64,8 +63,7 @@ class GloablApproxBuilder {
int32_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, evaluator_))) {
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, task_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
@ -160,8 +158,9 @@ class GloablApproxBuilder {
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
evaluator_{param_, info, ctx->Threads(), col_sampler_},
ctx_{ctx},
task_{task},
monitor_{monitor} {}
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess,

View File

@ -229,16 +229,14 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns,
ObjInfo task) {
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
auto const& info = dmat->Info();
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param,
device_id);
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, device_id);
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
@ -260,7 +258,7 @@ struct GPUHistMakerDevice {
hist.Reset();
}
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) {
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight) {
int nidx = RegTree::kRoot;
GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0);
@ -277,12 +275,12 @@ struct GPUHistMakerDevice {
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(nidx)};
auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task);
auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight);
return split;
}
void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx,
int right_nidx, const RegTree& tree,
void EvaluateLeftRightSplits(GPUExpandEntry candidate, int left_nidx, int right_nidx,
const RegTree& tree,
common::Span<GPUExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param);
@ -316,7 +314,7 @@ struct GPUHistMakerDevice {
hist.GetNodeHistogram(right_nidx)};
dh::TemporaryArray<GPUExpandEntry> entries(2);
this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries));
this->evaluator_.EvaluateSplits(candidate, left, right, dh::ToSpan(entries));
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
}
@ -584,7 +582,7 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}
GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) {
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
@ -605,7 +603,7 @@ struct GPUHistMakerDevice {
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
// Generate first split
auto root_entry = this->EvaluateRootSplit(root_sum, weight, task);
auto root_entry = this->EvaluateRootSplit(root_sum, weight);
return root_entry;
}
@ -615,11 +613,11 @@ struct GPUHistMakerDevice {
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task);
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.Stop("Reset");
monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, task, reducer) });
driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.Stop("InitRoot");
auto num_leaves = 1;
@ -656,7 +654,7 @@ struct GPUHistMakerDevice {
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree,
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {

View File

@ -342,7 +342,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(DMatrix *fmat, const Reg
// store a pointer to the tree
p_last_tree_ = &tree;
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_});
param_, info, this->ctx_->Threads(), column_sampler_});
monitor_->Stop(__func__);
}

View File

@ -57,8 +57,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
GPUHistEvaluator<GradientPair> evaluator{
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
dh::device_vector<common::CatBitField::value_type> out_cats;
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0);
@ -101,8 +100,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
dh::ToSpan(feature_histogram)};
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_set.size(), 0);
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
EXPECT_EQ(result.findex, 0);
EXPECT_EQ(result.fvalue, 1.0);
@ -114,10 +112,8 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
TEST(GpuHist, EvaluateSingleSplitEmpty) {
TrainParam tparam = ZeroParam();
GPUHistEvaluator<GradientPair> evaluator(tparam, 1, 0);
DeviceSplitCandidate result = evaluator
.EvaluateSingleSplit(EvaluateSplitInputs<GradientPair>{}, 0,
ObjInfo{ObjInfo::kRegression})
.split;
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(EvaluateSplitInputs<GradientPair>{}, 0).split;
EXPECT_EQ(result.findex, -1);
EXPECT_LT(result.loss_chg, 0.0f);
}
@ -152,8 +148,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
dh::ToSpan(feature_histogram)};
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0);
@ -191,8 +186,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
dh::ToSpan(feature_histogram)};
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
EXPECT_EQ(result.findex, 0);
EXPECT_EQ(result.fvalue, 1.0);
@ -243,8 +237,8 @@ TEST(GpuHist, EvaluateSplits) {
GPUHistEvaluator<GradientPair> evaluator{
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
evaluator.EvaluateSplits(input_left, input_right, ObjInfo{ObjInfo::kRegression},
evaluator.GetEvaluator(), dh::ToSpan(out_splits));
evaluator.EvaluateSplits(input_left, input_right, evaluator.GetEvaluator(),
dh::ToSpan(out_splits));
DeviceSplitCandidate result_left = out_splits[0];
EXPECT_EQ(result_left.findex, 1);
@ -264,8 +258,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
cuts_.cut_values_.SetDevice(0);
cuts_.min_vals_.SetDevice(0);
ObjInfo task{ObjInfo::kRegression};
evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0);
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0);
dh::device_vector<GradientPairPrecise> d_hist(hist_[0].size());
auto node_hist = hist_[0];
@ -282,7 +275,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
cuts_.cut_values_.ConstDeviceSpan(),
cuts_.min_vals_.ConstDeviceSpan(),
dh::ToSpan(d_hist)};
auto split = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
auto split = evaluator.EvaluateSingleSplit(input, 0).split;
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
}
} // namespace tree

View File

@ -24,8 +24,8 @@ template <typename GradientSumT> void TestEvaluateSplits() {
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
auto evaluator =
HistEvaluator<GradientSumT, CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
common::HistCollection<GradientSumT> hist;
std::vector<GradientPair> row_gpairs = {
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
@ -97,8 +97,7 @@ TEST(HistEvaluator, Apply) {
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler,
ObjInfo{ObjInfo::kRegression}};
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler};
CPUExpandEntry entry{0, 0, 10.0f};
entry.split.left_sum = GradStats{0.4, 0.6f};
@ -125,7 +124,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
std::vector<FeatureType> ft{FeatureType::kCategorical};
auto sampler = std::make_shared<common::ColumnSampler>();
HistEvaluator<double, CPUExpandEntry> evaluator{param_, info_, common::OmpGetNumThreads(0),
sampler, ObjInfo{ObjInfo::kRegression}};
sampler};
evaluator.InitRoot(GradStats{total_gpair_});
RegTree tree;
std::vector<CPUExpandEntry> entries(1);
@ -156,8 +155,8 @@ auto CompareOneHotAndPartition(bool onehot) {
int32_t n_threads = 16;
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
auto evaluator =
HistEvaluator<GradientSumT, CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
std::vector<CPUExpandEntry> entries(1);
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {

View File

@ -262,7 +262,7 @@ TEST(GpuHist, EvaluateRootSplit) {
info.num_col_ = kNCols;
DeviceSplitCandidate res =
maker.EvaluateRootSplit({6.4f, 12.8f}, 0, ObjInfo{ObjInfo::kRegression}).split;
maker.EvaluateRootSplit({6.4f, 12.8f}, 0).split;
ASSERT_EQ(res.findex, 7);
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
@ -300,11 +300,11 @@ void TestHistogramIndexImpl() {
const auto &maker = hist_maker.maker;
auto grad = GenerateRandomGradients(kNRows);
grad.SetDevice(0);
maker->Reset(&grad, hist_maker_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
maker->Reset(&grad, hist_maker_dmat.get(), kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.HostVector());
const auto &maker_ext = hist_maker_ext.maker;
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector());
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());