Support categorical data for hist. (#7695)

* Extract partitioner from hist.
* Implement categorical data support by passing the gradient index directly into the partitioner.
* Organize/update document.
* Remove code for negative hessian.
This commit is contained in:
Jiaming Yuan
2022-02-25 03:47:14 +08:00
committed by GitHub
parent f60d95b0ba
commit 83a66b4994
15 changed files with 402 additions and 498 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by Contributors
* Copyright 2021-2022 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
@@ -8,12 +8,15 @@
#define XGBOOST_COMMON_PARTITION_BUILDER_H_
#include <xgboost/data.h>
#include <algorithm>
#include <vector>
#include <utility>
#include <memory>
#include <utility>
#include <vector>
#include "categorical.h"
#include "column_matrix.h"
#include "xgboost/tree_model.h"
#include "../common/column_matrix.h"
namespace xgboost {
namespace common {
@@ -46,18 +49,24 @@ class PartitionBuilder {
// on comparison of indexes values (idx_span) and split point (split_cond)
// Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType>
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
common::Span<const size_t> rid_span, const int32_t split_cond,
common::Span<size_t> left_part, common::Span<size_t> right_part) {
common::Span<const size_t> row_indices,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) {
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
auto state = column.GetInitialState(rid_span.front());
auto state = column.GetInitialState(row_indices.front() - base_rowid);
for (auto rid : rid_span) {
const int32_t bin_id = column.GetBinIdx(rid, &state);
auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size();
for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i];
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) {
p_left_part[nleft_elems++] = rid;
@@ -65,7 +74,7 @@ class PartitionBuilder {
p_right_part[nright_elems++] = rid;
}
} else {
if (bin_id <= split_cond) {
if (pred(rid, bin_id)) {
p_left_part[nleft_elems++] = rid;
} else {
p_right_part[nright_elems++] = rid;
@@ -95,41 +104,66 @@ class PartitionBuilder {
return {nleft_elems, nright_elems};
}
template <typename BinIdxType, bool any_missing>
template <typename BinIdxType, bool any_missing, bool any_cat>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
const int32_t split_cond, GHistIndexMatrix const& gmat,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
std::pair<size_t, size_t> child_nodes_sizes;
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);
auto const& index = gmat.index;
auto const& cut_values = gmat.cut.Values();
auto const& cut_ptrs = gmat.cut.Ptrs();
auto pred = [&](auto ridx, auto bin_id) {
if (any_cat && is_cat) {
auto begin = gmat.RowIdx(ridx);
auto end = gmat.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fid];
auto f_end = cut_ptrs[fid + 1];
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
// features.
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
bool go_left;
if (gidx == -1) {
go_left = default_left;
} else {
go_left = Decision(node_cats, cut_values[gidx], default_left);
}
return go_left;
} else {
return bin_id <= split_cond;
}
};
std::pair<size_t, size_t> child_nodes_sizes;
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
}
} else {
CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
}
}

View File

@@ -275,9 +275,6 @@ class MemStackAllocator {
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }
// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
auto Get() { return ptr_; }
private:
T* ptr_ = nullptr;
size_t required_size_;