More categorical tests and disable shap sparse test. (#6219)

* Fix tree load with 32 category.
This commit is contained in:
Jiaming Yuan
2020-10-10 16:12:37 +08:00
committed by GitHub
parent c991eb612d
commit b5b24354b8
5 changed files with 120 additions and 27 deletions

View File

@@ -5,6 +5,7 @@
#include "xgboost/json_io.h"
#include "xgboost/tree_model.h"
#include "../../../src/common/bitfield.h"
#include "../../../src/common/categorical.h"
namespace xgboost {
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
@@ -150,6 +151,77 @@ TEST(Tree, ExpandCategoricalFeature) {
}
}
void GrowTree(RegTree* p_tree) {
SimpleLCG lcg;
size_t n_expands = 10;
constexpr size_t kCols = 256;
SimpleRealUniformDistribution<double> coin(0.0, 1.0);
SimpleRealUniformDistribution<double> feat(0.0, kCols);
SimpleRealUniformDistribution<double> split_cat(0.0, 128.0);
SimpleRealUniformDistribution<double> split_value(0.0, kCols);
std::stack<bst_node_t> stack;
stack.push(RegTree::kRoot);
auto& tree = *p_tree;
for (size_t i = 0; i < n_expands; ++i) {
auto is_cat = coin(&lcg) <= 0.5;
bst_node_t node = stack.top();
stack.pop();
bst_feature_t f = feat(&lcg);
if (is_cat) {
bst_cat_t cat = common::AsCat(split_cat(&lcg));
std::vector<uint32_t> split_cats(
LBitField32::ComputeStorageSize(cat + 1));
LBitField32 bitset{split_cats};
bitset.Set(cat);
tree.ExpandCategorical(node, f, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
} else {
auto split = split_value(&lcg);
tree.ExpandNode(node, f, split, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
}
stack.push(tree[node].LeftChild());
stack.push(tree[node].RightChild());
}
}
void CheckReload(RegTree const &tree) {
Json out{Object()};
tree.SaveModel(&out);
RegTree loaded_tree;
loaded_tree.LoadModel(out);
Json saved{Object()};
loaded_tree.SaveModel(&saved);
auto same = out == saved;
ASSERT_TRUE(same);
}
TEST(Tree, CategoricalIO) {
{
RegTree tree;
bst_cat_t cat = 32;
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat + 1));
LBitField32 bitset{split_cats};
bitset.Set(cat);
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
CheckReload(tree);
}
{
RegTree tree;
GrowTree(&tree);
CheckReload(tree);
}
}
namespace {
RegTree ConstructTree() {
RegTree tree;