merge latest, Jan 12 2024
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2015-2023 by Contributors
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file tree_model.cc
|
||||
* \brief model structure for tree
|
||||
*/
|
||||
@@ -15,9 +15,9 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/common.h" // for EscapeU8
|
||||
#include "../predictor/predict_fn.h"
|
||||
#include "io_utils.h" // GetElem
|
||||
#include "io_utils.h" // for GetElem
|
||||
#include "param.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
@@ -207,8 +207,9 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
|
||||
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
||||
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
||||
|
||||
std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
||||
auto const &csr = tree.GetCategoriesMatrix();
|
||||
namespace {
|
||||
std::vector<bst_cat_t> GetSplitCategories(RegTree const& tree, int32_t nidx) {
|
||||
auto const& csr = tree.GetCategoriesMatrix();
|
||||
auto seg = csr.node_ptr[nidx];
|
||||
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
|
||||
|
||||
@@ -221,7 +222,7 @@ std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
||||
return cats;
|
||||
}
|
||||
|
||||
std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
||||
std::string PrintCatsAsSet(std::vector<bst_cat_t> const& cats) {
|
||||
std::stringstream ss;
|
||||
ss << "{";
|
||||
for (size_t i = 0; i < cats.size(); ++i) {
|
||||
@@ -234,6 +235,15 @@ std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string GetFeatureName(FeatureMap const& fmap, bst_feature_t split_index) {
|
||||
CHECK_LE(fmap.Size(), std::numeric_limits<decltype(split_index)>::max());
|
||||
auto fname = split_index < static_cast<decltype(split_index)>(fmap.Size())
|
||||
? fmap.Name(split_index)
|
||||
: ('f' + std::to_string(split_index));
|
||||
return common::EscapeU8(fname);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class TextGenerator : public TreeGenerator {
|
||||
using SuperT = TreeGenerator;
|
||||
|
||||
@@ -263,7 +273,7 @@ class TextGenerator : public TreeGenerator {
|
||||
std::string result = SuperT::Match(
|
||||
kIndicatorTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", fmap_.Name(split_index)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{yes}", std::to_string(nyes)},
|
||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||
return result;
|
||||
@@ -277,8 +287,7 @@ class TextGenerator : public TreeGenerator {
|
||||
template_str,
|
||||
{{"{tabs}", SuperT::Tabs(depth)},
|
||||
{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
|
||||
std::to_string(split_index)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{cond}", cond},
|
||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||
@@ -308,7 +317,7 @@ class TextGenerator : public TreeGenerator {
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
@@ -376,7 +385,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
std::string LeafNode(RegTree const& tree, bst_node_t nid, uint32_t) const override {
|
||||
static std::string const kLeafTemplate =
|
||||
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
||||
static std::string const kStatTemplate =
|
||||
@@ -392,26 +401,22 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
std::string Indicator(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
|
||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||
static std::string const kIndicatorTemplate =
|
||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto fname = fmap_.Name(split_index);
|
||||
std::string qfname; // quoted
|
||||
common::EscapeU8(fname, &qfname);
|
||||
auto result = SuperT::Match(
|
||||
kIndicatorTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{depth}", std::to_string(depth)},
|
||||
{"{fname}", qfname},
|
||||
{"{yes}", std::to_string(nyes)},
|
||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||
auto result =
|
||||
SuperT::Match(kIndicatorTemplate, {{"{nid}", std::to_string(nid)},
|
||||
{"{depth}", std::to_string(depth)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{yes}", std::to_string(nyes)},
|
||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
static std::string const kCategoryTemplate =
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||
@@ -429,22 +434,17 @@ class JsonGenerator : public TreeGenerator {
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
||||
std::string const &template_str, std::string cond,
|
||||
uint32_t depth) const {
|
||||
std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str,
|
||||
std::string cond, uint32_t depth) const {
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index);
|
||||
std::string qfname; // quoted
|
||||
common::EscapeU8(fname, &qfname);
|
||||
std::string const result = SuperT::Match(
|
||||
template_str,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{depth}", std::to_string(depth)},
|
||||
{"{fname}", qfname},
|
||||
{"{cond}", cond},
|
||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
|
||||
std::string const result =
|
||||
SuperT::Match(template_str, {{"{nid}", std::to_string(nid)},
|
||||
{"{depth}", std::to_string(depth)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{cond}", cond},
|
||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -605,9 +605,8 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
auto const& extra = kwargs["graph_attrs"];
|
||||
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
|
||||
for (auto const& kv : extra) {
|
||||
param_.graph_attrs += SuperT::Match(kGraphTemplate,
|
||||
{{"{key}", kv.first},
|
||||
{"{value}", kv.second}});
|
||||
param_.graph_attrs +=
|
||||
SuperT::Match(kGraphTemplate, {{"{key}", kv.first}, {"{value}", kv.second}});
|
||||
}
|
||||
|
||||
kwargs.erase("graph_attrs");
|
||||
@@ -646,20 +645,18 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
// Only indicator is different, so we combine all different node types into this
|
||||
// function.
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
auto split = tree[nid].SplitIndex();
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||
|
||||
// Indicator only has fname.
|
||||
bool has_less = (split >= fmap_.Size()) || fmap_.TypeOf(split) != FeatureMap::kIndicator;
|
||||
std::string result = SuperT::Match(kNodeTemplate, {
|
||||
{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
|
||||
'f' + std::to_string(split)},
|
||||
{"{<}", has_less ? "<" : ""},
|
||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
bool has_less =
|
||||
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
|
||||
std::string result =
|
||||
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{<}", has_less ? "<" : ""},
|
||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
|
||||
@@ -672,14 +669,13 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
auto cats_str = PrintCatsAsSet(cats);
|
||||
auto split = tree[nid].SplitIndex();
|
||||
std::string result = SuperT::Match(
|
||||
kLabelTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
|
||||
: 'f' + std::to_string(split)},
|
||||
{"{cond}", cats_str},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
|
||||
std::string result =
|
||||
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{cond}", cats_str},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file updater_colmaker.cc
|
||||
* \brief use columnwise update to construct a tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/error_msg.h" // for NoCategorical
|
||||
#include "../common/random.h"
|
||||
#include "constraints.h"
|
||||
#include "param.h"
|
||||
#include "split_evaluator.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "param.h"
|
||||
#include "constraints.h"
|
||||
#include "../common/random.h"
|
||||
#include "split_evaluator.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
@@ -102,6 +103,9 @@ class ColMaker: public TreeUpdater {
|
||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||
"support external memory training.";
|
||||
}
|
||||
if (dmat->Info().HasCategorical()) {
|
||||
LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method");
|
||||
}
|
||||
this->LazyGetColumnDensity(dmat);
|
||||
// rescale learning rate according to size of trees
|
||||
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
|
||||
|
||||
@@ -545,12 +545,12 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
|
||||
if (p_impl_) {
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (p_mtimpl_) {
|
||||
if (out_preds.Shape(1) > 1) {
|
||||
CHECK(p_mtimpl_);
|
||||
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
CHECK(p_impl_);
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user