Support column-split in row partitioner (#8828)
This commit is contained in:
parent
90c0633a28
commit
d9688f93c7
@ -31,6 +31,8 @@ namespace common {
|
||||
// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature
|
||||
template<size_t BlockSize>
|
||||
class PartitionBuilder {
|
||||
using BitVector = RBitField8;
|
||||
|
||||
public:
|
||||
template<typename Func>
|
||||
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) {
|
||||
@ -121,27 +123,11 @@ class PartitionBuilder {
|
||||
bool default_left = tree[nid].DefaultLeft();
|
||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||
auto node_cats = tree.NodeCats(nid);
|
||||
|
||||
auto const& index = gmat.index;
|
||||
auto const& cut_values = gmat.cut.Values();
|
||||
auto const& cut_ptrs = gmat.cut.Ptrs();
|
||||
|
||||
auto gidx_calc = [&](auto ridx) {
|
||||
auto begin = gmat.RowIdx(ridx);
|
||||
if (gmat.IsDense()) {
|
||||
return static_cast<bst_bin_t>(index[begin + fid]);
|
||||
}
|
||||
auto end = gmat.RowIdx(ridx + 1);
|
||||
auto f_begin = cut_ptrs[fid];
|
||||
auto f_end = cut_ptrs[fid + 1];
|
||||
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
|
||||
// features.
|
||||
return BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||
};
|
||||
|
||||
auto pred_hist = [&](auto ridx, auto bin_id) {
|
||||
if (any_cat && is_cat) {
|
||||
auto gidx = gidx_calc(ridx);
|
||||
auto gidx = gmat.GetGindex(ridx, fid);
|
||||
bool go_left = default_left;
|
||||
if (gidx > -1) {
|
||||
go_left = Decision(node_cats, cut_values[gidx]);
|
||||
@ -153,7 +139,7 @@ class PartitionBuilder {
|
||||
};
|
||||
|
||||
auto pred_approx = [&](auto ridx) {
|
||||
auto gidx = gidx_calc(ridx);
|
||||
auto gidx = gmat.GetGindex(ridx, fid);
|
||||
bool go_left = default_left;
|
||||
if (gidx > -1) {
|
||||
if (is_cat) {
|
||||
@ -199,6 +185,84 @@ class PartitionBuilder {
|
||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief When data is split by column, we don't have all the features locally on the current
|
||||
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made
|
||||
* to go right, or if the feature value used for the split is missing.
|
||||
*/
|
||||
void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
||||
const common::Range1d range, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix,
|
||||
const RegTree& tree, const size_t* rid,
|
||||
BitVector* decision_bits, BitVector* missing_bits) {
|
||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||
std::size_t nid = nodes[node_in_set].nid;
|
||||
bst_feature_t fid = tree[nid].SplitIndex();
|
||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||
auto node_cats = tree.NodeCats(nid);
|
||||
auto const& cut_values = gmat.cut.Values();
|
||||
|
||||
if (!column_matrix.IsInitialized()) {
|
||||
for (auto row_id : rid_span) {
|
||||
auto gidx = gmat.GetGindex(row_id, fid);
|
||||
if (gidx > -1) {
|
||||
bool go_left = false;
|
||||
if (is_cat) {
|
||||
go_left = Decision(node_cats, cut_values[gidx]);
|
||||
} else {
|
||||
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
||||
}
|
||||
if (go_left) {
|
||||
decision_bits->Set(row_id - gmat.base_rowid);
|
||||
}
|
||||
} else {
|
||||
missing_bits->Set(row_id - gmat.base_rowid);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Once we've aggregated the decision and missing bits from all the workers, we can then
|
||||
* use them to partition the rows accordingly.
|
||||
*/
|
||||
void PartitionByMask(const size_t node_in_set,
|
||||
std::vector<xgboost::tree::CPUExpandEntry> const& nodes,
|
||||
const common::Range1d range, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix, const RegTree& tree,
|
||||
const size_t* rid, BitVector const& decision_bits,
|
||||
BitVector const& missing_bits) {
|
||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||
std::size_t nid = nodes[node_in_set].nid;
|
||||
bool default_left = tree[nid].DefaultLeft();
|
||||
|
||||
auto pred_approx = [&](auto ridx) {
|
||||
bool go_left = default_left;
|
||||
bool is_missing = missing_bits.Check(ridx - gmat.base_rowid);
|
||||
if (!is_missing) {
|
||||
go_left = decision_bits.Check(ridx - gmat.base_rowid);
|
||||
}
|
||||
return go_left;
|
||||
};
|
||||
|
||||
std::pair<size_t, size_t> child_nodes_sizes;
|
||||
if (!column_matrix.IsInitialized()) {
|
||||
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
||||
} else {
|
||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||
}
|
||||
|
||||
const size_t n_left = child_nodes_sizes.first;
|
||||
const size_t n_right = child_nodes_sizes.second;
|
||||
|
||||
SetNLeftElems(node_in_set, range.begin(), n_left);
|
||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||
}
|
||||
|
||||
// allocate thread local memory, should be called for each specific task
|
||||
void AllocateForTask(size_t id) {
|
||||
if (mem_blocks_[id].get() == nullptr) {
|
||||
|
||||
@ -150,16 +150,24 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
|
||||
return *columns_;
|
||||
}
|
||||
|
||||
bst_bin_t GHistIndexMatrix::GetGindex(size_t ridx, size_t fidx) const {
|
||||
auto begin = RowIdx(ridx);
|
||||
if (IsDense()) {
|
||||
return static_cast<bst_bin_t>(index[begin + fidx]);
|
||||
}
|
||||
auto end = RowIdx(ridx + 1);
|
||||
auto const& cut_ptrs = cut.Ptrs();
|
||||
auto f_begin = cut_ptrs[fidx];
|
||||
auto f_end = cut_ptrs[fidx + 1];
|
||||
return BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||
}
|
||||
|
||||
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
auto const &values = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
if (is_cat) {
|
||||
auto f_begin = ptrs[fidx];
|
||||
auto f_end = ptrs[fidx + 1];
|
||||
auto begin = RowIdx(ridx);
|
||||
auto end = RowIdx(ridx + 1);
|
||||
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||
auto gidx = GetGindex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
@ -228,6 +228,8 @@ class GHistIndexMatrix {
|
||||
|
||||
common::ColumnMatrix const& Transpose() const;
|
||||
|
||||
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
|
||||
|
||||
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
||||
|
||||
private:
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/partition_builder.h"
|
||||
#include "hist/expand_entry.h" // CPUExpandEntry
|
||||
@ -16,17 +17,73 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
class CommonRowPartitioner {
|
||||
static constexpr size_t kPartitionBlockSize = 2048;
|
||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||
common::RowSetCollection row_set_collection_;
|
||||
|
||||
static constexpr size_t kPartitionBlockSize = 2048;
|
||||
|
||||
class ColumnSplitHelper {
|
||||
public:
|
||||
ColumnSplitHelper() = default;
|
||||
|
||||
ColumnSplitHelper(bst_row_t num_row,
|
||||
common::PartitionBuilder<kPartitionBlockSize>* partition_builder,
|
||||
common::RowSetCollection* row_set_collection)
|
||||
: partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
|
||||
decision_storage_.resize(num_row);
|
||||
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
|
||||
missing_storage_.resize(num_row);
|
||||
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
|
||||
}
|
||||
|
||||
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
|
||||
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
// When data is split by column, we don't have all the feature values in the local worker, so
|
||||
// we first collect all the decisions and whether the feature is missing into bit vectors.
|
||||
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
|
||||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
partition_builder_->MaskRows(node_in_set, nodes, r, gmat, column_matrix, *p_tree,
|
||||
(*row_set_collection_)[nid].begin, &decision_bits_,
|
||||
&missing_bits_);
|
||||
});
|
||||
|
||||
// Then aggregate the bit vectors across all the workers.
|
||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||
decision_storage_.size());
|
||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||
missing_storage_.size());
|
||||
|
||||
// Finally use the bit vectors to partition the rows.
|
||||
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
|
||||
size_t begin = r.begin();
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
const size_t task_id = partition_builder_->GetTaskIdx(node_in_set, begin);
|
||||
partition_builder_->AllocateForTask(task_id);
|
||||
partition_builder_->PartitionByMask(node_in_set, nodes, r, gmat, column_matrix, *p_tree,
|
||||
(*row_set_collection_)[nid].begin, decision_bits_,
|
||||
missing_bits_);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
using BitVector = RBitField8;
|
||||
std::vector<BitVector::value_type> decision_storage_{};
|
||||
BitVector decision_bits_{};
|
||||
std::vector<BitVector::value_type> missing_storage_{};
|
||||
BitVector missing_bits_{};
|
||||
common::PartitionBuilder<kPartitionBlockSize>* partition_builder_;
|
||||
common::RowSetCollection* row_set_collection_;
|
||||
};
|
||||
|
||||
class CommonRowPartitioner {
|
||||
public:
|
||||
bst_row_t base_rowid = 0;
|
||||
|
||||
CommonRowPartitioner() = default;
|
||||
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid)
|
||||
: base_rowid{_base_rowid} {
|
||||
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
|
||||
bool is_col_split)
|
||||
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
|
||||
row_set_collection_.Clear();
|
||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||
row_indices.resize(num_row);
|
||||
@ -34,6 +91,10 @@ class CommonRowPartitioner {
|
||||
std::size_t* p_row_indices = row_indices.data();
|
||||
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
|
||||
row_set_collection_.Init();
|
||||
|
||||
if (is_col_split_) {
|
||||
column_split_helper_ = ColumnSplitHelper{num_row, &partition_builder_, &row_set_collection_};
|
||||
}
|
||||
}
|
||||
|
||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||
@ -156,16 +217,20 @@ class CommonRowPartitioner {
|
||||
|
||||
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
||||
// Store results in intermediate buffers from partition_builder_
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||
size_t begin = r.begin();
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
||||
partition_builder_.AllocateForTask(task_id);
|
||||
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
|
||||
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
|
||||
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
|
||||
row_set_collection_[nid].begin);
|
||||
});
|
||||
if (is_col_split_) {
|
||||
column_split_helper_.Partition(space, ctx->Threads(), gmat, column_matrix, nodes, p_tree);
|
||||
} else {
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||
size_t begin = r.begin();
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
||||
partition_builder_.AllocateForTask(task_id);
|
||||
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
|
||||
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
|
||||
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
|
||||
row_set_collection_[nid].begin);
|
||||
});
|
||||
}
|
||||
|
||||
// 3. Compute offsets to copy blocks of row-indexes
|
||||
// from partition_builder_ to row_set_collection_
|
||||
@ -205,6 +270,12 @@ class CommonRowPartitioner {
|
||||
ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
||||
}
|
||||
|
||||
private:
|
||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||
common::RowSetCollection row_set_collection_;
|
||||
bool is_col_split_;
|
||||
ColumnSplitHelper column_split_helper_;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
|
||||
@ -71,7 +71,7 @@ class GloablApproxBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
n_batches_++;
|
||||
}
|
||||
|
||||
|
||||
@ -277,7 +277,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
|
||||
@ -10,29 +10,36 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
TEST(Approx, Partitioner) {
|
||||
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||
Context ctx;
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||
ASSERT_EQ(partitioner.Size(), 1);
|
||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
|
||||
namespace {
|
||||
std::vector<float> GenerateHess(size_t n_samples) {
|
||||
auto grad = GenerateRandomGradients(n_samples);
|
||||
std::vector<float> hess(grad.Size());
|
||||
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
|
||||
[](auto gpair) { return gpair.GetHess(); });
|
||||
return hess;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Approx, Partitioner) {
|
||||
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||
Context ctx;
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||
ASSERT_EQ(partitioner.Size(), 1);
|
||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||
|
||||
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
auto hess = GenerateHess(n_samples);
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
|
||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
||||
bst_feature_t const split_ind = 0;
|
||||
{
|
||||
auto min_value = page.cut.MinValues()[split_ind];
|
||||
RegTree tree;
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
GetSplit(&tree, min_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
ASSERT_EQ(partitioner.Size(), 3);
|
||||
@ -40,7 +47,7 @@ TEST(Approx, Partitioner) {
|
||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||
}
|
||||
{
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||
float split_value = page.cut.Values().at(ptr / 2);
|
||||
RegTree tree;
|
||||
@ -66,12 +73,85 @@ TEST(Approx, Partitioner) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared_ptr<DMatrix> Xy,
|
||||
std::vector<float>* hess, float min_value, float mid_value,
|
||||
CommonRowPartitioner const& expected_mid_partitioner) {
|
||||
auto dmat =
|
||||
std::unique_ptr<DMatrix>{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
Context ctx;
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>({64, *hess, true})) {
|
||||
{
|
||||
RegTree tree;
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, true};
|
||||
GetSplit(&tree, min_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
ASSERT_EQ(partitioner.Size(), 3);
|
||||
ASSERT_EQ(partitioner[1].Size(), 0);
|
||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||
}
|
||||
{
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, true};
|
||||
RegTree tree;
|
||||
GetSplit(&tree, mid_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
auto expected_elem = expected_mid_partitioner[left_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
elem = partitioner[right_nidx];
|
||||
expected_elem = expected_mid_partitioner[right_nidx];
|
||||
ASSERT_EQ(elem.Size(), expected_elem.Size());
|
||||
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) {
|
||||
ASSERT_EQ(*it, *eit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Approx, PartitionerColSplit) {
|
||||
size_t n_samples = 1024, n_features = 16, base_rowid = 0;
|
||||
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
auto hess = GenerateHess(n_samples);
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
|
||||
float min_value, mid_value;
|
||||
Context ctx;
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
CommonRowPartitioner mid_partitioner{&ctx, n_samples, base_rowid, false};
|
||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
||||
bst_feature_t const split_ind = 0;
|
||||
min_value = page.cut.MinValues()[split_ind];
|
||||
|
||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||
mid_value = page.cut.Values().at(ptr / 2);
|
||||
RegTree tree;
|
||||
GetSplit(&tree, mid_value, &candidates);
|
||||
mid_partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
}
|
||||
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy,
|
||||
&hess, min_value, mid_value, mid_partitioner);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestLeafPartition(size_t n_samples) {
|
||||
size_t const n_features = 2, base_rowid = 0;
|
||||
Context ctx;
|
||||
common::RowSetCollection row_set;
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
|
||||
@ -23,7 +23,7 @@ TEST(QuantileHist, Partitioner) {
|
||||
Context ctx;
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||
ASSERT_EQ(partitioner.Size(), 1);
|
||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||
@ -41,7 +41,7 @@ TEST(QuantileHist, Partitioner) {
|
||||
{
|
||||
auto min_value = gmat.cut.MinValues()[split_ind];
|
||||
RegTree tree;
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
GetSplit(&tree, min_value, &candidates);
|
||||
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
|
||||
ASSERT_EQ(partitioner.Size(), 3);
|
||||
@ -49,7 +49,7 @@ TEST(QuantileHist, Partitioner) {
|
||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||
}
|
||||
{
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||
auto ptr = gmat.cut.Ptrs()[split_ind + 1];
|
||||
float split_value = gmat.cut.Values().at(ptr / 2);
|
||||
RegTree tree;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user