Predictor for vector leaf. (#8898)
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#define XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#include "../common/categorical.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
namespace xgboost::predictor {
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
|
||||
float fvalue, bool is_missing,
|
||||
@@ -24,6 +23,25 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree,
|
||||
bst_node_t const nidx, float fvalue,
|
||||
bool is_missing,
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
if (has_missing && is_missing) {
|
||||
return tree.DefaultChild(nidx);
|
||||
} else {
|
||||
if (has_categorical && common::IsCat(cats.split_type, nidx)) {
|
||||
auto node_categories =
|
||||
cats.categories.subspan(cats.node_ptr[nidx].beg, cats.node_ptr[nidx].size);
|
||||
return common::Decision(node_categories, fvalue) ? tree.LeftChild(nidx)
|
||||
: tree.RightChild(nidx);
|
||||
} else {
|
||||
return tree.LeftChild(nidx) + !(fvalue < tree.SplitCond(nidx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::predictor
|
||||
#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
|
||||
Reference in New Issue
Block a user