Fix feature names with special characters. (#9923)
This commit is contained in:
parent
a197899161
commit
a7226c0222
@ -66,8 +66,20 @@ inline std::vector<std::string> Split(const std::string& s, char delim) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Add escapes for a UTF-8 string.
|
||||||
|
*/
|
||||||
void EscapeU8(std::string const &string, std::string *p_buffer);
|
void EscapeU8(std::string const &string, std::string *p_buffer);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Add escapes for a UTF-8 string with newly created buffer as return.
|
||||||
|
*/
|
||||||
|
inline std::string EscapeU8(std::string const &str) {
|
||||||
|
std::string buffer;
|
||||||
|
EscapeU8(str, &buffer);
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
XGBOOST_DEVICE T Max(T a, T b) {
|
XGBOOST_DEVICE T Max(T a, T b) {
|
||||||
return a < b ? b : a;
|
return a < b ? b : a;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015-2023 by Contributors
|
* Copyright 2015-2023, XGBoost Contributors
|
||||||
* \file tree_model.cc
|
* \file tree_model.cc
|
||||||
* \brief model structure for tree
|
* \brief model structure for tree
|
||||||
*/
|
*/
|
||||||
@ -15,9 +15,9 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h" // for EscapeU8
|
||||||
#include "../predictor/predict_fn.h"
|
#include "../predictor/predict_fn.h"
|
||||||
#include "io_utils.h" // GetElem
|
#include "io_utils.h" // for GetElem
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -207,8 +207,9 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
|
|||||||
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
||||||
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
||||||
|
|
||||||
std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
namespace {
|
||||||
auto const &csr = tree.GetCategoriesMatrix();
|
std::vector<bst_cat_t> GetSplitCategories(RegTree const& tree, int32_t nidx) {
|
||||||
|
auto const& csr = tree.GetCategoriesMatrix();
|
||||||
auto seg = csr.node_ptr[nidx];
|
auto seg = csr.node_ptr[nidx];
|
||||||
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
|
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
|
||||||
|
|
||||||
@ -221,7 +222,7 @@ std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
|||||||
return cats;
|
return cats;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
std::string PrintCatsAsSet(std::vector<bst_cat_t> const& cats) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "{";
|
ss << "{";
|
||||||
for (size_t i = 0; i < cats.size(); ++i) {
|
for (size_t i = 0; i < cats.size(); ++i) {
|
||||||
@ -234,6 +235,15 @@ std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
|||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetFeatureName(FeatureMap const& fmap, bst_feature_t split_index) {
|
||||||
|
CHECK_LE(fmap.Size(), std::numeric_limits<decltype(split_index)>::max());
|
||||||
|
auto fname = split_index < static_cast<decltype(split_index)>(fmap.Size())
|
||||||
|
? fmap.Name(split_index)
|
||||||
|
: ('f' + std::to_string(split_index));
|
||||||
|
return common::EscapeU8(fname);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
class TextGenerator : public TreeGenerator {
|
class TextGenerator : public TreeGenerator {
|
||||||
using SuperT = TreeGenerator;
|
using SuperT = TreeGenerator;
|
||||||
|
|
||||||
@ -263,7 +273,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
std::string result = SuperT::Match(
|
std::string result = SuperT::Match(
|
||||||
kIndicatorTemplate,
|
kIndicatorTemplate,
|
||||||
{{"{nid}", std::to_string(nid)},
|
{{"{nid}", std::to_string(nid)},
|
||||||
{"{fname}", fmap_.Name(split_index)},
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
{"{yes}", std::to_string(nyes)},
|
{"{yes}", std::to_string(nyes)},
|
||||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
return result;
|
return result;
|
||||||
@ -277,8 +287,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
template_str,
|
template_str,
|
||||||
{{"{tabs}", SuperT::Tabs(depth)},
|
{{"{tabs}", SuperT::Tabs(depth)},
|
||||||
{"{nid}", std::to_string(nid)},
|
{"{nid}", std::to_string(nid)},
|
||||||
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
std::to_string(split_index)},
|
|
||||||
{"{cond}", cond},
|
{"{cond}", cond},
|
||||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||||
@ -308,7 +317,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate =
|
||||||
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,7 +385,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string LeafNode(RegTree const& tree, bst_node_t nid, uint32_t) const override {
|
||||||
static std::string const kLeafTemplate =
|
static std::string const kLeafTemplate =
|
||||||
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
||||||
static std::string const kStatTemplate =
|
static std::string const kStatTemplate =
|
||||||
@ -392,26 +401,22 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
std::string Indicator(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
|
||||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||||
static std::string const kIndicatorTemplate =
|
static std::string const kIndicatorTemplate =
|
||||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
auto fname = fmap_.Name(split_index);
|
auto result =
|
||||||
std::string qfname; // quoted
|
SuperT::Match(kIndicatorTemplate, {{"{nid}", std::to_string(nid)},
|
||||||
common::EscapeU8(fname, &qfname);
|
|
||||||
auto result = SuperT::Match(
|
|
||||||
kIndicatorTemplate,
|
|
||||||
{{"{nid}", std::to_string(nid)},
|
|
||||||
{"{depth}", std::to_string(depth)},
|
{"{depth}", std::to_string(depth)},
|
||||||
{"{fname}", qfname},
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
{"{yes}", std::to_string(nyes)},
|
{"{yes}", std::to_string(nyes)},
|
||||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
|
||||||
auto cats = GetSplitCategories(tree, nid);
|
auto cats = GetSplitCategories(tree, nid);
|
||||||
static std::string const kCategoryTemplate =
|
static std::string const kCategoryTemplate =
|
||||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
@ -429,18 +434,13 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str,
|
||||||
std::string const &template_str, std::string cond,
|
std::string cond, uint32_t depth) const {
|
||||||
uint32_t depth) const {
|
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index);
|
std::string const result =
|
||||||
std::string qfname; // quoted
|
SuperT::Match(template_str, {{"{nid}", std::to_string(nid)},
|
||||||
common::EscapeU8(fname, &qfname);
|
|
||||||
std::string const result = SuperT::Match(
|
|
||||||
template_str,
|
|
||||||
{{"{nid}", std::to_string(nid)},
|
|
||||||
{"{depth}", std::to_string(depth)},
|
{"{depth}", std::to_string(depth)},
|
||||||
{"{fname}", qfname},
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
{"{cond}", cond},
|
{"{cond}", cond},
|
||||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||||
@ -605,9 +605,8 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
auto const& extra = kwargs["graph_attrs"];
|
auto const& extra = kwargs["graph_attrs"];
|
||||||
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
|
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
|
||||||
for (auto const& kv : extra) {
|
for (auto const& kv : extra) {
|
||||||
param_.graph_attrs += SuperT::Match(kGraphTemplate,
|
param_.graph_attrs +=
|
||||||
{{"{key}", kv.first},
|
SuperT::Match(kGraphTemplate, {{"{key}", kv.first}, {"{value}", kv.second}});
|
||||||
{"{value}", kv.second}});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs.erase("graph_attrs");
|
kwargs.erase("graph_attrs");
|
||||||
@ -646,17 +645,15 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
// Only indicator is different, so we combine all different node types into this
|
// Only indicator is different, so we combine all different node types into this
|
||||||
// function.
|
// function.
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||||
auto split = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||||
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
|
||||||
|
|
||||||
// Indicator only has fname.
|
bool has_less =
|
||||||
bool has_less = (split >= fmap_.Size()) || fmap_.TypeOf(split) != FeatureMap::kIndicator;
|
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
|
||||||
std::string result = SuperT::Match(kNodeTemplate, {
|
std::string result =
|
||||||
{"{nid}", std::to_string(nid)},
|
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
|
||||||
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
'f' + std::to_string(split)},
|
|
||||||
{"{<}", has_less ? "<" : ""},
|
{"{<}", has_less ? "<" : ""},
|
||||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||||
{"{params}", param_.condition_node_params}});
|
{"{params}", param_.condition_node_params}});
|
||||||
@ -672,12 +669,11 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||||
auto cats = GetSplitCategories(tree, nid);
|
auto cats = GetSplitCategories(tree, nid);
|
||||||
auto cats_str = PrintCatsAsSet(cats);
|
auto cats_str = PrintCatsAsSet(cats);
|
||||||
auto split = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
std::string result = SuperT::Match(
|
|
||||||
kLabelTemplate,
|
std::string result =
|
||||||
{{"{nid}", std::to_string(nid)},
|
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
|
||||||
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
: 'f' + std::to_string(split)},
|
|
||||||
{"{cond}", cats_str},
|
{"{cond}", cats_str},
|
||||||
{"{params}", param_.condition_node_params}});
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
|
|||||||
@ -404,7 +404,7 @@ TEST(Tree, DumpText) {
|
|||||||
}
|
}
|
||||||
ASSERT_EQ(n_conditions, 3ul);
|
ASSERT_EQ(n_conditions, 3ul);
|
||||||
|
|
||||||
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
ASSERT_NE(str.find("[f0<0]"), std::string::npos) << str;
|
||||||
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
||||||
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
|
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
|
||||||
|
|
||||||
|
|||||||
@ -28,10 +28,11 @@ def json_model(model_path: str, parameters: dict) -> dict:
|
|||||||
|
|
||||||
if model_path.endswith("ubj"):
|
if model_path.endswith("ubj"):
|
||||||
import ubjson
|
import ubjson
|
||||||
|
|
||||||
with open(model_path, "rb") as ubjfd:
|
with open(model_path, "rb") as ubjfd:
|
||||||
model = ubjson.load(ubjfd)
|
model = ubjson.load(ubjfd)
|
||||||
else:
|
else:
|
||||||
with open(model_path, 'r') as fd:
|
with open(model_path, "r") as fd:
|
||||||
model = json.load(fd)
|
model = json.load(fd)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@ -439,25 +440,34 @@ class TestModels:
|
|||||||
'objective': 'multi:softmax'}
|
'objective': 'multi:softmax'}
|
||||||
validate_model(parameters)
|
validate_model(parameters)
|
||||||
|
|
||||||
def test_special_model_dump_characters(self):
|
def test_special_model_dump_characters(self) -> None:
|
||||||
params = {"objective": "reg:squarederror", "max_depth": 3}
|
params = {"objective": "reg:squarederror", "max_depth": 3}
|
||||||
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
|
feature_names = ['"feature 0"', "\tfeature\n1", """feature "2"."""]
|
||||||
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
|
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
|
||||||
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
|
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
|
||||||
booster = xgb.train(params, Xy, num_boost_round=3)
|
booster = xgb.train(params, Xy, num_boost_round=3)
|
||||||
|
|
||||||
json_dump = booster.get_dump(dump_format="json")
|
json_dump = booster.get_dump(dump_format="json")
|
||||||
assert len(json_dump) == 3
|
assert len(json_dump) == 3
|
||||||
|
|
||||||
def validate(obj: dict) -> None:
|
def validate_json(obj: dict) -> None:
|
||||||
for k, v in obj.items():
|
for k, v in obj.items():
|
||||||
if k == "split":
|
if k == "split":
|
||||||
assert v in feature_names
|
assert v in feature_names
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
validate(v)
|
validate_json(v)
|
||||||
|
|
||||||
for j_tree in json_dump:
|
for j_tree in json_dump:
|
||||||
loaded = json.loads(j_tree)
|
loaded = json.loads(j_tree)
|
||||||
validate(loaded)
|
validate_json(loaded)
|
||||||
|
|
||||||
|
dot_dump = booster.get_dump(dump_format="dot")
|
||||||
|
for d in dot_dump:
|
||||||
|
assert d.find(r"feature \"2\"") != -1
|
||||||
|
|
||||||
|
text_dump = booster.get_dump(dump_format="text")
|
||||||
|
for d in text_dump:
|
||||||
|
assert d.find(r"feature \"2\"") != -1
|
||||||
|
|
||||||
def test_categorical_model_io(self):
|
def test_categorical_model_io(self):
|
||||||
X, y = tm.make_categorical(256, 16, 71, False)
|
X, y = tm.make_categorical(256, 16, 71, False)
|
||||||
@ -485,6 +495,7 @@ class TestModels:
|
|||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_attributes(self):
|
def test_attributes(self):
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
X, y = load_iris(return_X_y=True)
|
X, y = load_iris(return_X_y=True)
|
||||||
cls = xgb.XGBClassifier(n_estimators=2)
|
cls = xgb.XGBClassifier(n_estimators=2)
|
||||||
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
|
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
|
||||||
@ -674,6 +685,7 @@ class TestModels:
|
|||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_feature_info(self):
|
def test_feature_info(self):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
rows = 100
|
rows = 100
|
||||||
cols = 10
|
cols = 10
|
||||||
X = rng.randn(rows, cols)
|
X = rng.randn(rows, cols)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user