Avoid the use of size_t in the partitioner. (#10541)
- Avoid the use of size_t in the partitioner. - Use `Span` instead of `Elem` where `node_id` is not needed. - Remove the `const_cast`. - Make sure the constness is not removed in the `Elem` by making it reference only. size_t is implementation-defined, which causes issue when we want to pass pointer or span.
This commit is contained in:
@@ -45,7 +45,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
// dense, no missing values
|
||||
GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false);
|
||||
common::RowSetCollection row_set_collection;
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kRows);
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection.Init();
|
||||
@@ -53,7 +53,9 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
HistMakerTrainParam hist_param;
|
||||
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
|
||||
hist.AllocateHistograms({0});
|
||||
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);
|
||||
auto const &elem = row_set_collection[0];
|
||||
common::BuildHist<false>(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0],
|
||||
force_read_by_column);
|
||||
|
||||
// Compute total gradient for all data points
|
||||
GradientPairPrecise total_gpair;
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <algorithm> // for max
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint32_t
|
||||
#include <functional> // for function
|
||||
#include <iterator> // for back_inserter
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr, allocator, unique_ptr
|
||||
@@ -108,7 +107,7 @@ void TestSyncHist(bool is_distributed) {
|
||||
common::RowSetCollection row_set_collection;
|
||||
{
|
||||
row_set_collection.Clear();
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kNRows);
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection.Init();
|
||||
@@ -251,7 +250,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
||||
|
||||
common::RowSetCollection row_set_collection;
|
||||
row_set_collection.Clear();
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kNRows);
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection.Init();
|
||||
@@ -345,7 +344,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
|
||||
common::RowSetCollection row_set_collection;
|
||||
row_set_collection.Clear();
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kRows);
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection.Init();
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/common/numeric.h"
|
||||
#include "../../../src/tree/common_row_partitioner.h"
|
||||
#include "../collective/test_worker.h" // for TestDistributedGlobal
|
||||
#include "../helpers.h"
|
||||
@@ -54,20 +53,23 @@ TEST(Approx, Partitioner) {
|
||||
GetSplit(&tree, split_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = page.cut.Values().at(page.index[*it]);
|
||||
ASSERT_LE(value, split_value);
|
||||
{
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
auto const& elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
for (auto& it : elem) {
|
||||
auto value = page.cut.Values().at(page.index[it]);
|
||||
ASSERT_LE(value, split_value);
|
||||
}
|
||||
}
|
||||
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
elem = partitioner[right_nidx];
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = page.cut.Values().at(page.index[*it]);
|
||||
ASSERT_GT(value, split_value) << *it;
|
||||
{
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
auto const& elem = partitioner[right_nidx];
|
||||
for (auto& it : elem) {
|
||||
auto value = page.cut.Values().at(page.index[it]);
|
||||
ASSERT_GT(value, split_value) << it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -99,23 +101,25 @@ void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared
|
||||
RegTree tree;
|
||||
GetSplit(&tree, mid_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
auto expected_elem = expected_mid_partitioner[left_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
{
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
auto const& elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
auto const& expected_elem = expected_mid_partitioner[left_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
}
|
||||
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
elem = partitioner[right_nidx];
|
||||
expected_elem = expected_mid_partitioner[right_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
{
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
auto const& elem = partitioner[right_nidx];
|
||||
auto const& expected_elem = expected_mid_partitioner[right_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -68,21 +67,24 @@ void TestPartitioner(bst_target_t n_targets) {
|
||||
} else {
|
||||
GetMultiSplitForTest(&tree, split_value, &candidates);
|
||||
}
|
||||
auto left_nidx = tree.LeftChild(RegTree::kRoot);
|
||||
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
|
||||
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = gmat.cut.Values().at(gmat.index[*it]);
|
||||
ASSERT_LE(value, split_value);
|
||||
{
|
||||
auto left_nidx = tree.LeftChild(RegTree::kRoot);
|
||||
auto const& elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
for (auto& it : elem) {
|
||||
auto value = gmat.cut.Values().at(gmat.index[it]);
|
||||
ASSERT_LE(value, split_value);
|
||||
}
|
||||
}
|
||||
auto right_nidx = tree.RightChild(RegTree::kRoot);
|
||||
elem = partitioner[right_nidx];
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = gmat.cut.Values().at(gmat.index[*it]);
|
||||
ASSERT_GT(value, split_value);
|
||||
{
|
||||
auto right_nidx = tree.RightChild(RegTree::kRoot);
|
||||
auto const& elem = partitioner[right_nidx];
|
||||
for (auto& it : elem) {
|
||||
auto value = gmat.cut.Values().at(gmat.index[it]);
|
||||
ASSERT_GT(value, split_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,21 +140,24 @@ void VerifyColumnSplitPartitioner(bst_target_t n_targets, size_t n_samples,
|
||||
auto left_nidx = tree.LeftChild(RegTree::kRoot);
|
||||
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
|
||||
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
auto expected_elem = expected_mid_partitioner[left_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
{
|
||||
auto const& elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
auto const& expected_elem = expected_mid_partitioner[left_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
}
|
||||
|
||||
auto right_nidx = tree.RightChild(RegTree::kRoot);
|
||||
elem = partitioner[right_nidx];
|
||||
expected_elem = expected_mid_partitioner[right_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
{
|
||||
auto right_nidx = tree.RightChild(RegTree::kRoot);
|
||||
auto const& elem = partitioner[right_nidx];
|
||||
auto const& expected_elem = expected_mid_partitioner[right_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user