Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -16,7 +16,8 @@ TEST(Quantile, LoadBalance) {
|
||||
size_t constexpr kRows = 1000, kCols = 100;
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
std::vector<bst_feature_t> cols_ptr;
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
Context ctx;
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
data::SparsePageAdapterBatch adapter{page.GetView()};
|
||||
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
|
||||
}
|
||||
@@ -43,6 +44,7 @@ void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo c
|
||||
|
||||
template <bool use_column>
|
||||
void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
Context ctx;
|
||||
auto const world = collective::GetWorldSize();
|
||||
std::vector<MetaInfo> infos(2);
|
||||
auto& h_weights = infos.front().weights_.HostVector();
|
||||
@@ -51,7 +53,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
SimpleRealUniformDistribution<float> dist(3, 1000);
|
||||
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
|
||||
std::vector<bst_row_t> column_size(cols, rows);
|
||||
size_t n_bins = 64;
|
||||
bst_bin_t n_bins = 64;
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
@@ -72,15 +74,15 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_distributed(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -93,8 +95,8 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_on_single_node(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
m->Info().num_row_ = rows;
|
||||
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
@@ -106,7 +108,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
.Upper(1.0f)
|
||||
.GenerateDMatrix();
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
@@ -172,6 +174,7 @@ TEST(Quantile, SortedDistributed) {
|
||||
namespace {
|
||||
template <bool use_column>
|
||||
void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
Context ctx;
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
@@ -204,17 +207,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
// Generate cuts for distributed environment.
|
||||
HistogramCuts distributed_cuts;
|
||||
{
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_distributed(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -227,17 +230,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
HistogramCuts single_node_cuts;
|
||||
{
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_on_single_node(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -299,8 +302,10 @@ namespace {
|
||||
void TestSameOnAllWorkers() {
|
||||
auto const world = collective::GetWorldSize();
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
kRows, [=, &ctx](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<float> storage;
|
||||
std::vector<FeatureType> ft(kCols);
|
||||
@@ -314,7 +319,7 @@ void TestSameOnAllWorkers() {
|
||||
.MaxCategory(17)
|
||||
.Seed(rank + seed)
|
||||
.GenerateDMatrix();
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, AllThreadsForTest());
|
||||
auto cuts = SketchOnDMatrix(&ctx, m.get(), n_bins);
|
||||
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
||||
std::vector<
|
||||
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
||||
|
||||
Reference in New Issue
Block a user