Small refactor to categoricals (#7858)
This commit is contained in:
@@ -24,14 +24,16 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 2, 4};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
@@ -42,22 +44,27 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
if (is_categorical) {
|
||||
auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(),
|
||||
cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
}
|
||||
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
d_feature_types,
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
|
||||
GPUHistEvaluator<GradientPair> evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
dh::device_vector<common::CatBitField::value_type> out_cats;
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
|
||||
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
|
||||
@@ -137,48 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
GPUExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
candidate.base_weight = 3.0f;
|
||||
candidate.split.is_cat = true;
|
||||
candidate.split.fvalue = 1.0f; // at cat 1
|
||||
|
||||
size_t n_rows = 10;
|
||||
size_t n_cols = 10;
|
||||
|
||||
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
|
||||
GenericParameter p;
|
||||
p.InitAllowUnknown(Args{});
|
||||
|
||||
TrainParam tparam;
|
||||
tparam.InitAllowUnknown(Args{});
|
||||
BatchParam bparam;
|
||||
bparam.gpu_id = 0;
|
||||
bparam.max_bin = 3;
|
||||
Context ctx{CreateEmptyGenericParam(0)};
|
||||
|
||||
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
|
||||
auto impl = ellpack.Impl();
|
||||
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
|
||||
feature_types.SetDevice(bparam.gpu_id);
|
||||
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
|
||||
&ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
|
||||
updater.ApplySplit(candidate, &tree);
|
||||
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
|
||||
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
|
||||
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
|
||||
|
||||
ASSERT_EQ(updater.node_categories.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCutsWrapper GetHostCutMatrix () {
|
||||
HistogramCutsWrapper cmat;
|
||||
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||
|
||||
Reference in New Issue
Block a user