Pass obj info by reference instead of by value. (#8889)

- Pass obj info into tree updater as const pointer.

This way we don't have to initialize the learner model param before configuring gbm, hence
breaking up the dependency of configurations.
This commit is contained in:
Jiaming Yuan
2023-03-11 01:38:28 +08:00
committed by GitHub
parent 54e001bbf4
commit 6deaec8027
18 changed files with 125 additions and 112 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file tree_updater.h
* \brief General primitive for tree learning,
* Updating a collection of trees given the information.
@@ -9,19 +9,17 @@
#define XGBOOST_TREE_UPDATER_H_
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h>
#include <xgboost/model.h>
#include <xgboost/task.h>
#include <xgboost/tree_model.h>
#include <xgboost/base.h> // for Args, GradientPair
#include <xgboost/data.h> // DMatrix
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Configurable
#include <xgboost/span.h> // for Span
#include <xgboost/tree_model.h> // for RegTree
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include <functional> // for function
#include <string> // for string
#include <vector> // for vector
namespace xgboost {
namespace tree {
@@ -30,8 +28,9 @@ struct TrainParam;
class Json;
struct Context;
struct ObjInfo;
/*!
/**
* \brief interface of tree update module, that performs update of a tree.
*/
class TreeUpdater : public Configurable {
@@ -53,12 +52,12 @@ class TreeUpdater : public Configurable {
* used for modifying existing trees (like `prune`). Return true if it can modify
* existing trees.
*/
virtual bool CanModifyTree() const { return false; }
[[nodiscard]] virtual bool CanModifyTree() const { return false; }
/*!
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
* tree can be used.
*/
virtual bool HasNodePosition() const { return false; }
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
/**
* \brief perform update to the tree models
*
@@ -91,14 +90,15 @@ class TreeUpdater : public Configurable {
return false;
}
virtual char const* Name() const = 0;
[[nodiscard]] virtual char const* Name() const = 0;
/*!
/**
* \brief Create a tree updater given name
* \param name Name of the tree updater.
* \param ctx A global runtime parameter
* \param task Infomation about the objective.
*/
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task);
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
};
/*!
@@ -106,7 +106,7 @@ class TreeUpdater : public Configurable {
*/
struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {};
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
/*!
* \brief Macro to register tree updater.