Support categorical split in tree model dump. (#7036)
This commit is contained in:
@@ -241,6 +241,65 @@ RegTree ConstructTree() {
|
||||
/*right_sum=*/0.0f);
|
||||
return tree;
|
||||
}
|
||||
|
||||
RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
|
||||
RegTree tree;
|
||||
std::vector<uint32_t> cats_storage(common::CatBitField::ComputeStorageSize(33), 0);
|
||||
common::CatBitField split_cats(cats_storage);
|
||||
split_cats.Set(0);
|
||||
split_cats.Set(14);
|
||||
split_cats.Set(32);
|
||||
|
||||
cond->push_back(0);
|
||||
cond->push_back(14);
|
||||
cond->push_back(32);
|
||||
|
||||
tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
|
||||
3.00, 11.0, 2.0, 3.0, 4.0);
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(
|
||||
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
|
||||
2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
|
||||
return tree;
|
||||
}
|
||||
|
||||
void TestCategoricalTreeDump(std::string format, std::string sep) {
|
||||
std::vector<bst_cat_t> cond;
|
||||
auto tree = ConstructTreeCat(&cond);
|
||||
|
||||
FeatureMap fmap;
|
||||
auto str = tree.DumpModel(fmap, true, format);
|
||||
std::string cond_str;
|
||||
for (size_t c = 0; c < cond.size(); ++c) {
|
||||
cond_str += std::to_string(cond[c]);
|
||||
if (c != cond.size() - 1) {
|
||||
cond_str += sep;
|
||||
}
|
||||
}
|
||||
auto pos = str.find(cond_str);
|
||||
ASSERT_NE(pos, std::string::npos);
|
||||
pos = str.find(cond_str, pos + 1);
|
||||
ASSERT_NE(pos, std::string::npos);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "categorical");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
fmap.PushBack(2, "feat_2", "int");
|
||||
|
||||
str = tree.DumpModel(fmap, true, format);
|
||||
pos = str.find(cond_str);
|
||||
ASSERT_NE(pos, std::string::npos);
|
||||
pos = str.find(cond_str, pos + 1);
|
||||
ASSERT_NE(pos, std::string::npos);
|
||||
|
||||
if (format == "json") {
|
||||
// Make sure it's valid JSON
|
||||
Json::Load(StringView{str});
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Tree, DumpJson) {
|
||||
@@ -278,6 +337,10 @@ TEST(Tree, DumpJson) {
|
||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpJsonCategorical) {
|
||||
TestCategoricalTreeDump("json", ", ");
|
||||
}
|
||||
|
||||
TEST(Tree, DumpText) {
|
||||
auto tree = ConstructTree();
|
||||
FeatureMap fmap;
|
||||
@@ -313,6 +376,10 @@ TEST(Tree, DumpText) {
|
||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpTextCategorical) {
|
||||
TestCategoricalTreeDump("text", ",");
|
||||
}
|
||||
|
||||
TEST(Tree, DumpDot) {
|
||||
auto tree = ConstructTree();
|
||||
FeatureMap fmap;
|
||||
@@ -350,6 +417,10 @@ TEST(Tree, DumpDot) {
|
||||
ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpDotCategorical) {
|
||||
TestCategoricalTreeDump("dot", ",");
|
||||
}
|
||||
|
||||
TEST(Tree, JsonIO) {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
|
||||
40
tests/python-gpu/test_gpu_plotting.py
Normal file
40
tests/python-gpu/test_gpu_plotting.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import sys
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import json
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Source
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz()))
|
||||
|
||||
|
||||
class TestPlotting:
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self):
|
||||
X, y = tm.make_categorical(1000, 31, 19, onehot=False)
|
||||
reg = xgb.XGBRegressor(
|
||||
enable_categorical=True, n_estimators=10, tree_method="gpu_hist"
|
||||
)
|
||||
reg.fit(X, y)
|
||||
trees = reg.get_booster().get_dump(dump_format="json")
|
||||
for tree in trees:
|
||||
j_tree = json.loads(tree)
|
||||
assert "leaf" in j_tree.keys() or isinstance(
|
||||
j_tree["split_condition"], list
|
||||
)
|
||||
|
||||
graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1)
|
||||
assert isinstance(graph, Source)
|
||||
ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -71,7 +71,6 @@ class TestGPUUpdaters:
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
pytest.xfail(reason='TestGPUUpdaters::test_categorical is flaky')
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
|
||||
def test_categorical_32_cat(self):
|
||||
|
||||
@@ -55,7 +55,6 @@ def test_categorical():
|
||||
tree_method="gpu_hist",
|
||||
use_label_encoder=False,
|
||||
enable_categorical=True,
|
||||
predictor="gpu_predictor",
|
||||
n_estimators=10,
|
||||
)
|
||||
X = pd.DataFrame(X.todense()).astype("category")
|
||||
|
||||
Reference in New Issue
Block a user