Implement hist evaluator for multi-target tree. (#8908)
This commit is contained in:
@@ -304,7 +304,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
|
||||
#include "../../../../src/common/hist_util.h"
|
||||
#include "../../../../src/tree/common_row_partitioner.h"
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||
#include "../test_evaluate_splits.h"
|
||||
#include "../../helpers.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h> // for GradientPairPrecise, Args, Gradie...
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/data.h> // for FeatureType, DMatrix, MetaInfo
|
||||
#include <xgboost/logging.h> // for CHECK_EQ
|
||||
#include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
|
||||
|
||||
#include <memory> // for make_shared, shared_ptr, addressof
|
||||
|
||||
#include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
|
||||
#include "../../../../src/common/random.h" // for ColumnSampler
|
||||
#include "../../../../src/common/row_set.h" // for RowSetCollection
|
||||
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator
|
||||
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
||||
#include "../../../../src/tree/param.h" // for GradStats, TrainParam
|
||||
#include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
|
||||
|
||||
namespace xgboost::tree {
|
||||
void TestEvaluateSplits(bool force_read_by_column) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
@@ -87,6 +96,68 @@ TEST(HistEvaluator, Evaluate) {
|
||||
TestEvaluateSplits(true);
|
||||
}
|
||||
|
||||
TEST(HistMultiEvaluator, Evaluate) {
|
||||
Context ctx;
|
||||
ctx.nthread = 1;
|
||||
|
||||
TrainParam param;
|
||||
param.Init(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
|
||||
std::size_t n_samples = 3;
|
||||
bst_feature_t n_features = 2;
|
||||
bst_target_t n_targets = 2;
|
||||
bst_bin_t n_bins = 2;
|
||||
|
||||
auto p_fmat =
|
||||
RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true);
|
||||
|
||||
HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), ¶m, sampler};
|
||||
std::vector<common::HistCollection> histogram(n_targets);
|
||||
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto &hist = histogram[t];
|
||||
hist.Init(n_bins * n_features);
|
||||
hist.AddHistRow(0);
|
||||
hist.AllocateAllData();
|
||||
auto node_hist = hist[0];
|
||||
node_hist[0] = {-0.5, 0.5};
|
||||
node_hist[1] = {2.0, 0.5};
|
||||
node_hist[2] = {0.5, 0.5};
|
||||
node_hist[3] = {1.0, 0.5};
|
||||
|
||||
root_sum(t) += node_hist[0];
|
||||
root_sum(t) += node_hist[1];
|
||||
}
|
||||
|
||||
RegTree tree{n_targets, n_features};
|
||||
auto weight = evaluator.InitRoot(root_sum.HostView());
|
||||
tree.SetLeaf(RegTree::kRoot, weight.HostView());
|
||||
auto w = weight.HostView();
|
||||
ASSERT_EQ(w.Size(), n_targets);
|
||||
ASSERT_EQ(w(0), -1.5);
|
||||
ASSERT_EQ(w(1), -1.5);
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_ptrs_ = {0, 2, 4};
|
||||
cuts.cut_values_ = {0.5, 1.0, 2.0, 3.0};
|
||||
cuts.min_vals_ = {-0.2, 1.8};
|
||||
|
||||
std::vector<MultiExpandEntry> entries(1, {/*nidx=*/0, /*depth=*/0});
|
||||
|
||||
std::vector<common::HistCollection const *> ptrs;
|
||||
std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs),
|
||||
[](auto const &h) { return std::addressof(h); });
|
||||
|
||||
evaluator.EvaluateSplits(tree, ptrs, cuts, &entries);
|
||||
|
||||
ASSERT_EQ(entries.front().split.loss_chg, 12.5);
|
||||
ASSERT_EQ(entries.front().split.split_value, 0.5);
|
||||
ASSERT_EQ(entries.front().split.SplitIndex(), 0);
|
||||
|
||||
ASSERT_EQ(sampler->GetFeatureSet(0)->Size(), n_features);
|
||||
}
|
||||
|
||||
TEST(HistEvaluator, Apply) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
@@ -211,12 +282,11 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
RegTree tree;
|
||||
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
|
||||
auto const& split = entries.front().split;
|
||||
auto const &split = entries.front().split;
|
||||
|
||||
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
|
||||
split.DefaultLeft(),
|
||||
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
|
||||
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -2,15 +2,26 @@
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
|
||||
#include <xgboost/data.h> // for MetaInfo
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/span.h> // for operator!=, Span, SpanIterator
|
||||
|
||||
#include <algorithm> // next_permutation
|
||||
#include <numeric> // iota
|
||||
#include <algorithm> // for max, max_element, next_permutation, copy
|
||||
#include <cmath> // for isnan
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint64_t, uint32_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <numeric> // for iota
|
||||
#include <tuple> // for make_tuple, tie, tuple
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../../../src/tree/split_evaluator.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
|
||||
#include "../../../src/tree/param.h" // for TrainParam, GradStats
|
||||
#include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
|
||||
#include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
|
||||
#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE
|
||||
|
||||
namespace xgboost::tree {
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user