Unify the partitioner for hist and approx.
Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com> Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
c69af90319
commit
5bd849f1b5
@ -103,15 +103,18 @@ class SparseColumnIter : public Column<BinIdxT> {
|
|||||||
|
|
||||||
template <typename BinIdxT, bool any_missing>
|
template <typename BinIdxT, bool any_missing>
|
||||||
class DenseColumnIter : public Column<BinIdxT> {
|
class DenseColumnIter : public Column<BinIdxT> {
|
||||||
|
public:
|
||||||
|
using ByteType = bool;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using Base = Column<BinIdxT>;
|
using Base = Column<BinIdxT>;
|
||||||
/* flags for missing values in dense columns */
|
/* flags for missing values in dense columns */
|
||||||
std::vector<bool> const& missing_flags_;
|
std::vector<ByteType> const& missing_flags_;
|
||||||
size_t feature_offset_;
|
size_t feature_offset_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
||||||
std::vector<bool> const& missing_flags, size_t feature_offset)
|
std::vector<ByteType> const& missing_flags, size_t feature_offset)
|
||||||
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
||||||
DenseColumnIter(DenseColumnIter const&) = delete;
|
DenseColumnIter(DenseColumnIter const&) = delete;
|
||||||
DenseColumnIter(DenseColumnIter&&) = default;
|
DenseColumnIter(DenseColumnIter&&) = default;
|
||||||
@ -153,6 +156,7 @@ class ColumnMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using ByteType = bool;
|
||||||
// get number of features
|
// get number of features
|
||||||
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
|
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
|
||||||
|
|
||||||
@ -195,6 +199,8 @@ class ColumnMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsInitialized() const { return !type_.empty(); }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Push batch of data for Quantile DMatrix support.
|
* \brief Push batch of data for Quantile DMatrix support.
|
||||||
*
|
*
|
||||||
@ -352,6 +358,13 @@ class ColumnMatrix {
|
|||||||
|
|
||||||
fi->Read(&row_ind_);
|
fi->Read(&row_ind_);
|
||||||
fi->Read(&feature_offsets_);
|
fi->Read(&feature_offsets_);
|
||||||
|
|
||||||
|
std::vector<std::uint8_t> missing;
|
||||||
|
fi->Read(&missing);
|
||||||
|
missing_flags_.resize(missing.size());
|
||||||
|
std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(),
|
||||||
|
[](std::uint8_t flag) { return !!flag; });
|
||||||
|
|
||||||
index_base_ = index_base;
|
index_base_ = index_base;
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
#if !DMLC_LITTLE_ENDIAN
|
||||||
std::underlying_type<BinTypeSize>::type v;
|
std::underlying_type<BinTypeSize>::type v;
|
||||||
@ -386,6 +399,11 @@ class ColumnMatrix {
|
|||||||
#endif // !DMLC_LITTLE_ENDIAN
|
#endif // !DMLC_LITTLE_ENDIAN
|
||||||
write_vec(row_ind_);
|
write_vec(row_ind_);
|
||||||
write_vec(feature_offsets_);
|
write_vec(feature_offsets_);
|
||||||
|
// dmlc can not handle bool vector
|
||||||
|
std::vector<std::uint8_t> missing(missing_flags_.size());
|
||||||
|
std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(),
|
||||||
|
[](bool flag) { return static_cast<std::uint8_t>(flag); });
|
||||||
|
write_vec(missing);
|
||||||
|
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
#if !DMLC_LITTLE_ENDIAN
|
||||||
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
|
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
|
||||||
@ -413,7 +431,7 @@ class ColumnMatrix {
|
|||||||
|
|
||||||
// index_base_[fid]: least bin id for feature fid
|
// index_base_[fid]: least bin id for feature fid
|
||||||
uint32_t const* index_base_;
|
uint32_t const* index_base_;
|
||||||
std::vector<bool> missing_flags_;
|
std::vector<ByteType> missing_flags_;
|
||||||
BinTypeSize bins_type_size_;
|
BinTypeSize bins_type_size_;
|
||||||
bool any_missing_;
|
bool any_missing_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -4,6 +4,8 @@
|
|||||||
#ifndef XGBOOST_COMMON_NUMERIC_H_
|
#ifndef XGBOOST_COMMON_NUMERIC_H_
|
||||||
#define XGBOOST_COMMON_NUMERIC_H_
|
#define XGBOOST_COMMON_NUMERIC_H_
|
||||||
|
|
||||||
|
#include <dmlc/common.h> // OMPException
|
||||||
|
|
||||||
#include <algorithm> // std::max
|
#include <algorithm> // std::max
|
||||||
#include <iterator> // std::iterator_traits
|
#include <iterator> // std::iterator_traits
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -106,6 +108,26 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
|
|||||||
* \brief Reduction with summation.
|
* \brief Reduction with summation.
|
||||||
*/
|
*/
|
||||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
|
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
void Iota(Context const* ctx, It first, It last,
|
||||||
|
typename std::iterator_traits<It>::value_type const& value) {
|
||||||
|
auto n = std::distance(first, last);
|
||||||
|
std::int32_t n_threads = ctx->Threads();
|
||||||
|
const size_t block_size = n / n_threads + !!(n % n_threads);
|
||||||
|
dmlc::OMPException exc;
|
||||||
|
#pragma omp parallel num_threads(n_threads)
|
||||||
|
{
|
||||||
|
exc.Run([&]() {
|
||||||
|
const size_t tid = omp_get_thread_num();
|
||||||
|
const size_t ibegin = tid * block_size;
|
||||||
|
const size_t iend = std::min(ibegin + block_size, static_cast<size_t>(n));
|
||||||
|
for (size_t i = ibegin; i < iend; ++i) {
|
||||||
|
first[i] = i + value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "column_matrix.h"
|
#include "column_matrix.h"
|
||||||
|
#include "../tree/hist/expand_entry.h"
|
||||||
#include "xgboost/generic_parameters.h"
|
#include "xgboost/generic_parameters.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
@ -107,14 +108,17 @@ class PartitionBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename BinIdxType, bool any_missing, bool any_cat>
|
template <typename BinIdxType, bool any_missing, bool any_cat>
|
||||||
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
||||||
|
const common::Range1d range,
|
||||||
const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
|
const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
|
||||||
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
|
const common::ColumnMatrix& column_matrix,
|
||||||
|
const RegTree& tree, const size_t* rid) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
const bst_uint fid = tree[nid].SplitIndex();
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
const bool default_left = tree[nid].DefaultLeft();
|
bst_feature_t fid = tree[nid].SplitIndex();
|
||||||
|
bool default_left = tree[nid].DefaultLeft();
|
||||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
auto node_cats = tree.NodeCats(nid);
|
auto node_cats = tree.NodeCats(nid);
|
||||||
|
|
||||||
@ -122,19 +126,24 @@ class PartitionBuilder {
|
|||||||
auto const& cut_values = gmat.cut.Values();
|
auto const& cut_values = gmat.cut.Values();
|
||||||
auto const& cut_ptrs = gmat.cut.Ptrs();
|
auto const& cut_ptrs = gmat.cut.Ptrs();
|
||||||
|
|
||||||
auto pred = [&](auto ridx, auto bin_id) {
|
auto gidx_calc = [&](auto ridx) {
|
||||||
|
auto begin = gmat.RowIdx(ridx);
|
||||||
|
if (gmat.IsDense()) {
|
||||||
|
return static_cast<bst_bin_t>(index[begin + fid]);
|
||||||
|
}
|
||||||
|
auto end = gmat.RowIdx(ridx + 1);
|
||||||
|
auto f_begin = cut_ptrs[fid];
|
||||||
|
auto f_end = cut_ptrs[fid + 1];
|
||||||
|
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
|
||||||
|
// features.
|
||||||
|
return BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto pred_hist = [&](auto ridx, auto bin_id) {
|
||||||
if (any_cat && is_cat) {
|
if (any_cat && is_cat) {
|
||||||
auto begin = gmat.RowIdx(ridx);
|
auto gidx = gidx_calc(ridx);
|
||||||
auto end = gmat.RowIdx(ridx + 1);
|
bool go_left = default_left;
|
||||||
auto f_begin = cut_ptrs[fid];
|
if (gidx > -1) {
|
||||||
auto f_end = cut_ptrs[fid + 1];
|
|
||||||
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
|
|
||||||
// features.
|
|
||||||
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
|
|
||||||
bool go_left;
|
|
||||||
if (gidx == -1) {
|
|
||||||
go_left = default_left;
|
|
||||||
} else {
|
|
||||||
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
||||||
}
|
}
|
||||||
return go_left;
|
return go_left;
|
||||||
@ -143,25 +152,43 @@ class PartitionBuilder {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<size_t, size_t> child_nodes_sizes;
|
auto pred_approx = [&](auto ridx) {
|
||||||
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
auto gidx = gidx_calc(ridx);
|
||||||
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
bool go_left = default_left;
|
||||||
if (default_left) {
|
if (gidx > -1) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
if (is_cat) {
|
||||||
gmat.base_rowid, pred);
|
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
||||||
gmat.base_rowid, pred);
|
}
|
||||||
}
|
}
|
||||||
|
return go_left;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::pair<size_t, size_t> child_nodes_sizes;
|
||||||
|
if (!column_matrix.IsInitialized()) {
|
||||||
|
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(any_missing, true);
|
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
||||||
auto column = column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||||
gmat.base_rowid, pred);
|
gmat.base_rowid, pred_hist);
|
||||||
|
} else {
|
||||||
|
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||||
|
gmat.base_rowid, pred_hist);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
CHECK_EQ(any_missing, true);
|
||||||
gmat.base_rowid, pred);
|
auto column =
|
||||||
|
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
||||||
|
if (default_left) {
|
||||||
|
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||||
|
gmat.base_rowid, pred_hist);
|
||||||
|
} else {
|
||||||
|
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||||
|
gmat.base_rowid, pred_hist);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,37 +199,6 @@ class PartitionBuilder {
|
|||||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Partition tree nodes with specific range of row indices.
|
|
||||||
*
|
|
||||||
* \tparam Pred Predicate for whether a row should be partitioned to the left node.
|
|
||||||
*
|
|
||||||
* \param node_in_set The index of node in current batch of nodes.
|
|
||||||
* \param nid The canonical node index (node index in the tree).
|
|
||||||
* \param range The range of input row index.
|
|
||||||
* \param fidx Feature index.
|
|
||||||
* \param p_row_set_collection Pointer to rows that are being partitioned.
|
|
||||||
* \param pred A callback function that returns whether current row should be
|
|
||||||
* partitioned to the left node, it should accept the row index as
|
|
||||||
* input and returns a boolean value.
|
|
||||||
*/
|
|
||||||
template <typename Pred>
|
|
||||||
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
|
|
||||||
common::RowSetCollection* p_row_set_collection, Pred pred) {
|
|
||||||
auto& row_set_collection = *p_row_set_collection;
|
|
||||||
const size_t* p_ridx = row_set_collection[nid].begin;
|
|
||||||
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
|
|
||||||
common::Span<size_t> left = this->GetLeftBuffer(node_in_set, range.begin(), range.end());
|
|
||||||
common::Span<size_t> right = this->GetRightBuffer(node_in_set, range.begin(), range.end());
|
|
||||||
std::pair<size_t, size_t> child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred);
|
|
||||||
|
|
||||||
const size_t n_left = child_nodes_sizes.first;
|
|
||||||
const size_t n_right = child_nodes_sizes.second;
|
|
||||||
|
|
||||||
this->SetNLeftElems(node_in_set, range.begin(), n_left);
|
|
||||||
this->SetNRightElems(node_in_set, range.begin(), n_right);
|
|
||||||
}
|
|
||||||
|
|
||||||
// allocate thread local memory, should be called for each specific task
|
// allocate thread local memory, should be called for each specific task
|
||||||
void AllocateForTask(size_t id) {
|
void AllocateForTask(size_t id) {
|
||||||
if (mem_blocks_[id].get() == nullptr) {
|
if (mem_blocks_[id].get() == nullptr) {
|
||||||
|
|||||||
212
src/tree/common_row_partitioner.h
Normal file
212
src/tree/common_row_partitioner.h
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021-2022 XGBoost contributors
|
||||||
|
* \file common_row_partitioner.h
|
||||||
|
* \brief Common partitioner logic for hist and approx methods.
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||||
|
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||||
|
|
||||||
|
#include <limits> // std::numeric_limits
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/numeric.h" // Iota
|
||||||
|
#include "../common/partition_builder.h"
|
||||||
|
#include "hist/expand_entry.h" // CPUExpandEntry
|
||||||
|
#include "xgboost/generic_parameters.h" // Context
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
class CommonRowPartitioner {
|
||||||
|
static constexpr size_t kPartitionBlockSize = 2048;
|
||||||
|
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||||
|
common::RowSetCollection row_set_collection_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
bst_row_t base_rowid = 0;
|
||||||
|
|
||||||
|
CommonRowPartitioner() = default;
|
||||||
|
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid)
|
||||||
|
: base_rowid{_base_rowid} {
|
||||||
|
row_set_collection_.Clear();
|
||||||
|
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||||
|
row_indices.resize(num_row);
|
||||||
|
|
||||||
|
std::size_t* p_row_indices = row_indices.data();
|
||||||
|
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
|
||||||
|
row_set_collection_.Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||||
|
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
|
||||||
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||||
|
const int32_t nid = nodes[i].nid;
|
||||||
|
const bst_uint fid = tree[nid].SplitIndex();
|
||||||
|
const bst_float split_pt = tree[nid].SplitCond();
|
||||||
|
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
||||||
|
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
||||||
|
bst_bin_t split_cond = -1;
|
||||||
|
// convert floating-point split_pt into corresponding bin_id
|
||||||
|
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||||
|
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
|
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
|
||||||
|
if (split_pt == gmat.cut.Values()[bound]) {
|
||||||
|
split_cond = static_cast<int32_t>(bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*split_conditions).at(i) = split_cond;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree) {
|
||||||
|
const size_t n_nodes = nodes.size();
|
||||||
|
for (unsigned int i = 0; i < n_nodes; ++i) {
|
||||||
|
const int32_t nid = nodes[i].nid;
|
||||||
|
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
||||||
|
const size_t n_right = partition_builder_.GetNRightElems(i);
|
||||||
|
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
|
||||||
|
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
|
||||||
|
n_left, n_right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||||
|
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||||
|
auto const& column_matrix = gmat.Transpose();
|
||||||
|
if (column_matrix.IsInitialized()) {
|
||||||
|
if (gmat.cut.HasCategorical()) {
|
||||||
|
this->template UpdatePosition<true>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||||
|
} else {
|
||||||
|
this->template UpdatePosition<false>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* ColumnMatrix is not initilized.
|
||||||
|
* It means that we use 'approx' method.
|
||||||
|
* any_missing and any_cat don't metter in this case.
|
||||||
|
* Jump directly to the main method.
|
||||||
|
*/
|
||||||
|
this->template UpdatePosition<uint8_t, true, true>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool any_cat>
|
||||||
|
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||||
|
const common::ColumnMatrix& column_matrix,
|
||||||
|
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||||
|
if (column_matrix.AnyMissing()) {
|
||||||
|
this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||||
|
} else {
|
||||||
|
this->template UpdatePosition<false, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool any_missing, bool any_cat>
|
||||||
|
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||||
|
const common::ColumnMatrix& column_matrix,
|
||||||
|
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||||
|
switch (column_matrix.GetTypeSize()) {
|
||||||
|
case common::kUint8BinsTypeSize:
|
||||||
|
this->template UpdatePosition<uint8_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||||
|
nodes, p_tree);
|
||||||
|
break;
|
||||||
|
case common::kUint16BinsTypeSize:
|
||||||
|
this->template UpdatePosition<uint16_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||||
|
nodes, p_tree);
|
||||||
|
break;
|
||||||
|
case common::kUint32BinsTypeSize:
|
||||||
|
this->template UpdatePosition<uint32_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||||
|
nodes, p_tree);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// no default behavior
|
||||||
|
CHECK(false) << column_matrix.GetTypeSize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename BinIdxType, bool any_missing, bool any_cat>
|
||||||
|
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||||
|
const common::ColumnMatrix& column_matrix,
|
||||||
|
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||||
|
// 1. Find split condition for each split
|
||||||
|
size_t n_nodes = nodes.size();
|
||||||
|
|
||||||
|
std::vector<int32_t> split_conditions;
|
||||||
|
if (column_matrix.IsInitialized()) {
|
||||||
|
split_conditions.resize(n_nodes);
|
||||||
|
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2.1 Create a blocked space of size SUM(samples in each node)
|
||||||
|
common::BlockedSpace2d space(
|
||||||
|
n_nodes,
|
||||||
|
[&](size_t node_in_set) {
|
||||||
|
int32_t nid = nodes[node_in_set].nid;
|
||||||
|
return row_set_collection_[nid].Size();
|
||||||
|
},
|
||||||
|
kPartitionBlockSize);
|
||||||
|
|
||||||
|
// 2.2 Initialize the partition builder
|
||||||
|
// allocate buffers for storage intermediate results by each thread
|
||||||
|
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
const size_t size = row_set_collection_[nid].Size();
|
||||||
|
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
||||||
|
return n_tasks;
|
||||||
|
});
|
||||||
|
CHECK_EQ(base_rowid, gmat.base_rowid);
|
||||||
|
|
||||||
|
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
||||||
|
// Store results in intermediate buffers from partition_builder_
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
size_t begin = r.begin();
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
||||||
|
partition_builder_.AllocateForTask(task_id);
|
||||||
|
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
|
||||||
|
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
|
||||||
|
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
|
||||||
|
row_set_collection_[nid].begin);
|
||||||
|
});
|
||||||
|
|
||||||
|
// 3. Compute offsets to copy blocks of row-indexes
|
||||||
|
// from partition_builder_ to row_set_collection_
|
||||||
|
partition_builder_.CalculateRowOffsets();
|
||||||
|
|
||||||
|
// 4. Copy elements from partition_builder_ to row_set_collection_ back
|
||||||
|
// with updated row-indexes for each tree-node
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
||||||
|
const_cast<size_t*>(row_set_collection_[nid].begin));
|
||||||
|
});
|
||||||
|
|
||||||
|
// 5. Add info about splits into row_set_collection_
|
||||||
|
AddSplitsToRowSet(nodes, p_tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& Partitions() const { return row_set_collection_; }
|
||||||
|
|
||||||
|
size_t Size() const {
|
||||||
|
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
|
||||||
|
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
||||||
|
|
||||||
|
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
|
||||||
|
std::vector<bst_node_t>* p_out_position) const {
|
||||||
|
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
|
||||||
|
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void LeafPartition(Context const* ctx, RegTree const& tree,
|
||||||
|
common::Span<GradientPair const> gpair,
|
||||||
|
std::vector<bst_node_t>* p_out_position) const {
|
||||||
|
partition_builder_.LeafPartition(
|
||||||
|
ctx, tree, this->Partitions(), p_out_position,
|
||||||
|
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||||
@ -3,14 +3,13 @@
|
|||||||
*
|
*
|
||||||
* \brief Implementation for the approx tree method.
|
* \brief Implementation for the approx tree method.
|
||||||
*/
|
*/
|
||||||
#include "updater_approx.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../data/gradient_index.h"
|
#include "../data/gradient_index.h"
|
||||||
|
#include "common_row_partitioner.h"
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "driver.h"
|
#include "driver.h"
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
@ -46,7 +45,7 @@ class GloablApproxBuilder {
|
|||||||
Context const *ctx_;
|
Context const *ctx_;
|
||||||
ObjInfo const task_;
|
ObjInfo const task_;
|
||||||
|
|
||||||
std::vector<ApproxRowPartitioner> partitioner_;
|
std::vector<CommonRowPartitioner> partitioner_;
|
||||||
// Pointer to last updated tree, used for update prediction cache.
|
// Pointer to last updated tree, used for update prediction cache.
|
||||||
RegTree *p_last_tree_{nullptr};
|
RegTree *p_last_tree_{nullptr};
|
||||||
common::Monitor *monitor_;
|
common::Monitor *monitor_;
|
||||||
@ -69,7 +68,7 @@ class GloablApproxBuilder {
|
|||||||
} else {
|
} else {
|
||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(page.Size(), page.base_rowid);
|
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
|
||||||
n_batches_++;
|
n_batches_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ class GloablApproxBuilder {
|
|||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LeafPartition(RegTree const &tree, common::Span<float> hess,
|
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
|
||||||
std::vector<bst_node_t> *p_out_position) {
|
std::vector<bst_node_t> *p_out_position) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
if (!task_.UpdateTreeLeaf()) {
|
if (!task_.UpdateTreeLeaf()) {
|
||||||
|
|||||||
@ -1,150 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2021-2022 XGBoost contributors
|
|
||||||
*
|
|
||||||
* \brief Implementation for the approx tree method.
|
|
||||||
*/
|
|
||||||
#ifndef XGBOOST_TREE_UPDATER_APPROX_H_
|
|
||||||
#define XGBOOST_TREE_UPDATER_APPROX_H_
|
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "../common/partition_builder.h"
|
|
||||||
#include "../common/random.h"
|
|
||||||
#include "constraints.h"
|
|
||||||
#include "driver.h"
|
|
||||||
#include "hist/evaluate_splits.h"
|
|
||||||
#include "hist/expand_entry.h"
|
|
||||||
#include "param.h"
|
|
||||||
#include "xgboost/generic_parameters.h"
|
|
||||||
#include "xgboost/json.h"
|
|
||||||
#include "xgboost/tree_updater.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {
|
|
||||||
class ApproxRowPartitioner {
|
|
||||||
static constexpr size_t kPartitionBlockSize = 2048;
|
|
||||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
|
||||||
common::RowSetCollection row_set_collection_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
bst_row_t base_rowid = 0;
|
|
||||||
|
|
||||||
static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index,
|
|
||||||
std::vector<uint32_t> const &cut_ptrs,
|
|
||||||
std::vector<float> const &cut_values) {
|
|
||||||
int32_t gidx = -1;
|
|
||||||
if (index.IsDense()) {
|
|
||||||
// RowIdx returns the starting pos of this row
|
|
||||||
gidx = index.index[index.RowIdx(ridx) + fidx];
|
|
||||||
} else {
|
|
||||||
auto begin = index.RowIdx(ridx);
|
|
||||||
auto end = index.RowIdx(ridx + 1);
|
|
||||||
auto f_begin = cut_ptrs[fidx];
|
|
||||||
auto f_end = cut_ptrs[fidx + 1];
|
|
||||||
gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end);
|
|
||||||
}
|
|
||||||
if (gidx == -1) {
|
|
||||||
return std::numeric_limits<float>::quiet_NaN();
|
|
||||||
}
|
|
||||||
return cut_values[gidx];
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index,
|
|
||||||
std::vector<CPUExpandEntry> const &candidates, RegTree const *p_tree) {
|
|
||||||
size_t n_nodes = candidates.size();
|
|
||||||
|
|
||||||
auto const &cut_values = index.cut.Values();
|
|
||||||
auto const &cut_ptrs = index.cut.Ptrs();
|
|
||||||
|
|
||||||
common::BlockedSpace2d space{n_nodes,
|
|
||||||
[&](size_t node_in_set) {
|
|
||||||
auto candidate = candidates[node_in_set];
|
|
||||||
int32_t nid = candidate.nid;
|
|
||||||
return row_set_collection_[nid].Size();
|
|
||||||
},
|
|
||||||
kPartitionBlockSize};
|
|
||||||
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
|
||||||
auto candidate = candidates[node_in_set];
|
|
||||||
const int32_t nid = candidate.nid;
|
|
||||||
const size_t size = row_set_collection_[nid].Size();
|
|
||||||
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
|
||||||
return n_tasks;
|
|
||||||
});
|
|
||||||
auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr;
|
|
||||||
auto categories = p_tree->GetCategoriesMatrix().categories;
|
|
||||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
auto candidate = candidates[node_in_set];
|
|
||||||
auto is_cat = candidate.split.is_cat;
|
|
||||||
const int32_t nid = candidate.nid;
|
|
||||||
auto fidx = candidate.split.SplitIndex();
|
|
||||||
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin());
|
|
||||||
partition_builder_.AllocateForTask(task_id);
|
|
||||||
partition_builder_.PartitionRange(
|
|
||||||
node_in_set, nid, r, &row_set_collection_, [&](size_t row_id) {
|
|
||||||
auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values);
|
|
||||||
if (std::isnan(cut_value)) {
|
|
||||||
return candidate.split.DefaultLeft();
|
|
||||||
}
|
|
||||||
bst_node_t nidx = candidate.nid;
|
|
||||||
auto segment = node_ptr[nidx];
|
|
||||||
auto node_cats = categories.subspan(segment.beg, segment.size);
|
|
||||||
bool go_left = true;
|
|
||||||
if (is_cat) {
|
|
||||||
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
|
|
||||||
} else {
|
|
||||||
go_left = cut_value <= candidate.split.split_value;
|
|
||||||
}
|
|
||||||
return go_left;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
partition_builder_.CalculateRowOffsets();
|
|
||||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
auto candidate = candidates[node_in_set];
|
|
||||||
const int32_t nid = candidate.nid;
|
|
||||||
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
|
||||||
const_cast<size_t *>(row_set_collection_[nid].begin));
|
|
||||||
});
|
|
||||||
for (size_t i = 0; i < candidates.size(); ++i) {
|
|
||||||
auto const &candidate = candidates[i];
|
|
||||||
auto nidx = candidate.nid;
|
|
||||||
auto n_left = partition_builder_.GetNLeftElems(i);
|
|
||||||
auto n_right = partition_builder_.GetNRightElems(i);
|
|
||||||
CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size());
|
|
||||||
bst_node_t left_nidx = (*p_tree)[nidx].LeftChild();
|
|
||||||
bst_node_t right_nidx = (*p_tree)[nidx].RightChild();
|
|
||||||
row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto const &Partitions() const { return row_set_collection_; }
|
|
||||||
|
|
||||||
void LeafPartition(Context const *ctx, RegTree const &tree, common::Span<float const> hess,
|
|
||||||
std::vector<bst_node_t> *p_out_position) const {
|
|
||||||
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
|
|
||||||
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
|
|
||||||
}
|
|
||||||
|
|
||||||
auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
|
|
||||||
auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
|
||||||
|
|
||||||
size_t Size() const {
|
|
||||||
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
ApproxRowPartitioner() = default;
|
|
||||||
explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid)
|
|
||||||
: base_rowid{_base_rowid} {
|
|
||||||
row_set_collection_.Clear();
|
|
||||||
auto p_positions = row_set_collection_.Data();
|
|
||||||
p_positions->resize(num_row);
|
|
||||||
std::iota(p_positions->begin(), p_positions->end(), base_rowid);
|
|
||||||
row_set_collection_.Init();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_TREE_UPDATER_APPROX_H_
|
|
||||||
@ -12,7 +12,9 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common_row_partitioner.h"
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
|
#include "hist/histogram.h"
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -309,7 +311,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
} else {
|
} else {
|
||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
|
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
|
||||||
++page_id;
|
++page_id;
|
||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
@ -331,44 +333,6 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
|
|
||||||
const RegTree &tree, const GHistIndexMatrix &gmat,
|
|
||||||
std::vector<int32_t> *split_conditions) {
|
|
||||||
const size_t n_nodes = nodes.size();
|
|
||||||
split_conditions->resize(n_nodes);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
|
||||||
const int32_t nid = nodes[i].nid;
|
|
||||||
const bst_uint fid = tree[nid].SplitIndex();
|
|
||||||
const bst_float split_pt = tree[nid].SplitCond();
|
|
||||||
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
|
||||||
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
|
||||||
bst_bin_t split_cond = -1;
|
|
||||||
// convert floating-point split_pt into corresponding bin_id
|
|
||||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
|
||||||
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
|
|
||||||
if (split_pt == gmat.cut.Values()[bound]) {
|
|
||||||
split_cond = static_cast<int32_t>(bound);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(*split_conditions)[i] = split_cond;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &nodes,
|
|
||||||
RegTree const *p_tree) {
|
|
||||||
const size_t n_nodes = nodes.size();
|
|
||||||
for (unsigned int i = 0; i < n_nodes; ++i) {
|
|
||||||
const int32_t nid = nodes[i].nid;
|
|
||||||
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
|
||||||
const size_t n_right = partition_builder_.GetNRightElems(i);
|
|
||||||
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
|
|
||||||
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
|
|
||||||
n_left, n_right);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||||
.describe("Grow tree using quantized histogram.")
|
.describe("Grow tree using quantized histogram.")
|
||||||
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
||||||
|
|||||||
@ -24,6 +24,7 @@
|
|||||||
#include "hist/histogram.h"
|
#include "hist/histogram.h"
|
||||||
#include "hist/expand_entry.h"
|
#include "hist/expand_entry.h"
|
||||||
|
|
||||||
|
#include "common_row_partitioner.h"
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./driver.h"
|
#include "./driver.h"
|
||||||
@ -77,155 +78,6 @@ struct RandomReplace {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace tree {
|
namespace tree {
|
||||||
class HistRowPartitioner {
|
|
||||||
// heuristically chosen block size of parallel partitioning
|
|
||||||
static constexpr size_t kPartitionBlockSize = 2048;
|
|
||||||
// worker class that partition a block of rows
|
|
||||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
|
||||||
// storage for row index
|
|
||||||
common::RowSetCollection row_set_collection_;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Turn split values into discrete bin indices.
|
|
||||||
*/
|
|
||||||
static void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<int32_t>* split_conditions);
|
|
||||||
/**
|
|
||||||
* \brief Update the row set for new splits specifed by nodes.
|
|
||||||
*/
|
|
||||||
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree);
|
|
||||||
|
|
||||||
public:
|
|
||||||
bst_row_t base_rowid = 0;
|
|
||||||
|
|
||||||
public:
|
|
||||||
HistRowPartitioner(size_t n_samples, size_t base_rowid, int32_t n_threads) {
|
|
||||||
row_set_collection_.Clear();
|
|
||||||
const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads);
|
|
||||||
dmlc::OMPException exc;
|
|
||||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
|
||||||
row_indices.resize(n_samples);
|
|
||||||
size_t* p_row_indices = row_indices.data();
|
|
||||||
// parallel initialization o f row indices. (std::iota)
|
|
||||||
#pragma omp parallel num_threads(n_threads)
|
|
||||||
{
|
|
||||||
exc.Run([&]() {
|
|
||||||
const size_t tid = omp_get_thread_num();
|
|
||||||
const size_t ibegin = tid * block_size;
|
|
||||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size), n_samples);
|
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
|
||||||
p_row_indices[i] = i + base_rowid;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
row_set_collection_.Init();
|
|
||||||
this->base_rowid = base_rowid;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool any_missing, bool any_cat>
|
|
||||||
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& gmat,
|
|
||||||
common::ColumnMatrix const& column_matrix,
|
|
||||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
|
||||||
// 1. Find split condition for each split
|
|
||||||
const size_t n_nodes = nodes.size();
|
|
||||||
std::vector<int32_t> split_conditions;
|
|
||||||
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
|
||||||
// 2.1 Create a blocked space of size SUM(samples in each node)
|
|
||||||
common::BlockedSpace2d space(
|
|
||||||
n_nodes,
|
|
||||||
[&](size_t node_in_set) {
|
|
||||||
int32_t nid = nodes[node_in_set].nid;
|
|
||||||
return row_set_collection_[nid].Size();
|
|
||||||
},
|
|
||||||
kPartitionBlockSize);
|
|
||||||
// 2.2 Initialize the partition builder
|
|
||||||
// allocate buffers for storage intermediate results by each thread
|
|
||||||
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
const size_t size = row_set_collection_[nid].Size();
|
|
||||||
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
|
||||||
return n_tasks;
|
|
||||||
});
|
|
||||||
CHECK_EQ(base_rowid, gmat.base_rowid);
|
|
||||||
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
|
||||||
// Store results in intermediate buffers from partition_builder_
|
|
||||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
size_t begin = r.begin();
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
|
||||||
partition_builder_.AllocateForTask(task_id);
|
|
||||||
switch (column_matrix.GetTypeSize()) {
|
|
||||||
case common::kUint8BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint8_t, any_missing, any_cat>(
|
|
||||||
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
|
||||||
row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
case common::kUint16BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint16_t, any_missing, any_cat>(
|
|
||||||
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
|
||||||
row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
case common::kUint32BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint32_t, any_missing, any_cat>(
|
|
||||||
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
|
||||||
row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
// no default behavior
|
|
||||||
CHECK(false) << column_matrix.GetTypeSize();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// 3. Compute offsets to copy blocks of row-indexes
|
|
||||||
// from partition_builder_ to row_set_collection_
|
|
||||||
partition_builder_.CalculateRowOffsets();
|
|
||||||
|
|
||||||
// 4. Copy elements from partition_builder_ to row_set_collection_ back
|
|
||||||
// with updated row-indexes for each tree-node
|
|
||||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
|
||||||
const_cast<size_t*>(row_set_collection_[nid].begin));
|
|
||||||
});
|
|
||||||
// 5. Add info about splits into row_set_collection_
|
|
||||||
AddSplitsToRowSet(nodes, p_tree);
|
|
||||||
}
|
|
||||||
|
|
||||||
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page,
|
|
||||||
std::vector<CPUExpandEntry> const& applied, RegTree const* p_tree) {
|
|
||||||
auto const& column_matrix = page.Transpose();
|
|
||||||
if (page.cut.HasCategorical()) {
|
|
||||||
if (column_matrix.AnyMissing()) {
|
|
||||||
this->template UpdatePosition<true, true>(ctx, page, column_matrix, applied, p_tree);
|
|
||||||
} else {
|
|
||||||
this->template UpdatePosition<false, true>(ctx, page, column_matrix, applied, p_tree);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (column_matrix.AnyMissing()) {
|
|
||||||
this->template UpdatePosition<true, false>(ctx, page, column_matrix, applied, p_tree);
|
|
||||||
} else {
|
|
||||||
this->template UpdatePosition<false, false>(ctx, page, column_matrix, applied, p_tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto const& Partitions() const { return row_set_collection_; }
|
|
||||||
size_t Size() const {
|
|
||||||
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
void LeafPartition(Context const* ctx, RegTree const& tree,
|
|
||||||
common::Span<GradientPair const> gpair,
|
|
||||||
std::vector<bst_node_t>* p_out_position) const {
|
|
||||||
partition_builder_.LeafPartition(
|
|
||||||
ctx, tree, this->Partitions(), p_out_position,
|
|
||||||
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
|
|
||||||
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
inline BatchParam HistBatch(TrainParam const& param) {
|
inline BatchParam HistBatch(TrainParam const& param) {
|
||||||
return {param.max_bin, param.sparse_threshold};
|
return {param.max_bin, param.sparse_threshold};
|
||||||
}
|
}
|
||||||
@ -314,7 +166,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::vector<GradientPair> gpair_local_;
|
std::vector<GradientPair> gpair_local_;
|
||||||
|
|
||||||
std::unique_ptr<HistEvaluator<CPUExpandEntry>> evaluator_;
|
std::unique_ptr<HistEvaluator<CPUExpandEntry>> evaluator_;
|
||||||
std::vector<HistRowPartitioner> partitioner_;
|
std::vector<CommonRowPartitioner> partitioner_;
|
||||||
|
|
||||||
// back pointers to tree and data matrix
|
// back pointers to tree and data matrix
|
||||||
const RegTree* p_last_tree_{nullptr};
|
const RegTree* p_last_tree_{nullptr};
|
||||||
|
|||||||
@ -5,8 +5,8 @@
|
|||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
|
||||||
#include "../../../../src/common/hist_util.h"
|
#include "../../../../src/common/hist_util.h"
|
||||||
|
#include "../../../../src/tree/common_row_partitioner.h"
|
||||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||||
#include "../../../../src/tree/updater_quantile_hist.h"
|
|
||||||
#include "../test_evaluate_splits.h"
|
#include "../test_evaluate_splits.h"
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/common/numeric.h"
|
#include "../../../src/common/numeric.h"
|
||||||
#include "../../../src/tree/updater_approx.h"
|
#include "../../../src/tree/common_row_partitioner.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "test_partitioner.h"
|
#include "test_partitioner.h"
|
||||||
|
|
||||||
@ -12,13 +12,13 @@ namespace xgboost {
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
TEST(Approx, Partitioner) {
|
TEST(Approx, Partitioner) {
|
||||||
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
GenericParameter ctx;
|
||||||
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||||
ASSERT_EQ(partitioner.Size(), 1);
|
ASSERT_EQ(partitioner.Size(), 1);
|
||||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
GenericParameter ctx;
|
|
||||||
ctx.InitAllowUnknown(Args{});
|
ctx.InitAllowUnknown(Args{});
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ TEST(Approx, Partitioner) {
|
|||||||
{
|
{
|
||||||
auto min_value = page.cut.MinValues()[split_ind];
|
auto min_value = page.cut.MinValues()[split_ind];
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
GetSplit(&tree, min_value, &candidates);
|
GetSplit(&tree, min_value, &candidates);
|
||||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||||
ASSERT_EQ(partitioner.Size(), 3);
|
ASSERT_EQ(partitioner.Size(), 3);
|
||||||
@ -40,7 +40,7 @@ TEST(Approx, Partitioner) {
|
|||||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||||
float split_value = page.cut.Values().at(ptr / 2);
|
float split_value = page.cut.Values().at(ptr / 2);
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
@ -65,14 +65,15 @@ TEST(Approx, Partitioner) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestLeafPartition(size_t n_samples) {
|
void TestLeafPartition(size_t n_samples) {
|
||||||
size_t const n_features = 2, base_rowid = 0;
|
size_t const n_features = 2, base_rowid = 0;
|
||||||
|
GenericParameter ctx;
|
||||||
common::RowSetCollection row_set;
|
common::RowSetCollection row_set;
|
||||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
GenericParameter ctx;
|
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<float> hess(n_samples, 0);
|
std::vector<float> hess(n_samples, 0);
|
||||||
@ -81,11 +82,9 @@ void TestLeafPartition(size_t n_samples) {
|
|||||||
size_t const kSampleFactor{3};
|
size_t const kSampleFactor{3};
|
||||||
return i % kSampleFactor != 0;
|
return i % kSampleFactor != 0;
|
||||||
};
|
};
|
||||||
size_t n{0};
|
|
||||||
for (size_t i = 0; i < hess.size(); ++i) {
|
for (size_t i = 0; i < hess.size(); ++i) {
|
||||||
if (not_sampled(i)) {
|
if (not_sampled(i)) {
|
||||||
hess[i] = 1.0f;
|
hess[i] = 1.0f;
|
||||||
++n;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
size_t constexpr kRows = 32;
|
size_t constexpr kRows = 32;
|
||||||
size_t constexpr kCols = 16;
|
size_t constexpr kCols = 16;
|
||||||
|
|
||||||
GenericParameter param;
|
Context ctx;
|
||||||
param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
|
||||||
|
|
||||||
auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatrix();
|
auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{
|
updater->Configure(Args{
|
||||||
{"interaction_constraints", "[[0, 1]]"},
|
{"interaction_constraints", "[[0, 1]]"},
|
||||||
{"num_feature", std::to_string(kCols)}});
|
{"num_feature", std::to_string(kCols)}});
|
||||||
@ -54,7 +53,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(&gradients, p_dmat.get(), position, {&tree});
|
updater->Update(&gradients, p_dmat.get(), position, {&tree});
|
||||||
|
|||||||
@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
#include "../../../src/tree/split_evaluator.h"
|
#include "../../../src/tree/split_evaluator.h"
|
||||||
#include "../../../src/tree/updater_quantile_hist.h"
|
#include "../../../src/tree/common_row_partitioner.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "test_partitioner.h"
|
#include "test_partitioner.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -23,7 +23,7 @@ TEST(QuantileHist, Partitioner) {
|
|||||||
GenericParameter ctx;
|
GenericParameter ctx;
|
||||||
ctx.InitAllowUnknown(Args{});
|
ctx.InitAllowUnknown(Args{});
|
||||||
|
|
||||||
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||||
ASSERT_EQ(partitioner.Size(), 1);
|
ASSERT_EQ(partitioner.Size(), 1);
|
||||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||||
@ -41,7 +41,7 @@ TEST(QuantileHist, Partitioner) {
|
|||||||
{
|
{
|
||||||
auto min_value = gmat.cut.MinValues()[split_ind];
|
auto min_value = gmat.cut.MinValues()[split_ind];
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
GetSplit(&tree, min_value, &candidates);
|
GetSplit(&tree, min_value, &candidates);
|
||||||
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
|
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
|
||||||
ASSERT_EQ(partitioner.Size(), 3);
|
ASSERT_EQ(partitioner.Size(), 3);
|
||||||
@ -49,7 +49,7 @@ TEST(QuantileHist, Partitioner) {
|
|||||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid};
|
||||||
auto ptr = gmat.cut.Ptrs()[split_ind + 1];
|
auto ptr = gmat.cut.Ptrs()[split_ind + 1];
|
||||||
float split_value = gmat.cut.Values().at(ptr / 2);
|
float split_value = gmat.cut.Values().at(ptr / 2);
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
#include <xgboost/tree_updater.h>
|
|
||||||
#include <xgboost/tree_model.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
@ -21,9 +21,10 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RunTest(std::string updater) {
|
void RunTest(std::string updater) {
|
||||||
auto tparam = CreateEmptyGenericParam(0);
|
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
auto up = std::unique_ptr<TreeUpdater>{
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
up->Configure(Args{});
|
up->Configure(Args{});
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
@ -41,22 +42,14 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
TEST_F(UpdaterTreeStatTest, GpuHist) {
|
TEST_F(UpdaterTreeStatTest, GpuHist) { this->RunTest("grow_gpu_hist"); }
|
||||||
this->RunTest("grow_gpu_hist");
|
|
||||||
}
|
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
TEST_F(UpdaterTreeStatTest, Hist) {
|
TEST_F(UpdaterTreeStatTest, Hist) { this->RunTest("grow_quantile_histmaker"); }
|
||||||
this->RunTest("grow_quantile_histmaker");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(UpdaterTreeStatTest, Exact) {
|
TEST_F(UpdaterTreeStatTest, Exact) { this->RunTest("grow_colmaker"); }
|
||||||
this->RunTest("grow_colmaker");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(UpdaterTreeStatTest, Approx) {
|
TEST_F(UpdaterTreeStatTest, Approx) { this->RunTest("grow_histmaker"); }
|
||||||
this->RunTest("grow_histmaker");
|
|
||||||
}
|
|
||||||
|
|
||||||
class UpdaterEtaTest : public ::testing::Test {
|
class UpdaterEtaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -74,14 +67,15 @@ class UpdaterEtaTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RunTest(std::string updater) {
|
void RunTest(std::string updater) {
|
||||||
auto tparam = CreateEmptyGenericParam(0);
|
GenericParameter ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
float eta = 0.4;
|
float eta = 0.4;
|
||||||
auto up_0 = std::unique_ptr<TreeUpdater>{
|
auto up_0 = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kClassification})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||||
up_0->Configure(Args{{"eta", std::to_string(eta)}});
|
up_0->Configure(Args{{"eta", std::to_string(eta)}});
|
||||||
|
|
||||||
auto up_1 = std::unique_ptr<TreeUpdater>{
|
auto up_1 = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kClassification})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||||
up_1->Configure(Args{{"eta", "1.0"}});
|
up_1->Configure(Args{{"eta", "1.0"}});
|
||||||
|
|
||||||
for (size_t iter = 0; iter < 4; ++iter) {
|
for (size_t iter = 0; iter < 4; ++iter) {
|
||||||
@ -130,7 +124,7 @@ class TestMinSplitLoss : public ::testing::Test {
|
|||||||
gpair_ = GenerateRandomGradients(kRows);
|
gpair_ = GenerateRandomGradients(kRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t Update(std::string updater, float gamma) {
|
std::int32_t Update(std::string updater, float gamma) {
|
||||||
Args args{{"max_depth", "1"},
|
Args args{{"max_depth", "1"},
|
||||||
{"max_leaves", "0"},
|
{"max_leaves", "0"},
|
||||||
|
|
||||||
@ -146,9 +140,12 @@ class TestMinSplitLoss : public ::testing::Test {
|
|||||||
// test gamma
|
// test gamma
|
||||||
{"gamma", std::to_string(gamma)}};
|
{"gamma", std::to_string(gamma)}};
|
||||||
|
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
std::cout << "updater:" << updater << std::endl;
|
||||||
|
GenericParameter ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||||
|
: CreateEmptyGenericParam(Context::kCpuId));
|
||||||
|
std::cout << ctx.gpu_id << std::endl;
|
||||||
auto up = std::unique_ptr<TreeUpdater>{
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
up->Configure(args);
|
up->Configure(args);
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user