Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
@@ -170,8 +170,8 @@ void TestHistogramIndexImpl() {
|
||||
|
||||
// Build 2 matrices and build a histogram maker with that
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}},
|
||||
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
|
||||
@@ -240,7 +240,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
param.UpdateAllowUnknown(args);
|
||||
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
tree::GPUHistMaker hist_maker{&ctx, &task};
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
@@ -385,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
|
||||
TEST(GpuHist, ConfigIO) {
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_gpu_hist", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
|
||||
updater->Configure(Args{});
|
||||
|
||||
Json j_updater { Object() };
|
||||
|
||||
@@ -37,13 +37,13 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
// With constraints
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(
|
||||
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||
@@ -61,8 +61,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
@@ -81,8 +80,8 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||
auto p_dmat = GenerateDMatrix(rows, cols);
|
||||
auto p_gradients = GenerateGradients(rows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{
|
||||
@@ -110,12 +109,12 @@ TEST(GrowHistMaker, ColumnSplit) {
|
||||
|
||||
RegTree expected_tree;
|
||||
expected_tree.param.num_feature = kCols;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
|
||||
@@ -2,22 +2,25 @@
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_updater.h> // for TreeUpdater
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Updater, HasNodePosition) {
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression, true, true};
|
||||
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)};
|
||||
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task));
|
||||
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
ctx.gpu_id = 0;
|
||||
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task));
|
||||
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, &task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -71,8 +72,8 @@ class TestPredictionCache : public ::testing::Test {
|
||||
ctx.gpu_id = Context::kCpuId;
|
||||
}
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create(updater_name, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, &ctx, &task)};
|
||||
RegTree tree;
|
||||
std::vector<RegTree *> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(n_samples_);
|
||||
|
||||
@@ -39,8 +39,8 @@ TEST(Updater, Prune) {
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(cfg);
|
||||
|
||||
std::unique_ptr<TreeUpdater> pruner(
|
||||
TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression}));
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &ctx, &task));
|
||||
|
||||
// loss_chg < min_split_loss;
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
/**
|
||||
* Copyright 2018-2013 by XGBoost Contributors
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory>
|
||||
@@ -12,9 +13,7 @@
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
namespace xgboost::tree {
|
||||
TEST(Updater, Refresh) {
|
||||
bst_row_t constexpr kRows = 8;
|
||||
bst_feature_t constexpr kCols = 16;
|
||||
@@ -33,8 +32,9 @@ TEST(Updater, Refresh) {
|
||||
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
||||
tree.param.UpdateAllowUnknown(cfg);
|
||||
std::vector<RegTree*> trees{&tree};
|
||||
std::unique_ptr<TreeUpdater> refresher(
|
||||
TreeUpdater::Create("refresh", &ctx, ObjInfo{ObjInfo::kRegression}));
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &ctx, &task));
|
||||
|
||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
@@ -57,6 +57,4 @@ TEST(Updater, Refresh) {
|
||||
ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps);
|
||||
ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -2,9 +2,13 @@
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
@@ -26,12 +30,12 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
tree::TrainParam param;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
param.Init(Args{});
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure(Args{});
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
@@ -74,18 +78,18 @@ class UpdaterEtaTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
ObjInfo task{ObjInfo::kClassification};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
|
||||
float eta = 0.4;
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_0->Configure(Args{});
|
||||
tree::TrainParam param0;
|
||||
param0.Init(Args{{"eta", std::to_string(eta)}});
|
||||
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_1->Configure(Args{{"eta", "1.0"}});
|
||||
tree::TrainParam param1;
|
||||
param1.Init(Args{{"eta", "1.0"}});
|
||||
@@ -153,11 +157,11 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
{"gamma", std::to_string(gamma)}};
|
||||
tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(args);
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure({});
|
||||
|
||||
RegTree tree;
|
||||
|
||||
Reference in New Issue
Block a user