[breaking] Change DMatrix construction to be distributed (#9623)

* Change column-split DMatrix construction to be distributed

* remove splitting code for row split
This commit is contained in:
Rong Ou
2023-10-10 08:35:57 -07:00
committed by GitHub
parent b14e535e78
commit 0ecb4de963
7 changed files with 36 additions and 65 deletions

View File

@@ -378,9 +378,8 @@ void RandomDataGenerator::GenerateCSR(
CHECK_EQ(columns->Size(), value->Size());
}
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
bool float_label,
size_t classes) const {
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
HostDeviceVector<float> data;
HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
@@ -388,7 +387,7 @@ void RandomDataGenerator::GenerateCSR(
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
data.Size(), cols_);
std::shared_ptr<DMatrix> out{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};
if (with_label) {
RandomDataGenerator gen{rows_, n_targets_, 0.0f};