Support hist in the partition builder under column split (#9120)
This commit is contained in:
@@ -183,14 +183,28 @@ class PartitionBuilder {
|
||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||
}
|
||||
|
||||
template <bool any_missing, typename ColumnType, typename Predicate>
|
||||
void MaskKernel(ColumnType* p_column, common::Span<const size_t> row_indices, size_t base_rowid,
|
||||
BitVector* decision_bits, BitVector* missing_bits, Predicate&& pred) {
|
||||
auto& column = *p_column;
|
||||
for (auto const row_id : row_indices) {
|
||||
auto const bin_id = column[row_id - base_rowid];
|
||||
if (any_missing && bin_id == ColumnType::kMissingId) {
|
||||
missing_bits->Set(row_id - base_rowid);
|
||||
} else if (pred(row_id, bin_id)) {
|
||||
decision_bits->Set(row_id - base_rowid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief When data is split by column, we don't have all the features locally on the current
|
||||
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made
|
||||
* to go right, or if the feature value used for the split is missing.
|
||||
*/
|
||||
template <typename ExpandEntry>
|
||||
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
|
||||
void MaskRows(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
|
||||
const common::Range1d range, GHistIndexMatrix const& gmat,
|
||||
const common::Range1d range, bst_bin_t split_cond, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid,
|
||||
BitVector* decision_bits, BitVector* missing_bits) {
|
||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||
@@ -204,7 +218,7 @@ class PartitionBuilder {
|
||||
for (auto row_id : rid_span) {
|
||||
auto gidx = gmat.GetGindex(row_id, fid);
|
||||
if (gidx > -1) {
|
||||
bool go_left = false;
|
||||
bool go_left;
|
||||
if (is_cat) {
|
||||
go_left = Decision(node_cats, cut_values[gidx]);
|
||||
} else {
|
||||
@@ -218,7 +232,27 @@ class PartitionBuilder {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||
auto pred_hist = [&](auto ridx, auto bin_id) {
|
||||
if (any_cat && is_cat) {
|
||||
auto gidx = gmat.GetGindex(ridx, fid);
|
||||
CHECK_GT(gidx, -1);
|
||||
return Decision(node_cats, cut_values[gidx]);
|
||||
} else {
|
||||
return bin_id <= split_cond;
|
||||
}
|
||||
};
|
||||
|
||||
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
||||
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
||||
MaskKernel<any_missing>(&column, rid_span, gmat.base_rowid, decision_bits, missing_bits,
|
||||
pred_hist);
|
||||
} else {
|
||||
CHECK_EQ(any_missing, true);
|
||||
auto column =
|
||||
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
||||
MaskKernel<any_missing>(&column, rid_span, gmat.base_rowid, decision_bits, missing_bits,
|
||||
pred_hist);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +272,7 @@ class PartitionBuilder {
|
||||
std::size_t nid = nodes[node_in_set].nid;
|
||||
bool default_left = tree[nid].DefaultLeft();
|
||||
|
||||
auto pred_approx = [&](auto ridx) {
|
||||
auto pred = [&](auto ridx) {
|
||||
bool go_left = default_left;
|
||||
bool is_missing = missing_bits.Check(ridx - gmat.base_rowid);
|
||||
if (!is_missing) {
|
||||
@@ -248,11 +282,7 @@ class PartitionBuilder {
|
||||
};
|
||||
|
||||
std::pair<size_t, size_t> child_nodes_sizes;
|
||||
if (!column_matrix.IsInitialized()) {
|
||||
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
||||
} else {
|
||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||
}
|
||||
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred);
|
||||
|
||||
const size_t n_left = child_nodes_sizes.first;
|
||||
const size_t n_right = child_nodes_sizes.second;
|
||||
|
||||
Reference in New Issue
Block a user