More categorical tests and disable shap sparse test. (#6219)
* Fix tree load with 32 category.
This commit is contained in:
parent
c991eb612d
commit
b5b24354b8
@ -186,7 +186,9 @@ Json& JsonObject::operator[](int ind) {
|
||||
}
|
||||
|
||||
bool JsonObject::operator==(Value const& rhs) const {
|
||||
if (!IsA<JsonObject>(&rhs)) { return false; }
|
||||
if (!IsA<JsonObject>(&rhs)) {
|
||||
return false;
|
||||
}
|
||||
return object_ == Cast<JsonObject const>(&rhs)->GetObject();
|
||||
}
|
||||
|
||||
@ -275,10 +277,14 @@ Json& JsonNumber::operator[](int ind) {
|
||||
|
||||
bool JsonNumber::operator==(Value const& rhs) const {
|
||||
if (!IsA<JsonNumber>(&rhs)) { return false; }
|
||||
auto r_num = Cast<JsonNumber const>(&rhs)->GetNumber();
|
||||
if (std::isinf(number_)) {
|
||||
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
|
||||
return std::isinf(r_num);
|
||||
}
|
||||
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
|
||||
if (std::isnan(number_)) {
|
||||
return std::isnan(r_num);
|
||||
}
|
||||
return number_ - r_num == 0;
|
||||
}
|
||||
|
||||
Value & JsonNumber::operator=(Value const &rhs) {
|
||||
|
||||
@ -792,16 +792,17 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
auto j_begin = get<Integer const>(categories_segments[cnt]);
|
||||
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
|
||||
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
||||
CHECK_NE(j_end - j_begin, 0) << nidx;
|
||||
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
auto const &category = get<Integer const>(categories[j]);
|
||||
auto cat = common::AsCat(category);
|
||||
max_cat = std::max(max_cat, cat);
|
||||
}
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
size = size == 0 ? 1 : size;
|
||||
// Have at least 1 category in split.
|
||||
CHECK_NE(std::numeric_limits<bst_cat_t>::min(), max_cat);
|
||||
size_t n_cats = max_cat + 1; // cat 0
|
||||
size_t size = common::KCatBitField::ComputeStorageSize(n_cats);
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include "xgboost/json_io.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../../../src/common/bitfield.h"
|
||||
#include "../../../src/common/categorical.h"
|
||||
|
||||
namespace xgboost {
|
||||
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
|
||||
@ -150,6 +151,77 @@ TEST(Tree, ExpandCategoricalFeature) {
|
||||
}
|
||||
}
|
||||
|
||||
void GrowTree(RegTree* p_tree) {
|
||||
SimpleLCG lcg;
|
||||
size_t n_expands = 10;
|
||||
constexpr size_t kCols = 256;
|
||||
SimpleRealUniformDistribution<double> coin(0.0, 1.0);
|
||||
SimpleRealUniformDistribution<double> feat(0.0, kCols);
|
||||
SimpleRealUniformDistribution<double> split_cat(0.0, 128.0);
|
||||
SimpleRealUniformDistribution<double> split_value(0.0, kCols);
|
||||
|
||||
std::stack<bst_node_t> stack;
|
||||
stack.push(RegTree::kRoot);
|
||||
auto& tree = *p_tree;
|
||||
|
||||
for (size_t i = 0; i < n_expands; ++i) {
|
||||
auto is_cat = coin(&lcg) <= 0.5;
|
||||
bst_node_t node = stack.top();
|
||||
stack.pop();
|
||||
|
||||
bst_feature_t f = feat(&lcg);
|
||||
if (is_cat) {
|
||||
bst_cat_t cat = common::AsCat(split_cat(&lcg));
|
||||
std::vector<uint32_t> split_cats(
|
||||
LBitField32::ComputeStorageSize(cat + 1));
|
||||
LBitField32 bitset{split_cats};
|
||||
bitset.Set(cat);
|
||||
tree.ExpandCategorical(node, f, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
} else {
|
||||
auto split = split_value(&lcg);
|
||||
tree.ExpandNode(node, f, split, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
}
|
||||
|
||||
stack.push(tree[node].LeftChild());
|
||||
stack.push(tree[node].RightChild());
|
||||
}
|
||||
}
|
||||
|
||||
void CheckReload(RegTree const &tree) {
|
||||
Json out{Object()};
|
||||
tree.SaveModel(&out);
|
||||
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(out);
|
||||
Json saved{Object()};
|
||||
loaded_tree.SaveModel(&saved);
|
||||
|
||||
auto same = out == saved;
|
||||
ASSERT_TRUE(same);
|
||||
}
|
||||
|
||||
TEST(Tree, CategoricalIO) {
|
||||
{
|
||||
RegTree tree;
|
||||
bst_cat_t cat = 32;
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat + 1));
|
||||
LBitField32 bitset{split_cats};
|
||||
bitset.Set(cat);
|
||||
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
|
||||
CheckReload(tree);
|
||||
}
|
||||
|
||||
{
|
||||
RegTree tree;
|
||||
GrowTree(&tree);
|
||||
CheckReload(tree);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree ConstructTree() {
|
||||
RegTree tree;
|
||||
|
||||
@ -212,6 +212,10 @@ class TestGPUPredict(unittest.TestCase):
|
||||
tm.dataset_strategy, shap_parameter_strategy)
|
||||
@settings(deadline=None, max_examples=20)
|
||||
def test_shap_interactions(self, num_rounds, dataset, param):
|
||||
if dataset.name == 'sparse':
|
||||
issue = 'https://github.com/dmlc/xgboost/issues/6074'
|
||||
pytest.xfail(reason=f'GPU shap with sparse is flaky: {issue}')
|
||||
|
||||
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
|
||||
param = dataset.set_params(param)
|
||||
dmat = dataset.get_dmat()
|
||||
@ -220,5 +224,6 @@ class TestGPUPredict(unittest.TestCase):
|
||||
shap = bst.predict(test_dmat, pred_interactions=True)
|
||||
margin = bst.predict(test_dmat, output_margin=True)
|
||||
assume(len(dataset.y) > 0)
|
||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
|
||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||
margin,
|
||||
1e-3, 1e-3)
|
||||
|
||||
@ -41,7 +41,24 @@ class TestGPUUpdaters:
|
||||
note(result)
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, cat, onehot, label, rounds):
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
onehot = pd.get_dummies(df)
|
||||
cat = df
|
||||
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
@ -64,28 +81,20 @@ class TestGPUUpdaters:
|
||||
rtol=1e-3)
|
||||
assert tm.non_increasing(by_builtin_results['Train']['rmse'])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(5, 10),
|
||||
strategies.integers(1, 5), strategies.integers(4, 8))
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 5), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
x = pd.get_dummies(df)
|
||||
|
||||
self.run_categorical_basic(df, x, label, rounds)
|
||||
def test_categorical_32_cat(self):
|
||||
'''32 hits the bound of integer bitset, so special test'''
|
||||
rows = 1000
|
||||
cols = 10
|
||||
cats = 32
|
||||
rounds = 4
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user