[breaking] Change DMatrix construction to be distributed (#9623)
* Change column-split DMatrix construction to be distributed * remove splitting code for row split
This commit is contained in:
@@ -428,3 +428,21 @@ TEST(SimpleDMatrix, Threads) {
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0, "")};
|
||||
ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest());
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyColumnSplit() {
|
||||
size_t constexpr kRows {16};
|
||||
size_t constexpr kCols {8};
|
||||
auto dmat =
|
||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol);
|
||||
|
||||
ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize());
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(SimpleDMatrix, ColumnSplit) {
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit);
|
||||
}
|
||||
|
||||
@@ -378,9 +378,8 @@ void RandomDataGenerator::GenerateCSR(
|
||||
CHECK_EQ(columns->Size(), value->Size());
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
|
||||
bool float_label,
|
||||
size_t classes) const {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
|
||||
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
@@ -388,7 +387,7 @@ void RandomDataGenerator::GenerateCSR(
|
||||
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
|
||||
data.Size(), cols_);
|
||||
std::shared_ptr<DMatrix> out{
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};
|
||||
|
||||
if (with_label) {
|
||||
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
|
||||
|
||||
@@ -310,9 +310,9 @@ class RandomDataGenerator {
|
||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||
HostDeviceVector<bst_feature_t>* columns) const;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(
|
||||
bool with_label = false, bool float_label = true, size_t classes = 1,
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow) const;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
|
||||
bool with_label) const;
|
||||
|
||||
Reference in New Issue
Block a user