[Breaking] Remove num roots. (#5059)
This commit is contained in:
@@ -48,11 +48,6 @@ class MetaInfo {
|
||||
uint64_t num_nonzero_{0};
|
||||
/*! \brief label of each instance */
|
||||
HostDeviceVector<bst_float> labels_;
|
||||
/*!
|
||||
* \brief specified root index of each instance,
|
||||
* can be used for multi task setting
|
||||
*/
|
||||
std::vector<bst_uint> root_index_;
|
||||
/*!
|
||||
* \brief the index of begin and end of a group
|
||||
* needed when the learning task is ranking.
|
||||
@@ -76,14 +71,6 @@ class MetaInfo {
|
||||
inline bst_float GetWeight(size_t i) const {
|
||||
return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
|
||||
}
|
||||
/*!
|
||||
* \brief Get the root index of i-th instance.
|
||||
* \param i Instance index.
|
||||
* \return The pre-defined root index of i-th instance.
|
||||
*/
|
||||
inline unsigned GetRoot(size_t i) const {
|
||||
return !root_index_.empty() ? root_index_[i] : 0U;
|
||||
}
|
||||
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
|
||||
inline const std::vector<size_t>& LabelAbsSort() const {
|
||||
if (label_order_cache_.size() == labels_.Size()) {
|
||||
|
||||
@@ -88,13 +88,11 @@ class GradientBooster {
|
||||
* \param inst the instance you want to predict
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param ntree_limit limit the number of trees used in prediction
|
||||
* \param root_index the root index
|
||||
* \sa Predict
|
||||
*/
|
||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned root_index = 0) = 0;
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
/*!
|
||||
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
|
||||
* this is only valid in gbtree predictor
|
||||
|
||||
@@ -98,7 +98,6 @@ class Predictor {
|
||||
/**
|
||||
* \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
|
||||
* inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model,
|
||||
* unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
|
||||
*
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||
@@ -109,14 +108,12 @@ class Predictor {
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model The model to predict from
|
||||
* \param ntree_limit (Optional) The ntree limit.
|
||||
* \param root_index (Optional) Zero-based index of the root.
|
||||
*/
|
||||
|
||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned root_index = 0) = 0;
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/**
|
||||
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,
|
||||
|
||||
@@ -29,8 +29,8 @@ struct PathElement; // forward declaration
|
||||
|
||||
/*! \brief meta parameters of the tree */
|
||||
struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
/*! \brief number of start root */
|
||||
int num_roots;
|
||||
/*! \brief (Deprecated) number of start root */
|
||||
int deprecated_num_roots;
|
||||
/*! \brief total number of nodes */
|
||||
int num_nodes;
|
||||
/*!\brief number of deleted nodes */
|
||||
@@ -52,14 +52,12 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
|
||||
"TreeParam: 64 bit align");
|
||||
std::memset(this, 0, sizeof(TreeParam));
|
||||
num_nodes = num_roots = 1;
|
||||
num_nodes = 1;
|
||||
}
|
||||
// declare the parameters
|
||||
DMLC_DECLARE_PARAMETER(TreeParam) {
|
||||
// only declare the parameters that can be set by the user.
|
||||
// other arguments are set by the algorithm.
|
||||
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
|
||||
.describe("Number of start root of trees.");
|
||||
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
|
||||
DMLC_DECLARE_FIELD(num_feature)
|
||||
.describe("Number of features used in tree construction.");
|
||||
@@ -68,7 +66,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
}
|
||||
|
||||
bool operator==(const TreeParam& b) const {
|
||||
return num_roots == b.num_roots && num_nodes == b.num_nodes &&
|
||||
return num_nodes == b.num_nodes &&
|
||||
num_deleted == b.num_deleted && max_depth == b.max_depth &&
|
||||
num_feature == b.num_feature &&
|
||||
size_leaf_vector == b.size_leaf_vector;
|
||||
@@ -269,7 +267,6 @@ class RegTree : public Model {
|
||||
/*! \brief constructor */
|
||||
RegTree() {
|
||||
param.num_nodes = 1;
|
||||
param.num_roots = 1;
|
||||
param.num_deleted = 0;
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
@@ -380,16 +377,12 @@ class RegTree : public Model {
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
int MaxDepth() {
|
||||
int maxd = 0;
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
maxd = std::max(maxd, MaxDepth(i));
|
||||
}
|
||||
return maxd;
|
||||
return MaxDepth(0);
|
||||
}
|
||||
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
int NumExtraNodes() const {
|
||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||
return param.num_nodes - 1 - param.num_deleted;
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -444,19 +437,17 @@ class RegTree : public Model {
|
||||
/*!
|
||||
* \brief get the leaf index
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const;
|
||||
int GetLeafIndex(const FVec& feat) const;
|
||||
/*!
|
||||
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
|
||||
* \param condition_feature the index of the feature to fix
|
||||
*/
|
||||
void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
|
||||
void CalculateContributions(const RegTree::FVec& feat,
|
||||
bst_float* out_contribs, int condition = 0,
|
||||
unsigned condition_feature = 0) const;
|
||||
/*!
|
||||
@@ -482,10 +473,9 @@ class RegTree : public Model {
|
||||
/*!
|
||||
* \brief calculate the approximate feature contributions for the given root
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
*/
|
||||
void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
|
||||
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
||||
bst_float* out_contribs) const;
|
||||
/*!
|
||||
* \brief get next position of the tree given current pid
|
||||
@@ -536,7 +526,7 @@ class RegTree : public Model {
|
||||
}
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
void DeleteNode(int nid) {
|
||||
CHECK_GE(nid, param.num_roots);
|
||||
CHECK_GE(nid, 1);
|
||||
deleted_nodes_.push_back(nid);
|
||||
nodes_[nid].MarkDelete();
|
||||
++param.num_deleted;
|
||||
@@ -576,14 +566,13 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
|
||||
return data_[i].flag == -1;
|
||||
}
|
||||
|
||||
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat,
|
||||
unsigned root_id) const {
|
||||
auto pid = static_cast<int>(root_id);
|
||||
while (!(*this)[pid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[pid].SplitIndex();
|
||||
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
}
|
||||
return pid;
|
||||
return nid;
|
||||
}
|
||||
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
|
||||
Reference in New Issue
Block a user