/** * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for FeatureType, DMatrix #include // for size_t #include // for shared_ptr #include // for vector #include "../helpers.h" // for RandomDataGenerator namespace xgboost::tree { inline std::shared_ptr GenerateCatDMatrix(std::size_t rows, std::size_t cols, float sparsity, bool categorical) { if (categorical) { std::vector ft(cols); for (size_t i = 0; i < ft.size(); ++i) { ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical; } return RandomDataGenerator(rows, cols, sparsity) .Seed(3) .Type(ft) .MaxCategory(17) .GenerateDMatrix(); } else { return RandomDataGenerator{rows, cols, sparsity}.Seed(3).GenerateDMatrix(); } } void TestColumnSplit(bst_target_t n_targets, bool categorical, std::string name, float sparsity); } // namespace xgboost::tree