Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -14,11 +14,12 @@ TEST(DenseColumn, Test) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
BinTypeSize last{kUint8BinsTypeSize};
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
auto sparse_thresh = 0.2;
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, AllThreadsForTest()};
|
||||
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
|
||||
@@ -62,9 +63,10 @@ TEST(SparseColumn, Test) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, AllThreadsForTest()};
|
||||
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
|
||||
@@ -90,9 +92,10 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, AllThreadsForTest());
|
||||
GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false);
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
|
||||
|
||||
@@ -156,6 +156,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DenseCutsCategorical) {
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
@@ -165,7 +166,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
||||
std::vector<float> x_sorted(x);
|
||||
std::sort(x_sorted.begin(), x_sorted.end());
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||
auto cuts_from_sketch = cuts.Values();
|
||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||
@@ -176,6 +177,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DenseCutsAccuracyTest) {
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100};
|
||||
int num_columns = 5;
|
||||
@@ -183,7 +185,7 @@ TEST(HistUtil, DenseCutsAccuracyTest) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@@ -193,6 +195,7 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
@@ -200,11 +203,11 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
for (auto num_bins : bin_sizes) {
|
||||
{
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), true);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins, true);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
{
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), false);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins, false);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@@ -215,6 +218,7 @@ void TestQuantileWithHessian(bool use_sorted) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {1000, 1500};
|
||||
int num_columns = 5;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
@@ -225,15 +229,13 @@ void TestQuantileWithHessian(bool use_sorted) {
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts_hess =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted, hessian);
|
||||
HistogramCuts cuts_hess = SketchOnDMatrix(&ctx, dmat.get(), num_bins, use_sorted, hessian);
|
||||
for (size_t i = 0; i < w.size(); ++i) {
|
||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||
}
|
||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||
|
||||
HistogramCuts cuts_wh =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted);
|
||||
HistogramCuts cuts_wh = SketchOnDMatrix(&ctx, dmat.get(), num_bins, use_sorted);
|
||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
@@ -255,12 +257,13 @@ TEST(HistUtil, DenseCutsExternalMemory) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, tmpdir);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@@ -275,12 +278,12 @@ TEST(HistUtil, IndexBinBound) {
|
||||
kUint32BinsTypeSize};
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
||||
GHistIndexMatrix hmat(&ctx, p_fmat.get(), max_bin, 0.5, false);
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||
}
|
||||
@@ -300,10 +303,11 @@ TEST(HistUtil, IndexBinData) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
||||
GHistIndexMatrix hmat(&ctx, p_fmat.get(), max_bin, 0.5, false);
|
||||
uint32_t const* offsets = hmat.index.Offset();
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
switch (max_bin) {
|
||||
@@ -327,10 +331,10 @@ void TestSketchFromWeights(bool with_group) {
|
||||
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -363,7 +367,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
|
||||
if (with_group) {
|
||||
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
@@ -382,7 +386,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
||||
}
|
||||
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
HistogramCuts weighted = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||
ValidateCuts(weighted, m.get(), kBins);
|
||||
}
|
||||
}
|
||||
@@ -393,11 +397,12 @@ TEST(HistUtil, SketchFromWeights) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, SketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
TestCategoricalSketch(1000, 256, 32, false, [&ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(&ctx, p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
||||
TestCategoricalSketch(1000, 256, 32, true, [&ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(&ctx, p_fmat, num_bins);
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@@ -25,9 +25,9 @@ namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
||||
HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) {
|
||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, AllThreadsForTest());
|
||||
HistogramCuts cuts = SketchOnDMatrix(ctx, &dmat, num_bins);
|
||||
return cuts;
|
||||
}
|
||||
|
||||
@@ -39,7 +39,9 @@ TEST(HistUtil, DeviceSketch) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
|
||||
Context ctx;
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
@@ -308,7 +310,8 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
||||
data::CupyAdapter adapter(str);
|
||||
|
||||
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
||||
auto host_cuts = GetHostCuts(&adapter, num_bins, missing);
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
auto host_cuts = GetHostCuts(&ctx, &adapter, num_bins, missing);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
|
||||
@@ -16,7 +16,8 @@ TEST(Quantile, LoadBalance) {
|
||||
size_t constexpr kRows = 1000, kCols = 100;
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
std::vector<bst_feature_t> cols_ptr;
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
Context ctx;
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
data::SparsePageAdapterBatch adapter{page.GetView()};
|
||||
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
|
||||
}
|
||||
@@ -43,6 +44,7 @@ void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo c
|
||||
|
||||
template <bool use_column>
|
||||
void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
Context ctx;
|
||||
auto const world = collective::GetWorldSize();
|
||||
std::vector<MetaInfo> infos(2);
|
||||
auto& h_weights = infos.front().weights_.HostVector();
|
||||
@@ -51,7 +53,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
SimpleRealUniformDistribution<float> dist(3, 1000);
|
||||
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
|
||||
std::vector<bst_row_t> column_size(cols, rows);
|
||||
size_t n_bins = 64;
|
||||
bst_bin_t n_bins = 64;
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
@@ -72,15 +74,15 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_distributed(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -93,8 +95,8 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_on_single_node(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
m->Info().num_row_ = rows;
|
||||
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
@@ -106,7 +108,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
.Upper(1.0f)
|
||||
.GenerateDMatrix();
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
@@ -172,6 +174,7 @@ TEST(Quantile, SortedDistributed) {
|
||||
namespace {
|
||||
template <bool use_column>
|
||||
void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
Context ctx;
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
@@ -204,17 +207,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
// Generate cuts for distributed environment.
|
||||
HistogramCuts distributed_cuts;
|
||||
{
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_distributed(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -227,17 +230,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
HistogramCuts single_node_cuts;
|
||||
{
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, AllThreadsForTest());
|
||||
ContainerType<use_column> sketch_on_single_node(
|
||||
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
@@ -299,8 +302,10 @@ namespace {
|
||||
void TestSameOnAllWorkers() {
|
||||
auto const world = collective::GetWorldSize();
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
kRows, [=, &ctx](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<float> storage;
|
||||
std::vector<FeatureType> ft(kCols);
|
||||
@@ -314,7 +319,7 @@ void TestSameOnAllWorkers() {
|
||||
.MaxCategory(17)
|
||||
.Seed(rank + seed)
|
||||
.GenerateDMatrix();
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, AllThreadsForTest());
|
||||
auto cuts = SketchOnDMatrix(&ctx, m.get(), n_bins);
|
||||
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
||||
std::vector<
|
||||
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
||||
|
||||
Reference in New Issue
Block a user