Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -108,6 +108,7 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
|
||||
Json::Dump(data_arr, &sdata);
|
||||
Json config{Object{}};
|
||||
config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
|
||||
config["data_split_mode"] = Integer{static_cast<int64_t>(DataSplitMode::kCol)};
|
||||
Json::Dump(config, &sconfig);
|
||||
|
||||
DMatrixHandle handle;
|
||||
@@ -120,6 +121,8 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
|
||||
ASSERT_EQ(n, 3);
|
||||
ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0);
|
||||
ASSERT_EQ(n, 3);
|
||||
ASSERT_EQ(XGDMatrixDataSplitMode(handle, &n), 0);
|
||||
ASSERT_EQ(n, static_cast<int64_t>(DataSplitMode::kCol));
|
||||
|
||||
std::shared_ptr<xgboost::DMatrix> *pp_fmat =
|
||||
static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
|
||||
@@ -74,6 +74,49 @@ TEST(MetaInfo, GetSetFeature) {
|
||||
// Other conditions are tested in `SaveLoadBinary`.
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyGetSetFeatureColumnSplit() {
|
||||
xgboost::MetaInfo info;
|
||||
info.data_split_mode = DataSplitMode::kCol;
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
|
||||
auto constexpr kCols{2};
|
||||
std::vector<std::string> types{u8"float", u8"c"};
|
||||
std::vector<char const *> c_types(kCols);
|
||||
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.num_col_ = kCols;
|
||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error);
|
||||
info.num_col_ = kCols * world_size;
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
|
||||
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
|
||||
u8"c", u8"float", u8"c"};
|
||||
EXPECT_EQ(info.feature_type_names, expected_type_names);
|
||||
std::vector<xgboost::FeatureType> expected_types{
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical};
|
||||
EXPECT_EQ(info.feature_types.HostVector(), expected_types);
|
||||
|
||||
std::vector<std::string> names{u8"feature0", u8"feature1"};
|
||||
std::vector<char const *> c_names(kCols);
|
||||
std::transform(names.cbegin(), names.cend(), c_names.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.num_col_ = kCols;
|
||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error);
|
||||
info.num_col_ = kCols * world_size;
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
|
||||
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
|
||||
u8"1.feature1", u8"2.feature0", u8"2.feature1"};
|
||||
EXPECT_EQ(info.feature_names, expected_names);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(MetaInfo, GetSetFeatureColumnSplit) {
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit);
|
||||
}
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
|
||||
Reference in New Issue
Block a user