Small refactor to categoricals (#7858)

This commit is contained in:
Rory Mitchell
2022-05-05 17:47:02 +02:00
committed by GitHub
parent 14ef38b834
commit 7ef54e39ec
7 changed files with 110 additions and 150 deletions

View File

@@ -137,48 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true);
}
TEST(GpuHist, ApplySplit) {
RegTree tree;
GPUExpandEntry candidate;
candidate.nid = 0;
candidate.left_weight = 1.0f;
candidate.right_weight = 2.0f;
candidate.base_weight = 3.0f;
candidate.split.is_cat = true;
candidate.split.fvalue = 1.0f; // at cat 1
size_t n_rows = 10;
size_t n_cols = 10;
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
GenericParameter p;
p.InitAllowUnknown(Args{});
TrainParam tparam;
tparam.InitAllowUnknown(Args{});
BatchParam bparam;
bparam.gpu_id = 0;
bparam.max_bin = 3;
Context ctx{CreateEmptyGenericParam(0)};
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
auto impl = ellpack.Impl();
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
&ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
ASSERT_EQ(updater.node_categories.size(), 1);
}
}
HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});