More categorical tests and disable shap sparse test. (#6219)
* Fix tree load with 32 category.
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "xgboost/json_io.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../../../src/common/bitfield.h"
|
||||
#include "../../../src/common/categorical.h"
|
||||
|
||||
namespace xgboost {
|
||||
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
|
||||
@@ -150,6 +151,77 @@ TEST(Tree, ExpandCategoricalFeature) {
|
||||
}
|
||||
}
|
||||
|
||||
void GrowTree(RegTree* p_tree) {
|
||||
SimpleLCG lcg;
|
||||
size_t n_expands = 10;
|
||||
constexpr size_t kCols = 256;
|
||||
SimpleRealUniformDistribution<double> coin(0.0, 1.0);
|
||||
SimpleRealUniformDistribution<double> feat(0.0, kCols);
|
||||
SimpleRealUniformDistribution<double> split_cat(0.0, 128.0);
|
||||
SimpleRealUniformDistribution<double> split_value(0.0, kCols);
|
||||
|
||||
std::stack<bst_node_t> stack;
|
||||
stack.push(RegTree::kRoot);
|
||||
auto& tree = *p_tree;
|
||||
|
||||
for (size_t i = 0; i < n_expands; ++i) {
|
||||
auto is_cat = coin(&lcg) <= 0.5;
|
||||
bst_node_t node = stack.top();
|
||||
stack.pop();
|
||||
|
||||
bst_feature_t f = feat(&lcg);
|
||||
if (is_cat) {
|
||||
bst_cat_t cat = common::AsCat(split_cat(&lcg));
|
||||
std::vector<uint32_t> split_cats(
|
||||
LBitField32::ComputeStorageSize(cat + 1));
|
||||
LBitField32 bitset{split_cats};
|
||||
bitset.Set(cat);
|
||||
tree.ExpandCategorical(node, f, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
} else {
|
||||
auto split = split_value(&lcg);
|
||||
tree.ExpandNode(node, f, split, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
}
|
||||
|
||||
stack.push(tree[node].LeftChild());
|
||||
stack.push(tree[node].RightChild());
|
||||
}
|
||||
}
|
||||
|
||||
void CheckReload(RegTree const &tree) {
|
||||
Json out{Object()};
|
||||
tree.SaveModel(&out);
|
||||
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(out);
|
||||
Json saved{Object()};
|
||||
loaded_tree.SaveModel(&saved);
|
||||
|
||||
auto same = out == saved;
|
||||
ASSERT_TRUE(same);
|
||||
}
|
||||
|
||||
TEST(Tree, CategoricalIO) {
|
||||
{
|
||||
RegTree tree;
|
||||
bst_cat_t cat = 32;
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat + 1));
|
||||
LBitField32 bitset{split_cats};
|
||||
bitset.Set(cat);
|
||||
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
|
||||
CheckReload(tree);
|
||||
}
|
||||
|
||||
{
|
||||
RegTree tree;
|
||||
GrowTree(&tree);
|
||||
CheckReload(tree);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree ConstructTree() {
|
||||
RegTree tree;
|
||||
|
||||
Reference in New Issue
Block a user