Small refactor to categoricals (#7858)
This commit is contained in:
@@ -137,48 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
GPUExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
candidate.base_weight = 3.0f;
|
||||
candidate.split.is_cat = true;
|
||||
candidate.split.fvalue = 1.0f; // at cat 1
|
||||
|
||||
size_t n_rows = 10;
|
||||
size_t n_cols = 10;
|
||||
|
||||
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
|
||||
GenericParameter p;
|
||||
p.InitAllowUnknown(Args{});
|
||||
|
||||
TrainParam tparam;
|
||||
tparam.InitAllowUnknown(Args{});
|
||||
BatchParam bparam;
|
||||
bparam.gpu_id = 0;
|
||||
bparam.max_bin = 3;
|
||||
Context ctx{CreateEmptyGenericParam(0)};
|
||||
|
||||
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
|
||||
auto impl = ellpack.Impl();
|
||||
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
|
||||
feature_types.SetDevice(bparam.gpu_id);
|
||||
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
|
||||
&ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
|
||||
updater.ApplySplit(candidate, &tree);
|
||||
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
|
||||
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
|
||||
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
|
||||
|
||||
ASSERT_EQ(updater.node_categories.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCutsWrapper GetHostCutMatrix () {
|
||||
HistogramCutsWrapper cmat;
|
||||
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||
|
||||
Reference in New Issue
Block a user