Support categorical split in tree model dump. (#7036)
This commit is contained in:
parent
7968c0d051
commit
29f8fd6fee
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014 by Contributors
|
* Copyright 2014-2021 by Contributors
|
||||||
* \file feature_map.h
|
* \file feature_map.h
|
||||||
* \brief Feature map data structure to help visualization and model dump.
|
* \brief Feature map data structure to help visualization and model dump.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -26,7 +26,8 @@ class FeatureMap {
|
|||||||
kIndicator = 0,
|
kIndicator = 0,
|
||||||
kQuantitive = 1,
|
kQuantitive = 1,
|
||||||
kInteger = 2,
|
kInteger = 2,
|
||||||
kFloat = 3
|
kFloat = 3,
|
||||||
|
kCategorical = 4
|
||||||
};
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief load feature map from input stream
|
* \brief load feature map from input stream
|
||||||
@ -82,6 +83,7 @@ class FeatureMap {
|
|||||||
if (!strcmp("q", tname)) return kQuantitive;
|
if (!strcmp("q", tname)) return kQuantitive;
|
||||||
if (!strcmp("int", tname)) return kInteger;
|
if (!strcmp("int", tname)) return kInteger;
|
||||||
if (!strcmp("float", tname)) return kFloat;
|
if (!strcmp("float", tname)) return kFloat;
|
||||||
|
if (!strcmp("categorical", tname)) return kCategorical;
|
||||||
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
|
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
|
||||||
return kIndicator;
|
return kIndicator;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
"""Plotting Library."""
|
"""Plotting Library."""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster
|
from .core import Booster
|
||||||
from .sklearn import XGBModel
|
from .sklearn import XGBModel
|
||||||
@ -203,7 +204,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
|
|||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
parameters += ':'
|
parameters += ':'
|
||||||
parameters += str(kwargs)
|
parameters += json.dumps(kwargs)
|
||||||
tree = booster.get_dump(
|
tree = booster.get_dump(
|
||||||
fmap=fmap,
|
fmap=fmap,
|
||||||
dump_format=parameters)[num_trees]
|
dump_format=parameters)[num_trees]
|
||||||
|
|||||||
@ -52,11 +52,6 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
|||||||
if (tree_info[i] == bst_group) {
|
if (tree_info[i] == bst_group) {
|
||||||
auto const &tree = *trees[i];
|
auto const &tree = *trees[i];
|
||||||
bool has_categorical = tree.HasCategoricalSplit();
|
bool has_categorical = tree.HasCategoricalSplit();
|
||||||
|
|
||||||
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
|
|
||||||
auto split_types = tree.GetSplitTypes();
|
|
||||||
auto categories_ptr =
|
|
||||||
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
|
|
||||||
auto cats = tree.GetCategoriesMatrix();
|
auto cats = tree.GetCategoriesMatrix();
|
||||||
bst_node_t nidx = -1;
|
bst_node_t nidx = -1;
|
||||||
if (has_categorical) {
|
if (has_categorical) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2015-2020 by Contributors
|
* Copyright 2015-2021 by Contributors
|
||||||
* \file tree_model.cc
|
* \file tree_model.cc
|
||||||
* \brief model structure for tree
|
* \brief model structure for tree
|
||||||
*/
|
*/
|
||||||
@ -74,6 +74,7 @@ class TreeGenerator {
|
|||||||
int32_t /*nid*/, uint32_t /*depth*/) const {
|
int32_t /*nid*/, uint32_t /*depth*/) const {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
virtual std::string Categorical(RegTree const&, int32_t, uint32_t) const = 0;
|
||||||
virtual std::string Integer(RegTree const& /*tree*/,
|
virtual std::string Integer(RegTree const& /*tree*/,
|
||||||
int32_t /*nid*/, uint32_t /*depth*/) const {
|
int32_t /*nid*/, uint32_t /*depth*/) const {
|
||||||
return "";
|
return "";
|
||||||
@ -92,27 +93,52 @@ class TreeGenerator {
|
|||||||
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
auto const split_index = tree[nid].SplitIndex();
|
auto const split_index = tree[nid].SplitIndex();
|
||||||
std::string result;
|
std::string result;
|
||||||
|
auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
if (split_index < fmap_.Size()) {
|
if (split_index < fmap_.Size()) {
|
||||||
|
auto check_categorical = [&]() {
|
||||||
|
CHECK(is_categorical)
|
||||||
|
<< fmap_.Name(split_index)
|
||||||
|
<< " in feature map is numerical but tree node is categorical.";
|
||||||
|
};
|
||||||
|
auto check_numerical = [&]() {
|
||||||
|
auto is_numerical = !is_categorical;
|
||||||
|
CHECK(is_numerical)
|
||||||
|
<< fmap_.Name(split_index)
|
||||||
|
<< " in feature map is categorical but tree node is numerical.";
|
||||||
|
};
|
||||||
|
|
||||||
switch (fmap_.TypeOf(split_index)) {
|
switch (fmap_.TypeOf(split_index)) {
|
||||||
|
case FeatureMap::kCategorical: {
|
||||||
|
check_categorical();
|
||||||
|
result = this->Categorical(tree, nid, depth);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case FeatureMap::kIndicator: {
|
case FeatureMap::kIndicator: {
|
||||||
|
check_numerical();
|
||||||
result = this->Indicator(tree, nid, depth);
|
result = this->Indicator(tree, nid, depth);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FeatureMap::kInteger: {
|
case FeatureMap::kInteger: {
|
||||||
|
check_numerical();
|
||||||
result = this->Integer(tree, nid, depth);
|
result = this->Integer(tree, nid, depth);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FeatureMap::kFloat:
|
case FeatureMap::kFloat:
|
||||||
case FeatureMap::kQuantitive: {
|
case FeatureMap::kQuantitive: {
|
||||||
|
check_numerical();
|
||||||
result = this->Quantitive(tree, nid, depth);
|
result = this->Quantitive(tree, nid, depth);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown feature map type.";
|
LOG(FATAL) << "Unknown feature map type.";
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if (is_categorical) {
|
||||||
|
result = this->Categorical(tree, nid, depth);
|
||||||
} else {
|
} else {
|
||||||
result = this->PlainNode(tree, nid, depth);
|
result = this->PlainNode(tree, nid, depth);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,6 +205,32 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
|
|||||||
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
||||||
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
||||||
|
|
||||||
|
std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
||||||
|
auto const &csr = tree.GetCategoriesMatrix();
|
||||||
|
auto seg = csr.node_ptr[nidx];
|
||||||
|
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
|
||||||
|
|
||||||
|
std::vector<bst_cat_t> cats;
|
||||||
|
for (size_t i = 0; i < split.Size(); ++i) {
|
||||||
|
if (split.Check(i)) {
|
||||||
|
cats.push_back(static_cast<bst_cat_t>(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cats;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "{";
|
||||||
|
for (size_t i = 0; i < cats.size(); ++i) {
|
||||||
|
ss << cats[i];
|
||||||
|
if (i != cats.size() - 1) {
|
||||||
|
ss << ",";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "}";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
class TextGenerator : public TreeGenerator {
|
class TextGenerator : public TreeGenerator {
|
||||||
using SuperT = TreeGenerator;
|
using SuperT = TreeGenerator;
|
||||||
@ -258,6 +310,17 @@ class TextGenerator : public TreeGenerator {
|
|||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Categorical(RegTree const &tree, int32_t nid,
|
||||||
|
uint32_t depth) const override {
|
||||||
|
auto cats = GetSplitCategories(tree, nid);
|
||||||
|
std::string cats_str = PrintCatsAsSet(cats);
|
||||||
|
static std::string const kNodeTemplate =
|
||||||
|
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
|
||||||
|
std::string const result =
|
||||||
|
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
@ -343,6 +406,24 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
|
auto cats = GetSplitCategories(tree, nid);
|
||||||
|
static std::string const kCategoryTemplate =
|
||||||
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
|
R"I("split_condition": {cond}, "yes": {right}, "no": {left}, )I"
|
||||||
|
R"I("missing": {missing})I";
|
||||||
|
std::string cats_ptr = "[";
|
||||||
|
for (size_t i = 0; i < cats.size(); ++i) {
|
||||||
|
cats_ptr += std::to_string(cats[i]);
|
||||||
|
if (i != cats.size() - 1) {
|
||||||
|
cats_ptr += ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cats_ptr += "]";
|
||||||
|
auto results = SplitNodeImpl(tree, nid, kCategoryTemplate, cats_ptr, depth);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
||||||
std::string const &template_str, std::string cond,
|
std::string const &template_str, std::string cond,
|
||||||
uint32_t depth) const {
|
uint32_t depth) const {
|
||||||
@ -534,6 +615,27 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
template <bool is_categorical>
|
||||||
|
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
|
||||||
|
static std::string const kEdgeTemplate =
|
||||||
|
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
||||||
|
// Is this the default child for missing value?
|
||||||
|
bool is_missing = tree[nid].DefaultChild() == child;
|
||||||
|
std::string branch;
|
||||||
|
if (is_categorical) {
|
||||||
|
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
|
||||||
|
} else {
|
||||||
|
branch = std::string{left ? "yes" : "no"} + std::string{is_missing ? ", missing" : ""};
|
||||||
|
}
|
||||||
|
std::string buffer =
|
||||||
|
SuperT::Match(kEdgeTemplate,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{child}", std::to_string(child)},
|
||||||
|
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||||
|
{"{branch}", branch}});
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
// Only indicator is different, so we combine all different node types into this
|
// Only indicator is different, so we combine all different node types into this
|
||||||
// function.
|
// function.
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||||
@ -552,27 +654,32 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||||
{"{params}", param_.condition_node_params}});
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
static std::string const kEdgeTemplate =
|
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
|
||||||
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
|
||||||
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
|
|
||||||
auto BuildEdge =
|
|
||||||
[&tree, nid, MatchFn, this](int32_t child, bool left) {
|
|
||||||
// Is this the default child for missing value?
|
|
||||||
bool is_missing = tree[nid].DefaultChild() == child;
|
|
||||||
std::string branch = std::string {left ? "yes" : "no"} +
|
|
||||||
std::string {is_missing ? ", missing" : ""};
|
|
||||||
std::string buffer = MatchFn(kEdgeTemplate, {
|
|
||||||
{"{nid}", std::to_string(nid)},
|
|
||||||
{"{child}", std::to_string(child)},
|
|
||||||
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
|
||||||
{"{branch}", branch}});
|
|
||||||
return buffer;
|
|
||||||
};
|
|
||||||
result += BuildEdge(tree[nid].LeftChild(), true);
|
|
||||||
result += BuildEdge(tree[nid].RightChild(), false);
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||||
|
static std::string const kLabelTemplate =
|
||||||
|
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||||
|
auto cats = GetSplitCategories(tree, nid);
|
||||||
|
auto cats_str = PrintCatsAsSet(cats);
|
||||||
|
auto split = tree[nid].SplitIndex();
|
||||||
|
std::string result = SuperT::Match(
|
||||||
|
kLabelTemplate,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
|
||||||
|
: 'f' + std::to_string(split)},
|
||||||
|
{"{cond}", cats_str},
|
||||||
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
|
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
|
||||||
|
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||||
static std::string const kLeafTemplate =
|
static std::string const kLeafTemplate =
|
||||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||||
@ -588,9 +695,12 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
return this->LeafNode(tree, nid, depth);
|
return this->LeafNode(tree, nid, depth);
|
||||||
}
|
}
|
||||||
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
||||||
|
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
|
||||||
|
? this->Categorical(tree, nid, depth)
|
||||||
|
: this->PlainNode(tree, nid, depth);
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
kNodeTemplate,
|
kNodeTemplate,
|
||||||
{{"{parent}", this->PlainNode(tree, nid, depth)},
|
{{"{parent}", node},
|
||||||
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||||
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
@ -241,6 +241,65 @@ RegTree ConstructTree() {
|
|||||||
/*right_sum=*/0.0f);
|
/*right_sum=*/0.0f);
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
|
||||||
|
RegTree tree;
|
||||||
|
std::vector<uint32_t> cats_storage(common::CatBitField::ComputeStorageSize(33), 0);
|
||||||
|
common::CatBitField split_cats(cats_storage);
|
||||||
|
split_cats.Set(0);
|
||||||
|
split_cats.Set(14);
|
||||||
|
split_cats.Set(32);
|
||||||
|
|
||||||
|
cond->push_back(0);
|
||||||
|
cond->push_back(14);
|
||||||
|
cond->push_back(32);
|
||||||
|
|
||||||
|
tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
|
||||||
|
3.00, 11.0, 2.0, 3.0, 4.0);
|
||||||
|
auto left = tree[0].LeftChild();
|
||||||
|
auto right = tree[0].RightChild();
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||||
|
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||||
|
/*right_sum=*/0.0f);
|
||||||
|
tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
|
||||||
|
2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCategoricalTreeDump(std::string format, std::string sep) {
|
||||||
|
std::vector<bst_cat_t> cond;
|
||||||
|
auto tree = ConstructTreeCat(&cond);
|
||||||
|
|
||||||
|
FeatureMap fmap;
|
||||||
|
auto str = tree.DumpModel(fmap, true, format);
|
||||||
|
std::string cond_str;
|
||||||
|
for (size_t c = 0; c < cond.size(); ++c) {
|
||||||
|
cond_str += std::to_string(cond[c]);
|
||||||
|
if (c != cond.size() - 1) {
|
||||||
|
cond_str += sep;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto pos = str.find(cond_str);
|
||||||
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
pos = str.find(cond_str, pos + 1);
|
||||||
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
|
||||||
|
fmap.PushBack(0, "feat_0", "categorical");
|
||||||
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
|
fmap.PushBack(2, "feat_2", "int");
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, true, format);
|
||||||
|
pos = str.find(cond_str);
|
||||||
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
pos = str.find(cond_str, pos + 1);
|
||||||
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
|
||||||
|
if (format == "json") {
|
||||||
|
// Make sure it's valid JSON
|
||||||
|
Json::Load(StringView{str});
|
||||||
|
}
|
||||||
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(Tree, DumpJson) {
|
TEST(Tree, DumpJson) {
|
||||||
@ -278,6 +337,10 @@ TEST(Tree, DumpJson) {
|
|||||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
|
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpJsonCategorical) {
|
||||||
|
TestCategoricalTreeDump("json", ", ");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Tree, DumpText) {
|
TEST(Tree, DumpText) {
|
||||||
auto tree = ConstructTree();
|
auto tree = ConstructTree();
|
||||||
FeatureMap fmap;
|
FeatureMap fmap;
|
||||||
@ -313,6 +376,10 @@ TEST(Tree, DumpText) {
|
|||||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpTextCategorical) {
|
||||||
|
TestCategoricalTreeDump("text", ",");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Tree, DumpDot) {
|
TEST(Tree, DumpDot) {
|
||||||
auto tree = ConstructTree();
|
auto tree = ConstructTree();
|
||||||
FeatureMap fmap;
|
FeatureMap fmap;
|
||||||
@ -350,6 +417,10 @@ TEST(Tree, DumpDot) {
|
|||||||
ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
|
ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpDotCategorical) {
|
||||||
|
TestCategoricalTreeDump("dot", ",");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Tree, JsonIO) {
|
TEST(Tree, JsonIO) {
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
|||||||
40
tests/python-gpu/test_gpu_plotting.py
Normal file
40
tests/python-gpu/test_gpu_plotting.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import sys
|
||||||
|
import xgboost as xgb
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
|
sys.path.append("tests/python")
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
from matplotlib.axes import Axes
|
||||||
|
from graphviz import Source
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz()))
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlotting:
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_categorical(self):
|
||||||
|
X, y = tm.make_categorical(1000, 31, 19, onehot=False)
|
||||||
|
reg = xgb.XGBRegressor(
|
||||||
|
enable_categorical=True, n_estimators=10, tree_method="gpu_hist"
|
||||||
|
)
|
||||||
|
reg.fit(X, y)
|
||||||
|
trees = reg.get_booster().get_dump(dump_format="json")
|
||||||
|
for tree in trees:
|
||||||
|
j_tree = json.loads(tree)
|
||||||
|
assert "leaf" in j_tree.keys() or isinstance(
|
||||||
|
j_tree["split_condition"], list
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1)
|
||||||
|
assert isinstance(graph, Source)
|
||||||
|
ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1)
|
||||||
|
assert isinstance(ax, Axes)
|
||||||
@ -71,7 +71,6 @@ class TestGPUUpdaters:
|
|||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_categorical(self, rows, cols, rounds, cats):
|
def test_categorical(self, rows, cols, rounds, cats):
|
||||||
pytest.xfail(reason='TestGPUUpdaters::test_categorical is flaky')
|
|
||||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||||
|
|
||||||
def test_categorical_32_cat(self):
|
def test_categorical_32_cat(self):
|
||||||
|
|||||||
@ -55,7 +55,6 @@ def test_categorical():
|
|||||||
tree_method="gpu_hist",
|
tree_method="gpu_hist",
|
||||||
use_label_encoder=False,
|
use_label_encoder=False,
|
||||||
enable_categorical=True,
|
enable_categorical=True,
|
||||||
predictor="gpu_predictor",
|
|
||||||
n_estimators=10,
|
n_estimators=10,
|
||||||
)
|
)
|
||||||
X = pd.DataFrame(X.todense()).astype("category")
|
X = pd.DataFrame(X.todense()).astype("category")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user