Bound the size of the histogram cache. (#9440)
- A new histogram collection with a limit in size. - Unify histogram building logic between hist, multi-hist, and approx.
This commit is contained in:
parent
5bd163aa25
commit
54029a59af
@ -69,6 +69,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/tree/updater_refresh.o \
|
$(PKGROOT)/src/tree/updater_refresh.o \
|
||||||
$(PKGROOT)/src/tree/updater_sync.o \
|
$(PKGROOT)/src/tree/updater_sync.o \
|
||||||
$(PKGROOT)/src/tree/hist/param.o \
|
$(PKGROOT)/src/tree/hist/param.o \
|
||||||
|
$(PKGROOT)/src/tree/hist/histogram.o \
|
||||||
$(PKGROOT)/src/linear/linear_updater.o \
|
$(PKGROOT)/src/linear/linear_updater.o \
|
||||||
$(PKGROOT)/src/linear/updater_coordinate.o \
|
$(PKGROOT)/src/linear/updater_coordinate.o \
|
||||||
$(PKGROOT)/src/linear/updater_shotgun.o \
|
$(PKGROOT)/src/linear/updater_shotgun.o \
|
||||||
|
|||||||
@ -69,6 +69,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/tree/updater_refresh.o \
|
$(PKGROOT)/src/tree/updater_refresh.o \
|
||||||
$(PKGROOT)/src/tree/updater_sync.o \
|
$(PKGROOT)/src/tree/updater_sync.o \
|
||||||
$(PKGROOT)/src/tree/hist/param.o \
|
$(PKGROOT)/src/tree/hist/param.o \
|
||||||
|
$(PKGROOT)/src/tree/hist/histogram.o \
|
||||||
$(PKGROOT)/src/linear/linear_updater.o \
|
$(PKGROOT)/src/linear/linear_updater.o \
|
||||||
$(PKGROOT)/src/linear/updater_coordinate.o \
|
$(PKGROOT)/src/linear/updater_coordinate.o \
|
||||||
$(PKGROOT)/src/linear/updater_shotgun.o \
|
$(PKGROOT)/src/linear/updater_shotgun.o \
|
||||||
|
|||||||
@ -91,8 +91,6 @@ namespace xgboost {
|
|||||||
|
|
||||||
/*! \brief unsigned integer type used for feature index. */
|
/*! \brief unsigned integer type used for feature index. */
|
||||||
using bst_uint = uint32_t; // NOLINT
|
using bst_uint = uint32_t; // NOLINT
|
||||||
/*! \brief integer type. */
|
|
||||||
using bst_int = int32_t; // NOLINT
|
|
||||||
/*! \brief unsigned long integers */
|
/*! \brief unsigned long integers */
|
||||||
using bst_ulong = uint64_t; // NOLINT
|
using bst_ulong = uint64_t; // NOLINT
|
||||||
/*! \brief float type, used for storing statistics */
|
/*! \brief float type, used for storing statistics */
|
||||||
@ -138,9 +136,9 @@ namespace detail {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
class GradientPairInternal {
|
class GradientPairInternal {
|
||||||
/*! \brief gradient statistics */
|
/*! \brief gradient statistics */
|
||||||
T grad_;
|
T grad_{0};
|
||||||
/*! \brief second order gradient statistics */
|
/*! \brief second order gradient statistics */
|
||||||
T hess_;
|
T hess_{0};
|
||||||
|
|
||||||
XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
|
XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
|
||||||
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }
|
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }
|
||||||
@ -157,7 +155,7 @@ class GradientPairInternal {
|
|||||||
a += b;
|
a += b;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}
|
GradientPairInternal() = default;
|
||||||
|
|
||||||
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
|
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
|
||||||
SetGrad(grad);
|
SetGrad(grad);
|
||||||
|
|||||||
34
python-package/xgboost/testing/data_iter.py
Normal file
34
python-package/xgboost/testing/data_iter.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""Tests related to the `DataIter` interface."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import xgboost
|
||||||
|
from xgboost import testing as tm
|
||||||
|
|
||||||
|
|
||||||
|
def run_mixed_sparsity(device: str) -> None:
|
||||||
|
"""Check QDM with mixed batches."""
|
||||||
|
X_0, y_0, _ = tm.make_regression(128, 16, False)
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
|
||||||
|
else:
|
||||||
|
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, False)
|
||||||
|
X_2, y_2 = tm.make_sparse_regression(512, 16, 0.9, True)
|
||||||
|
X = [X_0, X_1, X_2]
|
||||||
|
y = [y_0, y_1, y_2]
|
||||||
|
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
import cupy as cp # pylint: disable=import-error
|
||||||
|
|
||||||
|
X = [cp.array(batch) for batch in X]
|
||||||
|
|
||||||
|
it = tm.IteratorForTest(X, y, None, None)
|
||||||
|
Xy_0 = xgboost.QuantileDMatrix(it)
|
||||||
|
|
||||||
|
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
|
||||||
|
X = [X_0, X_1, X_2]
|
||||||
|
y = [y_0, y_1, y_2]
|
||||||
|
X_arr = np.concatenate(X, axis=0)
|
||||||
|
y_arr = np.concatenate(y, axis=0)
|
||||||
|
Xy_1 = xgboost.QuantileDMatrix(X_arr, y_arr)
|
||||||
|
|
||||||
|
assert tm.predictor_equal(Xy_0, Xy_1)
|
||||||
@ -41,6 +41,10 @@ hist_parameter_strategy = strategies.fixed_dictionaries(
|
|||||||
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
|
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hist_cache_strategy = strategies.fixed_dictionaries(
|
||||||
|
{"internal_max_cached_hist_node": strategies.sampled_from([1, 4, 1024, 2**31])}
|
||||||
|
)
|
||||||
|
|
||||||
hist_multi_parameter_strategy = strategies.fixed_dictionaries(
|
hist_multi_parameter_strategy = strategies.fixed_dictionaries(
|
||||||
{
|
{
|
||||||
"max_depth": strategies.integers(1, 11),
|
"max_depth": strategies.integers(1, 11),
|
||||||
|
|||||||
@ -67,17 +67,6 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief fill a histogram by zeros in range [begin, end)
|
|
||||||
*/
|
|
||||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
|
|
||||||
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
|
||||||
std::fill(hist.begin() + begin, hist.begin() + end, xgboost::GradientPairPrecise());
|
|
||||||
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
|
||||||
memset(hist.data() + begin, '\0', (end - begin) * sizeof(xgboost::GradientPairPrecise));
|
|
||||||
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Increment hist as dst += add in range [begin, end)
|
* \brief Increment hist as dst += add in range [begin, end)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -364,11 +364,6 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t
|
|||||||
using GHistRow = Span<xgboost::GradientPairPrecise>;
|
using GHistRow = Span<xgboost::GradientPairPrecise>;
|
||||||
using ConstGHistRow = Span<xgboost::GradientPairPrecise const>;
|
using ConstGHistRow = Span<xgboost::GradientPairPrecise const>;
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief fill a histogram by zeros
|
|
||||||
*/
|
|
||||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Increment hist as dst += add in range [begin, end)
|
* \brief Increment hist as dst += add in range [begin, end)
|
||||||
*/
|
*/
|
||||||
@ -395,12 +390,7 @@ class HistCollection {
|
|||||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||||
const size_t id = row_ptr_.at(nid);
|
const size_t id = row_ptr_.at(nid);
|
||||||
CHECK_NE(id, kMax);
|
CHECK_NE(id, kMax);
|
||||||
GradientPairPrecise* ptr = nullptr;
|
GradientPairPrecise* ptr = const_cast<GradientPairPrecise*>(data_[id].data());
|
||||||
if (contiguous_allocation_) {
|
|
||||||
ptr = const_cast<GradientPairPrecise*>(data_[0].data() + nbins_*id);
|
|
||||||
} else {
|
|
||||||
ptr = const_cast<GradientPairPrecise*>(data_[id].data());
|
|
||||||
}
|
|
||||||
return {ptr, nbins_};
|
return {ptr, nbins_};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,24 +435,12 @@ class HistCollection {
|
|||||||
data_[row_ptr_[nid]].resize(nbins_, {0, 0});
|
data_[row_ptr_[nid]].resize(nbins_, {0, 0});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// allocate common buffer contiguously for all nodes, need for single Allreduce call
|
|
||||||
void AllocateAllData() {
|
|
||||||
const size_t new_size = nbins_*data_.size();
|
|
||||||
contiguous_allocation_ = true;
|
|
||||||
if (data_[0].size() != new_size) {
|
|
||||||
data_[0].resize(new_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief number of all bins over all features */
|
/*! \brief number of all bins over all features */
|
||||||
uint32_t nbins_ = 0;
|
uint32_t nbins_ = 0;
|
||||||
/*! \brief amount of active nodes in hist collection */
|
/*! \brief amount of active nodes in hist collection */
|
||||||
uint32_t n_nodes_added_ = 0;
|
uint32_t n_nodes_added_ = 0;
|
||||||
/*! \brief flag to identify contiguous memory allocation */
|
|
||||||
bool contiguous_allocation_ = false;
|
|
||||||
|
|
||||||
std::vector<std::vector<GradientPairPrecise>> data_;
|
std::vector<std::vector<GradientPairPrecise>> data_;
|
||||||
|
|
||||||
/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
|
/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
|
||||||
@ -518,7 +496,7 @@ class ParallelGHistBuilder {
|
|||||||
GHistRow hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx];
|
GHistRow hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx];
|
||||||
|
|
||||||
if (!hist_was_used_[tid * nodes_ + nid]) {
|
if (!hist_was_used_[tid * nodes_ + nid]) {
|
||||||
InitilizeHistByZeroes(hist, 0, hist.size());
|
std::fill_n(hist.data(), hist.size(), GradientPairPrecise{});
|
||||||
hist_was_used_[tid * nodes_ + nid] = static_cast<int>(true);
|
hist_was_used_[tid * nodes_ + nid] = static_cast<int>(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -548,7 +526,7 @@ class ParallelGHistBuilder {
|
|||||||
if (!is_updated) {
|
if (!is_updated) {
|
||||||
// In distributed mode - some tree nodes can be empty on local machines,
|
// In distributed mode - some tree nodes can be empty on local machines,
|
||||||
// So we need just set local hist by zeros in this case
|
// So we need just set local hist by zeros in this case
|
||||||
InitilizeHistByZeroes(dst, begin, end);
|
std::fill(dst.data() + begin, dst.data() + end, GradientPairPrecise{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -7,13 +7,14 @@
|
|||||||
#include <dmlc/common.h>
|
#include <dmlc/common.h>
|
||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // for min
|
||||||
#include <cstdint> // for int32_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdlib> // for malloc, free
|
#include <cstdint> // for int32_t
|
||||||
#include <limits>
|
#include <cstdlib> // for malloc, free
|
||||||
|
#include <functional> // for function
|
||||||
#include <new> // for bad_alloc
|
#include <new> // for bad_alloc
|
||||||
#include <type_traits> // for is_signed
|
#include <type_traits> // for is_signed, conditional_t
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
@ -25,6 +26,8 @@ inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
|
|||||||
|
|
||||||
// MSVC doesn't implement the thread limit.
|
// MSVC doesn't implement the thread limit.
|
||||||
#if defined(_OPENMP) && defined(_MSC_VER)
|
#if defined(_OPENMP) && defined(_MSC_VER)
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
|
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
|
||||||
}
|
}
|
||||||
@ -84,8 +87,8 @@ class BlockedSpace2d {
|
|||||||
// dim1 - size of the first dimension in the space
|
// dim1 - size of the first dimension in the space
|
||||||
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
|
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
|
||||||
// grain_size - max size of produced blocks
|
// grain_size - max size of produced blocks
|
||||||
template <typename Func>
|
BlockedSpace2d(std::size_t dim1, std::function<std::size_t(std::size_t)> getter_size_dim2,
|
||||||
BlockedSpace2d(std::size_t dim1, Func getter_size_dim2, std::size_t grain_size) {
|
std::size_t grain_size) {
|
||||||
for (std::size_t i = 0; i < dim1; ++i) {
|
for (std::size_t i = 0; i < dim1; ++i) {
|
||||||
std::size_t size = getter_size_dim2(i);
|
std::size_t size = getter_size_dim2(i);
|
||||||
// Each row (second dim) is divided into n_blocks
|
// Each row (second dim) is divided into n_blocks
|
||||||
@ -104,13 +107,13 @@ class BlockedSpace2d {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get index of the first dimension of i-th block(task)
|
// get index of the first dimension of i-th block(task)
|
||||||
[[nodiscard]] std::size_t GetFirstDimension(size_t i) const {
|
[[nodiscard]] std::size_t GetFirstDimension(std::size_t i) const {
|
||||||
CHECK_LT(i, first_dimension_.size());
|
CHECK_LT(i, first_dimension_.size());
|
||||||
return first_dimension_[i];
|
return first_dimension_[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// get a range of indexes for the second dimension of i-th block(task)
|
// get a range of indexes for the second dimension of i-th block(task)
|
||||||
[[nodiscard]] Range1d GetRange(size_t i) const {
|
[[nodiscard]] Range1d GetRange(std::size_t i) const {
|
||||||
CHECK_LT(i, ranges_.size());
|
CHECK_LT(i, ranges_.size());
|
||||||
return ranges_[i];
|
return ranges_[i];
|
||||||
}
|
}
|
||||||
@ -129,22 +132,22 @@ class BlockedSpace2d {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Range1d> ranges_;
|
std::vector<Range1d> ranges_;
|
||||||
std::vector<size_t> first_dimension_;
|
std::vector<std::size_t> first_dimension_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Wrapper to implement nested parallelism with simple omp parallel for
|
// Wrapper to implement nested parallelism with simple omp parallel for
|
||||||
template <typename Func>
|
inline void ParallelFor2d(BlockedSpace2d const& space, std::int32_t n_threads,
|
||||||
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
|
std::function<void(std::size_t, Range1d)> func) {
|
||||||
std::size_t n_blocks_in_space = space.Size();
|
std::size_t n_blocks_in_space = space.Size();
|
||||||
CHECK_GE(nthreads, 1);
|
CHECK_GE(n_threads, 1);
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
dmlc::OMPException exc;
|
||||||
#pragma omp parallel num_threads(nthreads)
|
#pragma omp parallel num_threads(n_threads)
|
||||||
{
|
{
|
||||||
exc.Run([&]() {
|
exc.Run([&]() {
|
||||||
size_t tid = omp_get_thread_num();
|
std::size_t tid = omp_get_thread_num();
|
||||||
size_t chunck_size = n_blocks_in_space / nthreads + !!(n_blocks_in_space % nthreads);
|
std::size_t chunck_size = n_blocks_in_space / n_threads + !!(n_blocks_in_space % n_threads);
|
||||||
|
|
||||||
std::size_t begin = chunck_size * tid;
|
std::size_t begin = chunck_size * tid;
|
||||||
std::size_t end = std::min(begin + chunck_size, n_blocks_in_space);
|
std::size_t end = std::min(begin + chunck_size, n_blocks_in_space);
|
||||||
|
|||||||
@ -477,7 +477,6 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo {
|
|||||||
ArrayInterface<1> indptr_;
|
ArrayInterface<1> indptr_;
|
||||||
ArrayInterface<1> indices_;
|
ArrayInterface<1> indices_;
|
||||||
ArrayInterface<1> values_;
|
ArrayInterface<1> values_;
|
||||||
bst_row_t n_rows_;
|
|
||||||
|
|
||||||
class Line {
|
class Line {
|
||||||
std::size_t column_idx_;
|
std::size_t column_idx_;
|
||||||
@ -503,11 +502,8 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo {
|
|||||||
static constexpr bool kIsRowMajor = false;
|
static constexpr bool kIsRowMajor = false;
|
||||||
|
|
||||||
CSCArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices,
|
CSCArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices,
|
||||||
ArrayInterface<1> values, bst_row_t n_rows)
|
ArrayInterface<1> values)
|
||||||
: indptr_{std::move(indptr)},
|
: indptr_{std::move(indptr)}, indices_{std::move(indices)}, values_{std::move(values)} {}
|
||||||
indices_{std::move(indices)},
|
|
||||||
values_{std::move(values)},
|
|
||||||
n_rows_{n_rows} {}
|
|
||||||
|
|
||||||
std::size_t Size() const { return indptr_.n - 1; }
|
std::size_t Size() const { return indptr_.n - 1; }
|
||||||
Line GetLine(std::size_t idx) const {
|
Line GetLine(std::size_t idx) const {
|
||||||
@ -542,8 +538,7 @@ class CSCArrayAdapter : public detail::SingleBatchDataIter<CSCArrayAdapterBatch>
|
|||||||
indices_{indices},
|
indices_{indices},
|
||||||
values_{values},
|
values_{values},
|
||||||
num_rows_{num_rows},
|
num_rows_{num_rows},
|
||||||
batch_{
|
batch_{CSCArrayAdapterBatch{indptr_, indices_, values_}} {}
|
||||||
CSCArrayAdapterBatch{indptr_, indices_, values_, static_cast<bst_row_t>(num_rows_)}} {}
|
|
||||||
|
|
||||||
// JVM package sends 0 as unknown
|
// JVM package sends 0 as unknown
|
||||||
size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }
|
size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }
|
||||||
|
|||||||
@ -4,13 +4,13 @@
|
|||||||
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
|
|
||||||
#include <algorithm> // for copy
|
#include <algorithm> // for copy
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <numeric> // for accumulate
|
#include <numeric> // for accumulate
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../common/categorical.h" // for CatBitField
|
#include "../../common/categorical.h" // for CatBitField
|
||||||
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
|
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
|
||||||
@ -20,6 +20,7 @@
|
|||||||
#include "../param.h" // for TrainParam
|
#include "../param.h" // for TrainParam
|
||||||
#include "../split_evaluator.h" // for TreeEvaluator
|
#include "../split_evaluator.h" // for TreeEvaluator
|
||||||
#include "expand_entry.h" // for MultiExpandEntry
|
#include "expand_entry.h" // for MultiExpandEntry
|
||||||
|
#include "hist_cache.h" // for BoundedHistCollection
|
||||||
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
|
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
|
||||||
#include "xgboost/context.h" // for COntext
|
#include "xgboost/context.h" // for COntext
|
||||||
#include "xgboost/linalg.h" // for Constants, Vector
|
#include "xgboost/linalg.h" // for Constants, Vector
|
||||||
@ -317,7 +318,7 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
|
void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
|
||||||
common::Span<FeatureType const> feature_types, const RegTree &tree,
|
common::Span<FeatureType const> feature_types, const RegTree &tree,
|
||||||
std::vector<CPUExpandEntry> *p_entries) {
|
std::vector<CPUExpandEntry> *p_entries) {
|
||||||
auto n_threads = ctx_->Threads();
|
auto n_threads = ctx_->Threads();
|
||||||
@ -623,7 +624,7 @@ class HistMultiEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
|
void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
|
||||||
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
|
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
|
||||||
auto &entries = *p_entries;
|
auto &entries = *p_entries;
|
||||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
|
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
|
||||||
|
|||||||
@ -18,8 +18,8 @@ namespace xgboost::tree {
|
|||||||
*/
|
*/
|
||||||
template <typename Impl>
|
template <typename Impl>
|
||||||
struct ExpandEntryImpl {
|
struct ExpandEntryImpl {
|
||||||
bst_node_t nid;
|
bst_node_t nid{0};
|
||||||
bst_node_t depth;
|
bst_node_t depth{0};
|
||||||
|
|
||||||
[[nodiscard]] float GetLossChange() const {
|
[[nodiscard]] float GetLossChange() const {
|
||||||
return static_cast<Impl const*>(this)->split.loss_chg;
|
return static_cast<Impl const*>(this)->split.loss_chg;
|
||||||
|
|||||||
109
src/tree/hist/hist_cache.h
Normal file
109
src/tree/hist/hist_cache.h
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_HIST_HIST_CACHE_H_
|
||||||
|
#define XGBOOST_TREE_HIST_HIST_CACHE_H_
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <map> // for map
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../common/hist_util.h" // for GHistRow, ConstGHistRow
|
||||||
|
#include "xgboost/base.h" // for bst_node_t, bst_bin_t
|
||||||
|
#include "xgboost/logging.h" // for CHECK_GT
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::tree {
|
||||||
|
/**
|
||||||
|
* @brief A persistent cache for CPU histogram.
|
||||||
|
*
|
||||||
|
* The size of the cache is first bounded by the `Driver` class then by this cache
|
||||||
|
* implementaiton. The former limits the number of nodes that can be built for each node
|
||||||
|
* batch, while this cache limits the number of all nodes up to the size of
|
||||||
|
* max(|node_batch|, n_cached_node).
|
||||||
|
*
|
||||||
|
* The caller is responsible for clearing up the cache as it needs to rearrange the
|
||||||
|
* nodes before making overflowed allocations. The strcut only reports whether the size
|
||||||
|
* limit has benn reached.
|
||||||
|
*/
|
||||||
|
class BoundedHistCollection {
|
||||||
|
// maps node index to offset in `data_`.
|
||||||
|
std::map<bst_node_t, std::size_t> node_map_;
|
||||||
|
// currently allocated bins, used for tracking consistentcy.
|
||||||
|
std::size_t current_size_{0};
|
||||||
|
|
||||||
|
// stores the histograms in a contiguous buffer
|
||||||
|
std::vector<GradientPairPrecise> data_;
|
||||||
|
|
||||||
|
// number of histogram bins across all features
|
||||||
|
bst_bin_t n_total_bins_{0};
|
||||||
|
// limits the number of nodes that can be in the cache for each tree
|
||||||
|
std::size_t n_cached_nodes_{0};
|
||||||
|
// whether the tree has grown beyond the cache limit
|
||||||
|
bool has_exceeded_{false};
|
||||||
|
|
||||||
|
public:
|
||||||
|
common::GHistRow operator[](std::size_t idx) {
|
||||||
|
auto offset = node_map_.at(idx);
|
||||||
|
return common::Span{data_.data(), data_.size()}.subspan(offset, n_total_bins_);
|
||||||
|
}
|
||||||
|
common::ConstGHistRow operator[](std::size_t idx) const {
|
||||||
|
auto offset = node_map_.at(idx);
|
||||||
|
return common::Span{data_.data(), data_.size()}.subspan(offset, n_total_bins_);
|
||||||
|
}
|
||||||
|
void Reset(bst_bin_t n_total_bins, std::size_t n_cached_nodes) {
|
||||||
|
n_total_bins_ = n_total_bins;
|
||||||
|
n_cached_nodes_ = n_cached_nodes;
|
||||||
|
this->Clear(false);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Clear the cache, mark whether the cache is exceeded the limit.
|
||||||
|
*/
|
||||||
|
void Clear(bool exceeded) {
|
||||||
|
node_map_.clear();
|
||||||
|
current_size_ = 0;
|
||||||
|
has_exceeded_ = exceeded;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool CanHost(common::Span<bst_node_t const> nodes_to_build,
|
||||||
|
common::Span<bst_node_t const> nodes_to_sub) const {
|
||||||
|
auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size();
|
||||||
|
return n_new_nodes + node_map_.size() <= n_cached_nodes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Allocate histogram buffers for all nodes.
|
||||||
|
*
|
||||||
|
* The resulting histogram buffer is contiguous for all nodes in the order of
|
||||||
|
* allocation.
|
||||||
|
*/
|
||||||
|
void AllocateHistograms(common::Span<bst_node_t const> nodes_to_build,
|
||||||
|
common::Span<bst_node_t const> nodes_to_sub) {
|
||||||
|
auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size();
|
||||||
|
auto alloc_size = n_new_nodes * n_total_bins_;
|
||||||
|
auto new_size = alloc_size + current_size_;
|
||||||
|
if (new_size > data_.size()) {
|
||||||
|
data_.resize(new_size);
|
||||||
|
}
|
||||||
|
for (auto nidx : nodes_to_build) {
|
||||||
|
node_map_[nidx] = current_size_;
|
||||||
|
current_size_ += n_total_bins_;
|
||||||
|
}
|
||||||
|
for (auto nidx : nodes_to_sub) {
|
||||||
|
node_map_[nidx] = current_size_;
|
||||||
|
current_size_ += n_total_bins_;
|
||||||
|
}
|
||||||
|
CHECK_EQ(current_size_, new_size);
|
||||||
|
}
|
||||||
|
void AllocateHistograms(std::vector<bst_node_t> const& nodes) {
|
||||||
|
this->AllocateHistograms(common::Span<bst_node_t const>{nodes},
|
||||||
|
common::Span<bst_node_t const>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool HasExceeded() const { return has_exceeded_; }
|
||||||
|
[[nodiscard]] bool HistogramExists(bst_node_t nidx) const {
|
||||||
|
return node_map_.find(nidx) != node_map_.cend();
|
||||||
|
}
|
||||||
|
[[nodiscard]] std::size_t Size() const { return current_size_; }
|
||||||
|
};
|
||||||
|
} // namespace xgboost::tree
|
||||||
|
#endif // XGBOOST_TREE_HIST_HIST_CACHE_H_
|
||||||
63
src/tree/hist/histogram.cc
Normal file
63
src/tree/hist/histogram.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "histogram.h"
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <numeric> // for accumulate
|
||||||
|
#include <utility> // for swap
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||||
|
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||||
|
#include "xgboost/logging.h" // for CHECK_NE
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
|
|
||||||
|
namespace xgboost::tree {
|
||||||
|
void AssignNodes(RegTree const *p_tree, std::vector<MultiExpandEntry> const &valid_candidates,
|
||||||
|
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub) {
|
||||||
|
CHECK_EQ(nodes_to_build.size(), valid_candidates.size());
|
||||||
|
|
||||||
|
std::size_t n_idx = 0;
|
||||||
|
for (auto const &c : valid_candidates) {
|
||||||
|
auto left_nidx = p_tree->LeftChild(c.nid);
|
||||||
|
auto right_nidx = p_tree->RightChild(c.nid);
|
||||||
|
|
||||||
|
auto build_nidx = left_nidx;
|
||||||
|
auto subtract_nidx = right_nidx;
|
||||||
|
auto lit =
|
||||||
|
common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); });
|
||||||
|
auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0);
|
||||||
|
auto rit =
|
||||||
|
common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); });
|
||||||
|
auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0);
|
||||||
|
auto fewer_right = right_sum < left_sum;
|
||||||
|
if (fewer_right) {
|
||||||
|
std::swap(build_nidx, subtract_nidx);
|
||||||
|
}
|
||||||
|
nodes_to_build[n_idx] = build_nidx;
|
||||||
|
nodes_to_sub[n_idx] = subtract_nidx;
|
||||||
|
++n_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssignNodes(RegTree const *p_tree, std::vector<CPUExpandEntry> const &candidates,
|
||||||
|
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub) {
|
||||||
|
std::size_t n_idx = 0;
|
||||||
|
for (auto const &c : candidates) {
|
||||||
|
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
||||||
|
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
||||||
|
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
|
||||||
|
|
||||||
|
auto build_nidx = left_nidx;
|
||||||
|
auto subtract_nidx = right_nidx;
|
||||||
|
if (fewer_right) {
|
||||||
|
std::swap(build_nidx, subtract_nidx);
|
||||||
|
}
|
||||||
|
nodes_to_build[n_idx] = build_nidx;
|
||||||
|
nodes_to_sub[n_idx] = subtract_nidx;
|
||||||
|
++n_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
@ -4,80 +4,85 @@
|
|||||||
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
|
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||||
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
|
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // for max
|
||||||
#include <limits>
|
#include <cstddef> // for size_t
|
||||||
#include <vector>
|
#include <cstdint> // for int32_t
|
||||||
|
#include <functional> // for function
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../collective/communicator-inl.h"
|
#include "../../collective/communicator-inl.h" // for Allreduce
|
||||||
#include "../../common/hist_util.h"
|
#include "../../collective/communicator.h" // for Operation
|
||||||
#include "../../data/gradient_index.h"
|
#include "../../common/hist_util.h" // for GHistRow, ParallelGHi...
|
||||||
#include "expand_entry.h"
|
#include "../../common/row_set.h" // for RowSetCollection
|
||||||
#include "xgboost/tree_model.h" // for RegTree
|
#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d
|
||||||
|
#include "../../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
|
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||||
|
#include "hist_cache.h" // for BoundedHistCollection
|
||||||
|
#include "param.h" // for HistMakerTrainParam
|
||||||
|
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/data.h" // for BatchIterator, BatchSet
|
||||||
|
#include "xgboost/linalg.h" // for MatrixView, All, Vect...
|
||||||
|
#include "xgboost/logging.h" // for CHECK_GE
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
template <typename ExpandEntry>
|
/**
|
||||||
|
* @brief Decide which node as the build node for multi-target trees.
|
||||||
|
*/
|
||||||
|
void AssignNodes(RegTree const *p_tree, std::vector<MultiExpandEntry> const &valid_candidates,
|
||||||
|
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Decide which node as the build node.
|
||||||
|
*/
|
||||||
|
void AssignNodes(RegTree const *p_tree, std::vector<CPUExpandEntry> const &candidates,
|
||||||
|
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub);
|
||||||
|
|
||||||
class HistogramBuilder {
|
class HistogramBuilder {
|
||||||
/*! \brief culmulative histogram of gradients. */
|
/*! \brief culmulative histogram of gradients. */
|
||||||
common::HistCollection hist_;
|
BoundedHistCollection hist_;
|
||||||
common::ParallelGHistBuilder buffer_;
|
common::ParallelGHistBuilder buffer_;
|
||||||
BatchParam param_;
|
BatchParam param_;
|
||||||
int32_t n_threads_{-1};
|
int32_t n_threads_{-1};
|
||||||
size_t n_batches_{0};
|
|
||||||
// Whether XGBoost is running in distributed environment.
|
// Whether XGBoost is running in distributed environment.
|
||||||
bool is_distributed_{false};
|
bool is_distributed_{false};
|
||||||
bool is_col_split_{false};
|
bool is_col_split_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* \param total_bins Total number of bins across all features
|
* @brief Reset the builder, should be called before growing a new tree.
|
||||||
* \param max_bin_per_feat Maximum number of bins per feature, same as the `max_bin`
|
*
|
||||||
* training parameter.
|
* @param total_bins Total number of bins across all features
|
||||||
* \param n_threads Number of threads.
|
* @param is_distributed Mostly used for testing to allow injecting parameters instead
|
||||||
* \param is_distributed Mostly used for testing to allow injecting parameters instead
|
|
||||||
* of using global rabit variable.
|
* of using global rabit variable.
|
||||||
*/
|
*/
|
||||||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
void Reset(Context const *ctx, bst_bin_t total_bins, BatchParam const &p, bool is_distributed,
|
||||||
bool is_distributed, bool is_col_split) {
|
bool is_col_split, HistMakerTrainParam const *param) {
|
||||||
CHECK_GE(n_threads, 1);
|
n_threads_ = ctx->Threads();
|
||||||
n_threads_ = n_threads;
|
|
||||||
n_batches_ = n_batches;
|
|
||||||
param_ = p;
|
param_ = p;
|
||||||
hist_.Init(total_bins);
|
hist_.Reset(total_bins, param->internal_max_cached_hist_node);
|
||||||
buffer_.Init(total_bins);
|
buffer_.Init(total_bins);
|
||||||
is_distributed_ = is_distributed;
|
is_distributed_ = is_distributed;
|
||||||
is_col_split_ = is_col_split;
|
is_col_split_ = is_col_split;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void BuildLocalHistograms(size_t page_idx, common::BlockedSpace2d space,
|
void BuildLocalHistograms(common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
|
||||||
GHistIndexMatrix const &gidx,
|
std::vector<bst_node_t> const &nodes_to_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
|
||||||
common::RowSetCollection const &row_set_collection,
|
common::RowSetCollection const &row_set_collection,
|
||||||
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
|
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
|
||||||
CHECK_GT(n_nodes, 0);
|
|
||||||
|
|
||||||
std::vector<common::GHistRow> target_hists(n_nodes);
|
|
||||||
for (size_t i = 0; i < n_nodes; ++i) {
|
|
||||||
auto const nidx = nodes_for_explicit_hist_build[i].nid;
|
|
||||||
target_hists[i] = hist_[nidx];
|
|
||||||
}
|
|
||||||
if (page_idx == 0) {
|
|
||||||
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
|
|
||||||
// partition size for the buffer, which might not be efficient if partition sizes
|
|
||||||
// has significant variance.
|
|
||||||
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parallel processing by nodes and data in each node
|
// Parallel processing by nodes and data in each node
|
||||||
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
|
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
|
||||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||||
const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid;
|
bst_node_t const nidx = nodes_to_build[nid_in_set];
|
||||||
auto elem = row_set_collection[nid];
|
auto elem = row_set_collection[nidx];
|
||||||
auto start_of_row_set = std::min(r.begin(), elem.Size());
|
auto start_of_row_set = std::min(r.begin(), elem.Size());
|
||||||
auto end_of_row_set = std::min(r.end(), elem.Size());
|
auto end_of_row_set = std::min(r.end(), elem.Size());
|
||||||
auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set,
|
auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set,
|
||||||
elem.begin + end_of_row_set, nid);
|
elem.begin + end_of_row_set, nidx);
|
||||||
auto hist = buffer_.GetInitializedHist(tid, nid_in_set);
|
auto hist = buffer_.GetInitializedHist(tid, nid_in_set);
|
||||||
if (rid_set.Size() != 0) {
|
if (rid_set.Size() != 0) {
|
||||||
common::BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, force_read_by_column);
|
common::BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, force_read_by_column);
|
||||||
@ -85,117 +90,143 @@ class HistogramBuilder {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddHistRows(int *starting_index,
|
/**
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
* @brief Allocate histogram, rearrange the nodes if `rearrange` is true and the tree
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
* has reached the cache size limit.
|
||||||
for (auto const &entry : nodes_for_explicit_hist_build) {
|
*/
|
||||||
int nid = entry.nid;
|
void AddHistRows(RegTree const *p_tree, std::vector<bst_node_t> *p_nodes_to_build,
|
||||||
this->hist_.AddHistRow(nid);
|
std::vector<bst_node_t> *p_nodes_to_sub, bool rearrange) {
|
||||||
(*starting_index) = std::min(nid, (*starting_index));
|
CHECK(p_nodes_to_build);
|
||||||
|
auto &nodes_to_build = *p_nodes_to_build;
|
||||||
|
CHECK(p_nodes_to_sub);
|
||||||
|
auto &nodes_to_sub = *p_nodes_to_sub;
|
||||||
|
|
||||||
|
// We first check whether the cache size is already exceeded or about to be exceeded.
|
||||||
|
// If not, then we can allocate histograms without clearing the cache and without
|
||||||
|
// worrying about missing parent histogram.
|
||||||
|
//
|
||||||
|
// Otherwise, we need to rearrange the nodes before the allocation to make sure the
|
||||||
|
// resulting buffer is contiguous. This is to facilitate efficient allreduce.
|
||||||
|
|
||||||
|
bool can_host = this->hist_.CanHost(nodes_to_build, nodes_to_sub);
|
||||||
|
// True if the tree is still within the size of cache limit. Allocate histogram as
|
||||||
|
// usual.
|
||||||
|
auto cache_is_valid = can_host && !this->hist_.HasExceeded();
|
||||||
|
|
||||||
|
if (!can_host) {
|
||||||
|
this->hist_.Clear(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto const &node : nodes_for_subtraction_trick) {
|
if (!rearrange || cache_is_valid) {
|
||||||
this->hist_.AddHistRow(node.nid);
|
// If not rearrange, we allocate the histogram as usual, assuming the nodes have
|
||||||
}
|
// been properly arranged by other builders.
|
||||||
this->hist_.AllocateAllData();
|
this->hist_.AllocateHistograms(nodes_to_build, nodes_to_sub);
|
||||||
}
|
if (rearrange) {
|
||||||
|
CHECK(!this->hist_.HasExceeded());
|
||||||
/** Main entry point of this class, build histogram for tree nodes. */
|
}
|
||||||
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
|
|
||||||
RegTree const *p_tree, common::RowSetCollection const &row_set_collection,
|
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
|
||||||
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
|
|
||||||
int starting_index = std::numeric_limits<int>::max();
|
|
||||||
if (page_id == 0) {
|
|
||||||
this->AddHistRows(&starting_index, nodes_for_explicit_hist_build,
|
|
||||||
nodes_for_subtraction_trick);
|
|
||||||
}
|
|
||||||
if (gidx.IsDense()) {
|
|
||||||
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
|
|
||||||
row_set_collection, gpair, force_read_by_column);
|
|
||||||
} else {
|
|
||||||
this->BuildLocalHistograms<true>(page_id, space, gidx, nodes_for_explicit_hist_build,
|
|
||||||
row_set_collection, gpair, force_read_by_column);
|
|
||||||
}
|
|
||||||
|
|
||||||
CHECK_GE(n_batches_, 1);
|
|
||||||
if (page_id != n_batches_ - 1) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this->SyncHistogram(p_tree, nodes_for_explicit_hist_build,
|
// The cache is full, parent histogram might be removed in previous iterations to
|
||||||
nodes_for_subtraction_trick, starting_index);
|
// saved memory.
|
||||||
}
|
std::vector<bst_node_t> can_subtract;
|
||||||
/** same as the other build hist but handles only single batch data (in-core) */
|
for (auto const &v : nodes_to_sub) {
|
||||||
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
|
if (this->hist_.HistogramExists(p_tree->Parent(v))) {
|
||||||
common::RowSetCollection const &row_set_collection,
|
// We can still use the subtraction trick for this node
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
can_subtract.push_back(v);
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
} else {
|
||||||
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
|
// This node requires a full build
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
nodes_to_build.push_back(v);
|
||||||
// create space of size (# rows in each node)
|
}
|
||||||
common::BlockedSpace2d space(
|
|
||||||
n_nodes,
|
|
||||||
[&](size_t nidx_in_set) {
|
|
||||||
const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid;
|
|
||||||
return row_set_collection[nidx].Size();
|
|
||||||
},
|
|
||||||
256);
|
|
||||||
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build,
|
|
||||||
nodes_for_subtraction_trick, gpair, force_read_by_column);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SyncHistogram(RegTree const *p_tree,
|
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
|
||||||
int starting_index) {
|
|
||||||
auto n_bins = buffer_.TotalBins();
|
|
||||||
common::BlockedSpace2d space(
|
|
||||||
nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024);
|
|
||||||
CHECK(hist_.IsContiguous());
|
|
||||||
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
|
||||||
const auto &entry = nodes_for_explicit_hist_build[node];
|
|
||||||
auto this_hist = this->hist_[entry.nid];
|
|
||||||
// Merging histograms from each thread into once
|
|
||||||
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
|
||||||
});
|
|
||||||
|
|
||||||
if (is_distributed_ && !is_col_split_) {
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
|
||||||
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
|
||||||
n_bins * nodes_for_explicit_hist_build.size() * 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
|
nodes_to_sub = std::move(can_subtract);
|
||||||
const auto &entry = nodes_for_explicit_hist_build[nidx_in_set];
|
this->hist_.AllocateHistograms(nodes_to_build, nodes_to_sub);
|
||||||
auto this_hist = this->hist_[entry.nid];
|
}
|
||||||
if (!p_tree->IsRoot(entry.nid)) {
|
|
||||||
auto const parent_id = p_tree->Parent(entry.nid);
|
/** Main entry point of this class, build histogram for tree nodes. */
|
||||||
auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid;
|
void BuildHist(std::size_t page_idx, common::BlockedSpace2d const &space,
|
||||||
auto parent_hist = this->hist_[parent_id];
|
GHistIndexMatrix const &gidx, common::RowSetCollection const &row_set_collection,
|
||||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
std::vector<bst_node_t> const &nodes_to_build,
|
||||||
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
linalg::VectorView<GradientPair const> gpair, bool force_read_by_column = false) {
|
||||||
|
CHECK(gpair.Contiguous());
|
||||||
|
|
||||||
|
if (page_idx == 0) {
|
||||||
|
// Add the local histogram cache to the parallel buffer before processing the first page.
|
||||||
|
auto n_nodes = nodes_to_build.size();
|
||||||
|
std::vector<common::GHistRow> target_hists(n_nodes);
|
||||||
|
for (size_t i = 0; i < n_nodes; ++i) {
|
||||||
|
auto const nidx = nodes_to_build[i];
|
||||||
|
target_hists[i] = hist_[nidx];
|
||||||
}
|
}
|
||||||
|
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gidx.IsDense()) {
|
||||||
|
this->BuildLocalHistograms<false>(space, gidx, nodes_to_build, row_set_collection,
|
||||||
|
gpair.Values(), force_read_by_column);
|
||||||
|
} else {
|
||||||
|
this->BuildLocalHistograms<true>(space, gidx, nodes_to_build, row_set_collection,
|
||||||
|
gpair.Values(), force_read_by_column);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SyncHistogram(RegTree const *p_tree, std::vector<bst_node_t> const &nodes_to_build,
|
||||||
|
std::vector<bst_node_t> const &nodes_to_trick) {
|
||||||
|
auto n_total_bins = buffer_.TotalBins();
|
||||||
|
common::BlockedSpace2d space(
|
||||||
|
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
|
||||||
|
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
||||||
|
// Merging histograms from each thread.
|
||||||
|
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
||||||
});
|
});
|
||||||
|
if (is_distributed_ && !is_col_split_) {
|
||||||
|
// The cache is contiguous, we can perform allreduce for all nodes in one go.
|
||||||
|
CHECK(!nodes_to_build.empty());
|
||||||
|
auto first_nidx = nodes_to_build.front();
|
||||||
|
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
|
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
|
||||||
|
}
|
||||||
|
|
||||||
|
common::BlockedSpace2d const &subspace =
|
||||||
|
nodes_to_trick.size() == nodes_to_build.size()
|
||||||
|
? space
|
||||||
|
: common::BlockedSpace2d{nodes_to_trick.size(),
|
||||||
|
[&](std::size_t) { return n_total_bins; }, 1024};
|
||||||
|
common::ParallelFor2d(
|
||||||
|
subspace, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
|
||||||
|
auto subtraction_nidx = nodes_to_trick[nidx_in_set];
|
||||||
|
auto parent_id = p_tree->Parent(subtraction_nidx);
|
||||||
|
auto sibling_nidx = p_tree->IsLeftChild(subtraction_nidx) ? p_tree->RightChild(parent_id)
|
||||||
|
: p_tree->LeftChild(parent_id);
|
||||||
|
auto sibling_hist = this->hist_[sibling_nidx];
|
||||||
|
auto parent_hist = this->hist_[parent_id];
|
||||||
|
auto subtract_hist = this->hist_[subtraction_nidx];
|
||||||
|
common::SubtractionHist(subtract_hist, parent_hist, sibling_hist, r.begin(), r.end());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/* Getters for tests. */
|
/* Getters for tests. */
|
||||||
common::HistCollection const &Histogram() { return hist_; }
|
[[nodiscard]] BoundedHistCollection const &Histogram() const { return hist_; }
|
||||||
|
[[nodiscard]] BoundedHistCollection &Histogram() { return hist_; }
|
||||||
auto &Buffer() { return buffer_; }
|
auto &Buffer() { return buffer_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Construct a work space for building histogram. Eventually we should move this
|
// Construct a work space for building histogram. Eventually we should move this
|
||||||
// function into histogram builder once hist tree method supports external memory.
|
// function into histogram builder once hist tree method supports external memory.
|
||||||
template <typename Partitioner, typename ExpandEntry = CPUExpandEntry>
|
template <typename Partitioner>
|
||||||
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
||||||
std::vector<ExpandEntry> const &nodes_to_build) {
|
std::vector<bst_node_t> const &nodes_to_build) {
|
||||||
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
|
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
|
||||||
|
// partition size for the buffer, which might not be efficient if partition sizes
|
||||||
|
// has significant variance.
|
||||||
|
std::vector<std::size_t> partition_size(nodes_to_build.size(), 0);
|
||||||
for (auto const &partition : partitioners) {
|
for (auto const &partition : partitioners) {
|
||||||
size_t k = 0;
|
size_t k = 0;
|
||||||
for (auto node : nodes_to_build) {
|
for (auto nidx : nodes_to_build) {
|
||||||
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
|
auto n_rows_in_node = partition.Partitions()[nidx].Size();
|
||||||
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
|
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
|
||||||
k++;
|
k++;
|
||||||
}
|
}
|
||||||
@ -204,5 +235,107 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
|||||||
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
|
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
|
||||||
return space;
|
return space;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Histogram builder that can handle multiple targets.
|
||||||
|
*/
|
||||||
|
class MultiHistogramBuilder {
|
||||||
|
std::vector<HistogramBuilder> target_builders_;
|
||||||
|
Context const *ctx_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Build the histogram for root node.
|
||||||
|
*/
|
||||||
|
template <typename Partitioner, typename ExpandEntry>
|
||||||
|
void BuildRootHist(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
|
std::vector<Partitioner> const &partitioners,
|
||||||
|
linalg::MatrixView<GradientPair const> gpair, ExpandEntry const &best,
|
||||||
|
BatchParam const ¶m, bool force_read_by_column = false) {
|
||||||
|
auto n_targets = p_tree->NumTargets();
|
||||||
|
CHECK_EQ(gpair.Shape(1), n_targets);
|
||||||
|
CHECK_EQ(p_fmat->Info().num_row_, gpair.Shape(0));
|
||||||
|
CHECK_EQ(target_builders_.size(), n_targets);
|
||||||
|
std::vector<bst_node_t> nodes{best.nid};
|
||||||
|
std::vector<bst_node_t> dummy_sub;
|
||||||
|
|
||||||
|
auto space = ConstructHistSpace(partitioners, nodes);
|
||||||
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
|
this->target_builders_[t].AddHistRows(p_tree, &nodes, &dummy_sub, false);
|
||||||
|
}
|
||||||
|
CHECK(dummy_sub.empty());
|
||||||
|
|
||||||
|
std::size_t page_idx{0};
|
||||||
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, param)) {
|
||||||
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
|
this->target_builders_[t].BuildHist(page_idx, space, gidx,
|
||||||
|
partitioners[page_idx].Partitions(), nodes, t_gpair,
|
||||||
|
force_read_by_column);
|
||||||
|
}
|
||||||
|
++page_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
|
this->target_builders_[t].SyncHistogram(p_tree, nodes, dummy_sub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Build histogram for left and right child of valid candidates
|
||||||
|
*/
|
||||||
|
template <typename Partitioner, typename ExpandEntry>
|
||||||
|
void BuildHistLeftRight(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
|
std::vector<Partitioner> const &partitioners,
|
||||||
|
std::vector<ExpandEntry> const &valid_candidates,
|
||||||
|
linalg::MatrixView<GradientPair const> gpair, BatchParam const ¶m,
|
||||||
|
bool force_read_by_column = false) {
|
||||||
|
std::vector<bst_node_t> nodes_to_build(valid_candidates.size());
|
||||||
|
std::vector<bst_node_t> nodes_to_sub(valid_candidates.size());
|
||||||
|
AssignNodes(p_tree, valid_candidates, nodes_to_build, nodes_to_sub);
|
||||||
|
|
||||||
|
// use the first builder for getting number of valid nodes.
|
||||||
|
target_builders_.front().AddHistRows(p_tree, &nodes_to_build, &nodes_to_sub, true);
|
||||||
|
CHECK_GE(nodes_to_build.size(), nodes_to_sub.size());
|
||||||
|
CHECK_EQ(nodes_to_sub.size() + nodes_to_build.size(), valid_candidates.size() * 2);
|
||||||
|
|
||||||
|
// allocate storage for the rest of the builders
|
||||||
|
for (bst_target_t t = 1; t < target_builders_.size(); ++t) {
|
||||||
|
target_builders_[t].AddHistRows(p_tree, &nodes_to_build, &nodes_to_sub, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto space = ConstructHistSpace(partitioners, nodes_to_build);
|
||||||
|
std::size_t page_idx{0};
|
||||||
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, param)) {
|
||||||
|
CHECK_EQ(gpair.Shape(1), p_tree->NumTargets());
|
||||||
|
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
|
CHECK_EQ(t_gpair.Shape(0), p_fmat->Info().num_row_);
|
||||||
|
this->target_builders_[t].BuildHist(page_idx, space, page,
|
||||||
|
partitioners[page_idx].Partitions(), nodes_to_build,
|
||||||
|
t_gpair, force_read_by_column);
|
||||||
|
}
|
||||||
|
page_idx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
|
this->target_builders_[t].SyncHistogram(p_tree, nodes_to_build, nodes_to_sub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] auto const &Histogram(bst_target_t t) const {
|
||||||
|
return target_builders_[t].Histogram();
|
||||||
|
}
|
||||||
|
[[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); }
|
||||||
|
|
||||||
|
void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p,
|
||||||
|
bool is_distributed, bool is_col_split, HistMakerTrainParam const *param) {
|
||||||
|
ctx_ = ctx;
|
||||||
|
target_builders_.resize(n_targets);
|
||||||
|
CHECK_GE(n_targets, 1);
|
||||||
|
for (auto &v : target_builders_) {
|
||||||
|
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_
|
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||||
|
|||||||
@ -2,12 +2,19 @@
|
|||||||
* Copyright 2021-2023, XGBoost Contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "xgboost/parameter.h"
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
|
#include "xgboost/parameter.h" // for XGBoostParameter
|
||||||
#include "xgboost/tree_model.h" // for RegTree
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
||||||
bool debug_synchronize;
|
constexpr static std::size_t DefaultNodes() { return static_cast<std::size_t>(1) << 16; }
|
||||||
|
|
||||||
|
bool debug_synchronize{false};
|
||||||
|
std::size_t internal_max_cached_hist_node{DefaultNodes()};
|
||||||
|
|
||||||
void CheckTreesSynchronized(RegTree const* local_tree) const;
|
void CheckTreesSynchronized(RegTree const* local_tree) const;
|
||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
@ -15,6 +22,10 @@ struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(debug_synchronize)
|
DMLC_DECLARE_FIELD(debug_synchronize)
|
||||||
.set_default(false)
|
.set_default(false)
|
||||||
.describe("Check if all distributed tree are identical after tree construction.");
|
.describe("Check if all distributed tree are identical after tree construction.");
|
||||||
|
DMLC_DECLARE_FIELD(internal_max_cached_hist_node)
|
||||||
|
.set_default(DefaultNodes())
|
||||||
|
.set_lower_bound(1)
|
||||||
|
.describe("Maximum number of nodes in CPU histogram cache. Only for internal usage.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -526,7 +526,7 @@ struct SplitEntryContainer {
|
|||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
bool Update(bst_float new_loss_chg, unsigned split_index, bst_float new_split_value,
|
bool Update(bst_float new_loss_chg, bst_feature_t split_index, float new_split_value,
|
||||||
bool default_left, bool is_cat, GradientSumT const &left_sum,
|
bool default_left, bool is_cat, GradientSumT const &left_sum,
|
||||||
GradientSumT const &right_sum) {
|
GradientSumT const &right_sum) {
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
|
|||||||
@ -3,27 +3,39 @@
|
|||||||
*
|
*
|
||||||
* \brief Implementation for the approx tree method.
|
* \brief Implementation for the approx tree method.
|
||||||
*/
|
*/
|
||||||
#include <algorithm>
|
#include <algorithm> // for max, transform, fill_n
|
||||||
#include <memory>
|
#include <cstddef> // for size_t
|
||||||
#include <vector>
|
#include <map> // for map
|
||||||
|
#include <memory> // for allocator, unique_ptr, make_shared, make_unique
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/aggregator.h"
|
#include "../collective/aggregator.h" // for GlobalSum
|
||||||
#include "../common/random.h"
|
#include "../collective/communicator-inl.h" // for IsDistributed
|
||||||
#include "../data/gradient_index.h"
|
#include "../common/hist_util.h" // for HistogramCuts
|
||||||
#include "common_row_partitioner.h"
|
#include "../common/random.h" // for ColumnSampler
|
||||||
#include "driver.h"
|
#include "../common/timer.h" // for Monitor
|
||||||
#include "hist/evaluate_splits.h"
|
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "hist/histogram.h"
|
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||||
#include "hist/param.h"
|
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||||
#include "hist/sampler.h" // for SampleGradient
|
#include "driver.h" // for Driver
|
||||||
#include "param.h" // for HistMakerTrainParam
|
#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl
|
||||||
#include "xgboost/base.h"
|
#include "hist/expand_entry.h" // for CPUExpandEntry
|
||||||
#include "xgboost/data.h"
|
#include "hist/histogram.h" // for MultiHistogramBuilder
|
||||||
#include "xgboost/json.h"
|
#include "hist/param.h" // for HistMakerTrainParam
|
||||||
#include "xgboost/linalg.h"
|
#include "hist/sampler.h" // for SampleGradient
|
||||||
#include "xgboost/task.h" // for ObjInfo
|
#include "param.h" // for GradStats, TrainParam
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/base.h" // for Args, GradientPair, bst_node_t, bst_bin_t
|
||||||
#include "xgboost/tree_updater.h" // for TreeUpdater
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/data.h" // for DMatrix, BatchSet, BatchIterator, MetaInfo
|
||||||
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
|
#include "xgboost/json.h" // for Object, Json, FromJson, ToJson, get
|
||||||
|
#include "xgboost/linalg.h" // for Matrix, MakeTensorView, Empty, MatrixView
|
||||||
|
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
#include "xgboost/task.h" // for ObjInfo
|
||||||
|
#include "xgboost/tree_model.h" // for RegTree, RTreeNodeStat
|
||||||
|
#include "xgboost/tree_updater.h" // for TreeUpdater, TreeUpdaterReg, XGBOOST_REGISTE...
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
|
|
||||||
@ -46,7 +58,7 @@ class GloablApproxBuilder {
|
|||||||
HistMakerTrainParam const *hist_param_{nullptr};
|
HistMakerTrainParam const *hist_param_{nullptr};
|
||||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
HistEvaluator evaluator_;
|
HistEvaluator evaluator_;
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
MultiHistogramBuilder histogram_builder_;
|
||||||
Context const *ctx_;
|
Context const *ctx_;
|
||||||
ObjInfo const *const task_;
|
ObjInfo const *const task_;
|
||||||
|
|
||||||
@ -59,7 +71,7 @@ class GloablApproxBuilder {
|
|||||||
common::HistogramCuts feature_values_;
|
common::HistogramCuts feature_values_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
|
void InitData(DMatrix *p_fmat, RegTree const *p_tree, common::Span<float> hess) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
|
|
||||||
n_batches_ = 0;
|
n_batches_ = 0;
|
||||||
@ -79,8 +91,9 @@ class GloablApproxBuilder {
|
|||||||
n_batches_++;
|
n_batches_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
|
||||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
|
||||||
|
hist_param_);
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,20 +109,16 @@ class GloablApproxBuilder {
|
|||||||
}
|
}
|
||||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
|
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
|
||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_,
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1),
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
|
best, BatchSpec(*param_, hess));
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
|
||||||
{}, gpair);
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto weight = evaluator_.InitRoot(root_sum);
|
auto weight = evaluator_.InitRoot(root_sum);
|
||||||
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
|
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
|
||||||
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
||||||
(*p_tree)[RegTree::kRoot].SetLeaf(param_->learning_rate * weight);
|
(*p_tree)[RegTree::kRoot].SetLeaf(param_->learning_rate * weight);
|
||||||
|
|
||||||
auto const &histograms = histogram_builder_.Histogram();
|
auto const &histograms = histogram_builder_.Histogram(0);
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
|
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
@ -130,30 +139,9 @@ class GloablApproxBuilder {
|
|||||||
std::vector<CPUExpandEntry> const &valid_candidates,
|
std::vector<CPUExpandEntry> const &valid_candidates,
|
||||||
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
|
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::vector<CPUExpandEntry> nodes_to_build;
|
this->histogram_builder_.BuildHistLeftRight(
|
||||||
std::vector<CPUExpandEntry> nodes_to_sub;
|
p_fmat, p_tree, partitioner_, valid_candidates,
|
||||||
|
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), BatchSpec(*param_, hess));
|
||||||
for (auto const &c : valid_candidates) {
|
|
||||||
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
|
||||||
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
|
||||||
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
|
|
||||||
|
|
||||||
auto build_nidx = left_nidx;
|
|
||||||
auto subtract_nidx = right_nidx;
|
|
||||||
if (fewer_right) {
|
|
||||||
std::swap(build_nidx, subtract_nidx);
|
|
||||||
}
|
|
||||||
nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}});
|
|
||||||
nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}});
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t i = 0;
|
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
|
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
|
||||||
nodes_to_build, nodes_to_sub, gpair);
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,7 +173,7 @@ class GloablApproxBuilder {
|
|||||||
void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
|
void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
|
||||||
RegTree *p_tree, HostDeviceVector<bst_node_t> *p_out_position) {
|
RegTree *p_tree, HostDeviceVector<bst_node_t> *p_out_position) {
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
this->InitData(p_fmat, hess);
|
this->InitData(p_fmat, p_tree, hess);
|
||||||
|
|
||||||
Driver<CPUExpandEntry> driver(*param_);
|
Driver<CPUExpandEntry> driver(*param_);
|
||||||
auto &tree = *p_tree;
|
auto &tree = *p_tree;
|
||||||
@ -235,7 +223,7 @@ class GloablApproxBuilder {
|
|||||||
best_splits.push_back(l_best);
|
best_splits.push_back(l_best);
|
||||||
best_splits.push_back(r_best);
|
best_splits.push_back(r_best);
|
||||||
}
|
}
|
||||||
auto const &histograms = histogram_builder_.Histogram();
|
auto const &histograms = histogram_builder_.Histogram(0);
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
monitor_->Start("EvaluateSplits");
|
monitor_->Start("EvaluateSplits");
|
||||||
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
|
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
|
||||||
|
|||||||
@ -7,35 +7,37 @@
|
|||||||
#include <algorithm> // for max, copy, transform
|
#include <algorithm> // for max, copy, transform
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for uint32_t, int32_t
|
#include <cstdint> // for uint32_t, int32_t
|
||||||
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
|
#include <exception> // for exception
|
||||||
#include <numeric> // for accumulate
|
#include <memory> // for allocator, unique_ptr, make_unique, shared_ptr
|
||||||
#include <ostream> // for basic_ostream, char_traits, operator<<
|
#include <ostream> // for operator<<, basic_ostream, char_traits
|
||||||
#include <utility> // for move, swap
|
#include <utility> // for move
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/aggregator.h" // for GlobalSum
|
#include "../collective/aggregator.h" // for GlobalSum
|
||||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
#include "../collective/communicator-inl.h" // for IsDistributed
|
||||||
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
#include "../common/hist_util.h" // for HistogramCuts, GHistRow
|
||||||
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
||||||
#include "../common/random.h" // for ColumnSampler
|
#include "../common/random.h" // for ColumnSampler
|
||||||
#include "../common/threading_utils.h" // for ParallelFor
|
#include "../common/threading_utils.h" // for ParallelFor
|
||||||
#include "../common/timer.h" // for Monitor
|
#include "../common/timer.h" // for Monitor
|
||||||
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
|
#include "../common/transform_iterator.h" // for IndexTransformIter
|
||||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||||
#include "driver.h" // for Driver
|
#include "driver.h" // for Driver
|
||||||
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
|
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
|
||||||
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||||
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
|
#include "hist/hist_cache.h" // for BoundedHistCollection
|
||||||
|
#include "hist/histogram.h" // for MultiHistogramBuilder
|
||||||
#include "hist/param.h" // for HistMakerTrainParam
|
#include "hist/param.h" // for HistMakerTrainParam
|
||||||
#include "hist/sampler.h" // for SampleGradient
|
#include "hist/sampler.h" // for SampleGradient
|
||||||
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
|
#include "param.h" // for TrainParam, GradStats
|
||||||
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
|
#include "xgboost/base.h" // for Args, GradientPairPrecise, GradientPair, Gra...
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo
|
#include "xgboost/data.h" // for BatchSet, DMatrix, BatchIterator, MetaInfo
|
||||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty
|
#include "xgboost/json.h" // for Object, Json, FromJson, ToJson, get
|
||||||
|
#include "xgboost/linalg.h" // for MatrixView, TensorView, All, Matrix, Empty
|
||||||
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE
|
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE
|
||||||
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
||||||
#include "xgboost/string_view.h" // for operator<<
|
#include "xgboost/string_view.h" // for operator<<
|
||||||
@ -120,7 +122,7 @@ class MultiTargetHistBuilder {
|
|||||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
std::unique_ptr<HistMultiEvaluator> evaluator_;
|
std::unique_ptr<HistMultiEvaluator> evaluator_;
|
||||||
// Histogram builder for each target.
|
// Histogram builder for each target.
|
||||||
std::vector<HistogramBuilder<MultiExpandEntry>> histogram_builder_;
|
std::unique_ptr<MultiHistogramBuilder> histogram_builder_;
|
||||||
Context const *ctx_{nullptr};
|
Context const *ctx_{nullptr};
|
||||||
// Partitioner for each data batch.
|
// Partitioner for each data batch.
|
||||||
std::vector<CommonRowPartitioner> partitioner_;
|
std::vector<CommonRowPartitioner> partitioner_;
|
||||||
@ -150,7 +152,6 @@ class MultiTargetHistBuilder {
|
|||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
|
|
||||||
p_last_fmat_ = p_fmat;
|
p_last_fmat_ = p_fmat;
|
||||||
std::size_t page_id = 0;
|
|
||||||
bst_bin_t n_total_bins = 0;
|
bst_bin_t n_total_bins = 0;
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
@ -160,16 +161,13 @@ class MultiTargetHistBuilder {
|
|||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
|
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
|
||||||
page_id++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_target_t n_targets = p_tree->NumTargets();
|
bst_target_t n_targets = p_tree->NumTargets();
|
||||||
histogram_builder_.clear();
|
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
|
||||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
|
||||||
histogram_builder_.emplace_back();
|
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
|
||||||
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
hist_param_);
|
||||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
|
||||||
}
|
|
||||||
|
|
||||||
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
@ -204,17 +202,7 @@ class MultiTargetHistBuilder {
|
|||||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
|
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
|
||||||
root_sum.Size() * 2);
|
root_sum.Size() * 2);
|
||||||
|
|
||||||
std::vector<MultiExpandEntry> nodes{best};
|
histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_));
|
||||||
std::size_t i = 0;
|
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
|
||||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
|
||||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
|
||||||
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
|
||||||
nodes, {}, t_gpair.Values());
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto weight = evaluator_->InitRoot(root_sum);
|
auto weight = evaluator_->InitRoot(root_sum);
|
||||||
auto weight_t = weight.HostView();
|
auto weight_t = weight.HostView();
|
||||||
@ -222,9 +210,10 @@ class MultiTargetHistBuilder {
|
|||||||
[&](float w) { return w * param_->learning_rate; });
|
[&](float w) { return w * param_->learning_rate; });
|
||||||
|
|
||||||
p_tree->SetLeaf(RegTree::kRoot, weight_t);
|
p_tree->SetLeaf(RegTree::kRoot, weight_t);
|
||||||
std::vector<common::HistCollection const *> hists;
|
std::vector<BoundedHistCollection const *> hists;
|
||||||
|
std::vector<MultiExpandEntry> nodes{{RegTree::kRoot, 0}};
|
||||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
hists.push_back(&histogram_builder_[t].Histogram());
|
hists.push_back(&(*histogram_builder_).Histogram(t));
|
||||||
}
|
}
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
||||||
@ -239,50 +228,17 @@ class MultiTargetHistBuilder {
|
|||||||
std::vector<MultiExpandEntry> const &valid_candidates,
|
std::vector<MultiExpandEntry> const &valid_candidates,
|
||||||
linalg::MatrixView<GradientPair const> gpair) {
|
linalg::MatrixView<GradientPair const> gpair) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::vector<MultiExpandEntry> nodes_to_build;
|
histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates, gpair,
|
||||||
std::vector<MultiExpandEntry> nodes_to_sub;
|
HistBatch(param_));
|
||||||
|
|
||||||
for (auto const &c : valid_candidates) {
|
|
||||||
auto left_nidx = p_tree->LeftChild(c.nid);
|
|
||||||
auto right_nidx = p_tree->RightChild(c.nid);
|
|
||||||
|
|
||||||
auto build_nidx = left_nidx;
|
|
||||||
auto subtract_nidx = right_nidx;
|
|
||||||
auto lit =
|
|
||||||
common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); });
|
|
||||||
auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0);
|
|
||||||
auto rit =
|
|
||||||
common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); });
|
|
||||||
auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0);
|
|
||||||
auto fewer_right = right_sum < left_sum;
|
|
||||||
if (fewer_right) {
|
|
||||||
std::swap(build_nidx, subtract_nidx);
|
|
||||||
}
|
|
||||||
nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx));
|
|
||||||
nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::size_t i = 0;
|
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
|
||||||
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
|
|
||||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
|
||||||
// Make sure the gradient matrix is f-order.
|
|
||||||
CHECK(t_gpair.Contiguous());
|
|
||||||
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
|
||||||
nodes_to_build, nodes_to_sub, t_gpair.Values());
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
std::vector<MultiExpandEntry> *best_splits) {
|
std::vector<MultiExpandEntry> *best_splits) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::vector<common::HistCollection const *> hists;
|
std::vector<BoundedHistCollection const *> hists;
|
||||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
hists.push_back(&histogram_builder_[t].Histogram());
|
hists.push_back(&(*histogram_builder_).Histogram(t));
|
||||||
}
|
}
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
||||||
@ -349,7 +305,7 @@ class HistUpdater {
|
|||||||
const RegTree *p_last_tree_{nullptr};
|
const RegTree *p_last_tree_{nullptr};
|
||||||
DMatrix const *const p_last_fmat_{nullptr};
|
DMatrix const *const p_last_fmat_{nullptr};
|
||||||
|
|
||||||
std::unique_ptr<HistogramBuilder<CPUExpandEntry>> histogram_builder_;
|
std::unique_ptr<MultiHistogramBuilder> histogram_builder_;
|
||||||
ObjInfo const *task_{nullptr};
|
ObjInfo const *task_{nullptr};
|
||||||
// Context for number of threads
|
// Context for number of threads
|
||||||
Context const *ctx_{nullptr};
|
Context const *ctx_{nullptr};
|
||||||
@ -364,7 +320,7 @@ class HistUpdater {
|
|||||||
col_sampler_{std::move(column_sampler)},
|
col_sampler_{std::move(column_sampler)},
|
||||||
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(), col_sampler_)},
|
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(), col_sampler_)},
|
||||||
p_last_fmat_(fmat),
|
p_last_fmat_(fmat),
|
||||||
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
|
histogram_builder_{new MultiHistogramBuilder},
|
||||||
task_{task},
|
task_{task},
|
||||||
ctx_{ctx} {
|
ctx_{ctx} {
|
||||||
monitor_->Init(__func__);
|
monitor_->Init(__func__);
|
||||||
@ -387,7 +343,6 @@ class HistUpdater {
|
|||||||
// initialize temp data structure
|
// initialize temp data structure
|
||||||
void InitData(DMatrix *fmat, RegTree const *p_tree) {
|
void InitData(DMatrix *fmat, RegTree const *p_tree) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::size_t page_id{0};
|
|
||||||
bst_bin_t n_total_bins{0};
|
bst_bin_t n_total_bins{0};
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
@ -398,10 +353,9 @@ class HistUpdater {
|
|||||||
}
|
}
|
||||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||||
fmat->Info().IsColumnSplit());
|
fmat->Info().IsColumnSplit());
|
||||||
++page_id;
|
|
||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
|
||||||
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
fmat->Info().IsColumnSplit(), hist_param_);
|
||||||
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
|
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
@ -410,7 +364,7 @@ class HistUpdater {
|
|||||||
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
std::vector<CPUExpandEntry> *best_splits) {
|
std::vector<CPUExpandEntry> *best_splits) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
auto const &histograms = histogram_builder_->Histogram();
|
auto const &histograms = histogram_builder_->Histogram(0);
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits);
|
evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits);
|
||||||
@ -428,16 +382,8 @@ class HistUpdater {
|
|||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
||||||
|
|
||||||
std::size_t page_id = 0;
|
this->histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, node,
|
||||||
auto space = ConstructHistSpace(partitioner_, {node});
|
HistBatch(param_));
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
|
||||||
std::vector<CPUExpandEntry> nodes_to_build{node};
|
|
||||||
std::vector<CPUExpandEntry> nodes_to_sub;
|
|
||||||
this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
|
||||||
partitioner_.at(page_id).Partitions(), nodes_to_build,
|
|
||||||
nodes_to_sub, gpair.Slice(linalg::All(), 0).Values());
|
|
||||||
++page_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
GradientPairPrecise grad_stat;
|
GradientPairPrecise grad_stat;
|
||||||
@ -451,7 +397,7 @@ class HistUpdater {
|
|||||||
CHECK_GE(row_ptr.size(), 2);
|
CHECK_GE(row_ptr.size(), 2);
|
||||||
std::uint32_t const ibegin = row_ptr[0];
|
std::uint32_t const ibegin = row_ptr[0];
|
||||||
std::uint32_t const iend = row_ptr[1];
|
std::uint32_t const iend = row_ptr[1];
|
||||||
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
|
auto hist = this->histogram_builder_->Histogram(0)[RegTree::kRoot];
|
||||||
auto begin = hist.data();
|
auto begin = hist.data();
|
||||||
for (std::uint32_t i = ibegin; i < iend; ++i) {
|
for (std::uint32_t i = ibegin; i < iend; ++i) {
|
||||||
GradientPairPrecise const &et = begin[i];
|
GradientPairPrecise const &et = begin[i];
|
||||||
@ -474,7 +420,7 @@ class HistUpdater {
|
|||||||
monitor_->Start("EvaluateSplits");
|
monitor_->Start("EvaluateSplits");
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree,
|
evaluator_->EvaluateSplits(histogram_builder_->Histogram(0), gmat.cut, ft, *p_tree,
|
||||||
&entries);
|
&entries);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -490,33 +436,8 @@ class HistUpdater {
|
|||||||
std::vector<CPUExpandEntry> const &valid_candidates,
|
std::vector<CPUExpandEntry> const &valid_candidates,
|
||||||
linalg::MatrixView<GradientPair const> gpair) {
|
linalg::MatrixView<GradientPair const> gpair) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
|
this->histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates,
|
||||||
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
|
gpair, HistBatch(param_));
|
||||||
|
|
||||||
std::size_t n_idx = 0;
|
|
||||||
for (auto const &c : valid_candidates) {
|
|
||||||
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
|
||||||
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
|
||||||
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
|
|
||||||
|
|
||||||
auto build_nidx = left_nidx;
|
|
||||||
auto subtract_nidx = right_nidx;
|
|
||||||
if (fewer_right) {
|
|
||||||
std::swap(build_nidx, subtract_nidx);
|
|
||||||
}
|
|
||||||
nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}};
|
|
||||||
nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}};
|
|
||||||
n_idx++;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::size_t page_id{0};
|
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
|
||||||
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
|
||||||
partitioner_.at(page_id).Partitions(), nodes_to_build,
|
|
||||||
nodes_to_sub, gpair.Values());
|
|
||||||
++page_id;
|
|
||||||
}
|
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -27,8 +27,8 @@ void ParallelGHistBuilderReset() {
|
|||||||
|
|
||||||
for(size_t inode = 0; inode < kNodesExtended; inode++) {
|
for(size_t inode = 0; inode < kNodesExtended; inode++) {
|
||||||
collection.AddHistRow(inode);
|
collection.AddHistRow(inode);
|
||||||
|
collection.AllocateData(inode);
|
||||||
}
|
}
|
||||||
collection.AllocateAllData();
|
|
||||||
ParallelGHistBuilder hist_builder;
|
ParallelGHistBuilder hist_builder;
|
||||||
hist_builder.Init(kBins);
|
hist_builder.Init(kBins);
|
||||||
std::vector<GHistRow> target_hist(kNodes);
|
std::vector<GHistRow> target_hist(kNodes);
|
||||||
@ -83,8 +83,8 @@ void ParallelGHistBuilderReduceHist(){
|
|||||||
|
|
||||||
for(size_t inode = 0; inode < kNodes; inode++) {
|
for(size_t inode = 0; inode < kNodes; inode++) {
|
||||||
collection.AddHistRow(inode);
|
collection.AddHistRow(inode);
|
||||||
|
collection.AllocateData(inode);
|
||||||
}
|
}
|
||||||
collection.AllocateAllData();
|
|
||||||
ParallelGHistBuilder hist_builder;
|
ParallelGHistBuilder hist_builder;
|
||||||
hist_builder.Init(kBins);
|
hist_builder.Init(kBins);
|
||||||
std::vector<GHistRow> target_hist(kNodes);
|
std::vector<GHistRow> target_hist(kNodes);
|
||||||
@ -129,7 +129,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
|||||||
|
|
||||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
std::vector<bst_int> group(kNumGroups);
|
std::vector<bst_group_t> group(kNumGroups);
|
||||||
group[0] = 2;
|
group[0] = 2;
|
||||||
group[1] = 3;
|
group[1] = 3;
|
||||||
group[2] = 7;
|
group[2] = 7;
|
||||||
|
|||||||
@ -92,7 +92,7 @@ TEST(Learner, CheckGroup) {
|
|||||||
|
|
||||||
std::shared_ptr<DMatrix> p_mat{RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatrix()};
|
std::shared_ptr<DMatrix> p_mat{RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatrix()};
|
||||||
std::vector<bst_float> weight(kNumGroups, 1);
|
std::vector<bst_float> weight(kNumGroups, 1);
|
||||||
std::vector<bst_int> group(kNumGroups);
|
std::vector<bst_group_t> group(kNumGroups);
|
||||||
group[0] = 2;
|
group[0] = 2;
|
||||||
group[1] = 3;
|
group[1] = 3;
|
||||||
group[2] = 7;
|
group[2] = 7;
|
||||||
|
|||||||
@ -4,13 +4,13 @@
|
|||||||
#include "../test_evaluate_splits.h"
|
#include "../test_evaluate_splits.h"
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/base.h> // for GradientPairPrecise, Args, Gradie...
|
#include <xgboost/base.h> // for GradientPairPrecise, Args, Gradie...
|
||||||
#include <xgboost/context.h> // for Context
|
#include <xgboost/context.h> // for Context
|
||||||
#include <xgboost/data.h> // for FeatureType, DMatrix, MetaInfo
|
#include <xgboost/data.h> // for FeatureType, DMatrix, MetaInfo
|
||||||
#include <xgboost/logging.h> // for CHECK_EQ
|
#include <xgboost/logging.h> // for CHECK_EQ
|
||||||
#include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
|
#include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
|
||||||
|
|
||||||
#include <memory> // for make_shared, shared_ptr, addressof
|
#include <memory> // for make_shared, shared_ptr, addressof
|
||||||
|
|
||||||
#include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
|
#include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
|
||||||
#include "../../../../src/common/random.h" // for ColumnSampler
|
#include "../../../../src/common/random.h" // for ColumnSampler
|
||||||
@ -18,6 +18,8 @@
|
|||||||
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator
|
#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator
|
||||||
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
||||||
|
#include "../../../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
|
||||||
|
#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam
|
||||||
#include "../../../../src/tree/param.h" // for GradStats, TrainParam
|
#include "../../../../src/tree/param.h" // for GradStats, TrainParam
|
||||||
#include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
|
#include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
|||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
common::HistCollection hist;
|
BoundedHistCollection hist;
|
||||||
std::vector<GradientPair> row_gpairs = {
|
std::vector<GradientPair> row_gpairs = {
|
||||||
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}};
|
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}};
|
||||||
@ -48,9 +50,9 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
|||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
|
||||||
hist.Init(gmat.cut.Ptrs().back());
|
HistMakerTrainParam hist_param;
|
||||||
hist.AddHistRow(0);
|
hist.Reset(gmat.cut.Ptrs().back(), hist_param.internal_max_cached_hist_node);
|
||||||
hist.AllocateAllData();
|
hist.AllocateHistograms({0});
|
||||||
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);
|
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);
|
||||||
|
|
||||||
// Compute total gradient for all data points
|
// Compute total gradient for all data points
|
||||||
@ -111,13 +113,13 @@ TEST(HistMultiEvaluator, Evaluate) {
|
|||||||
RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true);
|
RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true);
|
||||||
|
|
||||||
HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), ¶m, sampler};
|
HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), ¶m, sampler};
|
||||||
std::vector<common::HistCollection> histogram(n_targets);
|
HistMakerTrainParam hist_param;
|
||||||
|
std::vector<BoundedHistCollection> histogram(n_targets);
|
||||||
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
|
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
|
||||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
auto &hist = histogram[t];
|
auto &hist = histogram[t];
|
||||||
hist.Init(n_bins * n_features);
|
hist.Reset(n_bins * n_features, hist_param.internal_max_cached_hist_node);
|
||||||
hist.AddHistRow(0);
|
hist.AllocateHistograms({0});
|
||||||
hist.AllocateAllData();
|
|
||||||
auto node_hist = hist[0];
|
auto node_hist = hist[0];
|
||||||
node_hist[0] = {-0.5, 0.5};
|
node_hist[0] = {-0.5, 0.5};
|
||||||
node_hist[1] = {2.0, 0.5};
|
node_hist[1] = {2.0, 0.5};
|
||||||
@ -143,7 +145,7 @@ TEST(HistMultiEvaluator, Evaluate) {
|
|||||||
|
|
||||||
std::vector<MultiExpandEntry> entries(1, {/*nidx=*/0, /*depth=*/0});
|
std::vector<MultiExpandEntry> entries(1, {/*nidx=*/0, /*depth=*/0});
|
||||||
|
|
||||||
std::vector<common::HistCollection const *> ptrs;
|
std::vector<BoundedHistCollection const *> ptrs;
|
||||||
std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs),
|
std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs),
|
||||||
[](auto const &h) { return std::addressof(h); });
|
[](auto const &h) { return std::addressof(h); });
|
||||||
|
|
||||||
@ -225,16 +227,16 @@ auto CompareOneHotAndPartition(bool onehot) {
|
|||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
HistMakerTrainParam hist_param;
|
||||||
|
|
||||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>(&ctx, {32, param.sparse_threshold})) {
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>(&ctx, {32, param.sparse_threshold})) {
|
||||||
common::HistCollection hist;
|
BoundedHistCollection hist;
|
||||||
|
|
||||||
entries.front().nid = 0;
|
entries.front().nid = 0;
|
||||||
entries.front().depth = 0;
|
entries.front().depth = 0;
|
||||||
|
|
||||||
hist.Init(gmat.cut.TotalBins());
|
hist.Reset(gmat.cut.TotalBins(), hist_param.internal_max_cached_hist_node);
|
||||||
hist.AddHistRow(0);
|
hist.AllocateHistograms({0});
|
||||||
hist.AllocateAllData();
|
|
||||||
auto node_hist = hist[0];
|
auto node_hist = hist[0];
|
||||||
|
|
||||||
CHECK_EQ(node_hist.size(), n_cats);
|
CHECK_EQ(node_hist.size(), n_cats);
|
||||||
@ -261,10 +263,10 @@ TEST(HistEvaluator, Categorical) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||||
common::HistCollection hist;
|
BoundedHistCollection hist;
|
||||||
hist.Init(cuts_.TotalBins());
|
HistMakerTrainParam hist_param;
|
||||||
hist.AddHistRow(0);
|
hist.Reset(cuts_.TotalBins(), hist_param.internal_max_cached_hist_node);
|
||||||
hist.AllocateAllData();
|
hist.AllocateHistograms({0});
|
||||||
auto node_hist = hist[0];
|
auto node_hist = hist[0];
|
||||||
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
|
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
|
||||||
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
|
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
|
||||||
|
|||||||
@ -2,16 +2,38 @@
|
|||||||
* Copyright 2018-2023 by Contributors
|
* Copyright 2018-2023 by Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h> // Context
|
#include <xgboost/base.h> // for bst_node_t, bst_bin_t, Gradient...
|
||||||
|
#include <xgboost/context.h> // for Context
|
||||||
|
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix
|
||||||
|
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||||
|
#include <xgboost/linalg.h> // for MakeTensorView
|
||||||
|
#include <xgboost/logging.h> // for Error, LogCheck_EQ, LogCheck_LT
|
||||||
|
#include <xgboost/span.h> // for Span, operator!=
|
||||||
|
#include <xgboost/tree_model.h> // for RegTree
|
||||||
|
|
||||||
#include <limits>
|
#include <algorithm> // for max
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, uint32_t
|
||||||
|
#include <functional> // for function
|
||||||
|
#include <iterator> // for back_inserter
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for shared_ptr, allocator, unique_ptr
|
||||||
|
#include <numeric> // for iota, accumulate
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../../src/common/categorical.h"
|
#include "../../../../src/collective/communicator-inl.h" // for GetRank, GetWorldSize
|
||||||
#include "../../../../src/common/row_set.h"
|
#include "../../../../src/common/hist_util.h" // for GHistRow, HistogramCuts, Sketch...
|
||||||
#include "../../../../src/tree/hist/expand_entry.h"
|
#include "../../../../src/common/ref_resource_view.h" // for RefResourceView
|
||||||
#include "../../../../src/tree/hist/histogram.h"
|
#include "../../../../src/common/row_set.h" // for RowSetCollection
|
||||||
#include "../../categorical_helpers.h"
|
#include "../../../../src/common/threading_utils.h" // for BlockedSpace2d
|
||||||
#include "../../helpers.h"
|
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
|
#include "../../../../src/tree/common_row_partitioner.h" // for CommonRowPartitioner
|
||||||
|
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
||||||
|
#include "../../../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
|
||||||
|
#include "../../../../src/tree/hist/histogram.h" // for HistogramBuilder
|
||||||
|
#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam
|
||||||
|
#include "../../categorical_helpers.h" // for OneHotEncodeFeature
|
||||||
|
#include "../../helpers.h" // for RandomDataGenerator, GenerateRa...
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
namespace {
|
namespace {
|
||||||
@ -25,9 +47,8 @@ void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples
|
|||||||
|
|
||||||
void TestAddHistRows(bool is_distributed) {
|
void TestAddHistRows(bool is_distributed) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
std::vector<bst_node_t> nodes_to_build;
|
||||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
std::vector<bst_node_t> nodes_to_sub;
|
||||||
int starting_index = std::numeric_limits<int>::max();
|
|
||||||
|
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
@ -40,24 +61,22 @@ void TestAddHistRows(bool is_distributed) {
|
|||||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3));
|
nodes_to_build.emplace_back(3);
|
||||||
nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4));
|
nodes_to_build.emplace_back(4);
|
||||||
nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5));
|
nodes_to_sub.emplace_back(5);
|
||||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
nodes_to_sub.emplace_back(6);
|
||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
HistMakerTrainParam hist_param;
|
||||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
HistogramBuilder histogram_builder;
|
||||||
is_distributed, false);
|
histogram_builder.Reset(&ctx, gmat.cut.TotalBins(), {kMaxBins, 0.5}, is_distributed, false,
|
||||||
histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
&hist_param);
|
||||||
nodes_for_subtraction_trick_);
|
histogram_builder.AddHistRows(&tree, &nodes_to_build, &nodes_to_sub, false);
|
||||||
|
|
||||||
ASSERT_EQ(starting_index, 3);
|
for (bst_node_t const &nidx : nodes_to_build) {
|
||||||
|
ASSERT_TRUE(histogram_builder.Histogram().HistogramExists(nidx));
|
||||||
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
|
|
||||||
ASSERT_EQ(histogram_builder.Histogram().RowExists(node.nid), true);
|
|
||||||
}
|
}
|
||||||
for (const CPUExpandEntry &node : nodes_for_subtraction_trick_) {
|
for (bst_node_t const &nidx : nodes_to_sub) {
|
||||||
ASSERT_EQ(histogram_builder.Histogram().RowExists(node.nid), true);
|
ASSERT_TRUE(histogram_builder.Histogram().HistogramExists(nidx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,83 +87,77 @@ TEST(CPUHistogram, AddRows) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TestSyncHist(bool is_distributed) {
|
void TestSyncHist(bool is_distributed) {
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
std::size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
bst_bin_t constexpr kMaxBins = 4;
|
||||||
Context ctx;
|
Context ctx;
|
||||||
|
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
std::vector<bst_bin_t> nodes_for_explicit_hist_build;
|
||||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
std::vector<bst_bin_t> nodes_for_subtraction_trick;
|
||||||
int starting_index = std::numeric_limits<int>::max();
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
auto const &gmat =
|
auto const &gmat =
|
||||||
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram;
|
HistogramBuilder histogram;
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false);
|
HistMakerTrainParam hist_param;
|
||||||
|
histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, false, &hist_param);
|
||||||
|
|
||||||
common::RowSetCollection row_set_collection_;
|
common::RowSetCollection row_set_collection;
|
||||||
{
|
{
|
||||||
row_set_collection_.Clear();
|
row_set_collection.Clear();
|
||||||
std::vector<size_t> &row_indices = *row_set_collection_.Data();
|
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||||
row_indices.resize(kNRows);
|
row_indices.resize(kNRows);
|
||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection_.Init();
|
row_set_collection.Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
// level 0
|
// level 0
|
||||||
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
|
nodes_for_explicit_hist_build.emplace_back(0);
|
||||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false);
|
||||||
nodes_for_subtraction_trick_);
|
|
||||||
|
|
||||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
nodes_for_explicit_hist_build_.clear();
|
nodes_for_explicit_hist_build.clear();
|
||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick.clear();
|
||||||
|
|
||||||
// level 1
|
// level 1
|
||||||
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
|
nodes_for_explicit_hist_build.emplace_back(tree[0].LeftChild());
|
||||||
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
|
nodes_for_subtraction_trick.emplace_back(tree[0].RightChild());
|
||||||
|
|
||||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false);
|
||||||
nodes_for_subtraction_trick_);
|
|
||||||
|
|
||||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
|
||||||
nodes_for_explicit_hist_build_.clear();
|
nodes_for_explicit_hist_build.clear();
|
||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick.clear();
|
||||||
// level 2
|
// level 2
|
||||||
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3));
|
nodes_for_explicit_hist_build.emplace_back(3);
|
||||||
nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4));
|
nodes_for_subtraction_trick.emplace_back(4);
|
||||||
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
|
nodes_for_explicit_hist_build.emplace_back(5);
|
||||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
nodes_for_subtraction_trick.emplace_back(6);
|
||||||
|
|
||||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false);
|
||||||
nodes_for_subtraction_trick_);
|
|
||||||
|
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
||||||
ASSERT_EQ(n_nodes, 2ul);
|
ASSERT_EQ(n_nodes, 2ul);
|
||||||
row_set_collection_.AddSplit(0, tree[0].LeftChild(), tree[0].RightChild(), 4,
|
row_set_collection.AddSplit(0, tree[0].LeftChild(), tree[0].RightChild(), 4, 4);
|
||||||
4);
|
row_set_collection.AddSplit(1, tree[1].LeftChild(), tree[1].RightChild(), 2, 2);
|
||||||
row_set_collection_.AddSplit(1, tree[1].LeftChild(), tree[1].RightChild(), 2,
|
row_set_collection.AddSplit(2, tree[2].LeftChild(), tree[2].RightChild(), 2, 2);
|
||||||
2);
|
|
||||||
row_set_collection_.AddSplit(2, tree[2].LeftChild(), tree[2].RightChild(), 2,
|
|
||||||
2);
|
|
||||||
|
|
||||||
common::BlockedSpace2d space(
|
common::BlockedSpace2d space(
|
||||||
n_nodes,
|
n_nodes,
|
||||||
[&](size_t node) {
|
[&](std::size_t nidx_in_set) {
|
||||||
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
bst_node_t nidx = nodes_for_explicit_hist_build[nidx_in_set];
|
||||||
return row_set_collection_[nid].Size();
|
return row_set_collection[nidx].Size();
|
||||||
},
|
},
|
||||||
256);
|
256);
|
||||||
|
|
||||||
std::vector<common::GHistRow> target_hists(n_nodes);
|
std::vector<common::GHistRow> target_hists(n_nodes);
|
||||||
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
|
for (size_t i = 0; i < nodes_for_explicit_hist_build.size(); ++i) {
|
||||||
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
bst_node_t nidx = nodes_for_explicit_hist_build[i];
|
||||||
target_hists[i] = histogram.Histogram()[nid];
|
target_hists[i] = histogram.Histogram()[nidx];
|
||||||
}
|
}
|
||||||
|
|
||||||
// set values to specific nodes hist
|
// set values to specific nodes hist
|
||||||
@ -168,8 +181,7 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
|
|
||||||
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
||||||
// sync hist
|
// sync hist
|
||||||
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
|
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
||||||
nodes_for_subtraction_trick_, starting_index);
|
|
||||||
|
|
||||||
using GHistRowT = common::GHistRow;
|
using GHistRowT = common::GHistRow;
|
||||||
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
||||||
@ -182,11 +194,10 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
size_t node_id = 0;
|
size_t node_id = 0;
|
||||||
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
|
for (auto const &nidx : nodes_for_explicit_hist_build) {
|
||||||
auto this_hist = histogram.Histogram()[node.nid];
|
auto this_hist = histogram.Histogram()[nidx];
|
||||||
const size_t parent_id = tree[node.nid].Parent();
|
const size_t parent_id = tree[nidx].Parent();
|
||||||
const size_t subtraction_node_id =
|
const size_t subtraction_node_id = nodes_for_subtraction_trick[node_id];
|
||||||
nodes_for_subtraction_trick_[node_id].nid;
|
|
||||||
auto parent_hist = histogram.Histogram()[parent_id];
|
auto parent_hist = histogram.Histogram()[parent_id];
|
||||||
auto sibling_hist = histogram.Histogram()[subtraction_node_id];
|
auto sibling_hist = histogram.Histogram()[subtraction_node_id];
|
||||||
|
|
||||||
@ -194,11 +205,10 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
++node_id;
|
++node_id;
|
||||||
}
|
}
|
||||||
node_id = 0;
|
node_id = 0;
|
||||||
for (const CPUExpandEntry &node : nodes_for_subtraction_trick_) {
|
for (auto const &nidx : nodes_for_subtraction_trick) {
|
||||||
auto this_hist = histogram.Histogram()[node.nid];
|
auto this_hist = histogram.Histogram()[nidx];
|
||||||
const size_t parent_id = tree[node.nid].Parent();
|
const size_t parent_id = tree[nidx].Parent();
|
||||||
const size_t subtraction_node_id =
|
const size_t subtraction_node_id = nodes_for_explicit_hist_build[node_id];
|
||||||
nodes_for_explicit_hist_build_[node_id].nid;
|
|
||||||
auto parent_hist = histogram.Histogram()[parent_id];
|
auto parent_hist = histogram.Histogram()[parent_id];
|
||||||
auto sibling_hist = histogram.Histogram()[subtraction_node_id];
|
auto sibling_hist = histogram.Histogram()[subtraction_node_id];
|
||||||
|
|
||||||
@ -232,9 +242,9 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f}};
|
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f}};
|
||||||
|
|
||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
HistogramBuilder<CPUExpandEntry> histogram;
|
HistogramBuilder histogram;
|
||||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed,
|
HistMakerTrainParam hist_param;
|
||||||
is_col_split);
|
histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param);
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
@ -246,12 +256,17 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
|
||||||
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
std::vector<bst_node_t> nodes_to_build{node.nid};
|
||||||
nodes_for_explicit_hist_build.push_back(node);
|
std::vector<bst_node_t> dummy_sub;
|
||||||
|
|
||||||
|
histogram.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
|
||||||
|
common::BlockedSpace2d space{
|
||||||
|
1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256};
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {kMaxBins, 0.5})) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {kMaxBins, 0.5})) {
|
||||||
histogram.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||||
gpair, force_read_by_column);
|
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
|
||||||
}
|
}
|
||||||
|
histogram.SyncHistogram(&tree, nodes_to_build, {});
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
||||||
@ -312,18 +327,18 @@ void ValidateCategoricalHistogram(size_t n_categories,
|
|||||||
|
|
||||||
void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||||
size_t constexpr kRows = 340;
|
size_t constexpr kRows = 340;
|
||||||
int32_t constexpr kBins = 256;
|
bst_bin_t constexpr kBins = 256;
|
||||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
||||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
Context ctx;
|
Context ctx;
|
||||||
|
|
||||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
BatchParam batch_param{0, kBins};
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(RegTree::kRoot)};
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
std::vector<bst_node_t> nodes_to_build;
|
||||||
nodes_for_explicit_hist_build.push_back(node);
|
nodes_to_build.push_back(node.nid);
|
||||||
|
|
||||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
|
|
||||||
@ -333,30 +348,41 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
row_indices.resize(kRows);
|
row_indices.resize(kRows);
|
||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
HistMakerTrainParam hist_param;
|
||||||
|
std::vector<bst_node_t> dummy_sub;
|
||||||
|
|
||||||
|
common::BlockedSpace2d space{
|
||||||
|
1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate hist with cat data.
|
* Generate hist with cat data.
|
||||||
*/
|
*/
|
||||||
HistogramBuilder<CPUExpandEntry> cat_hist;
|
HistogramBuilder cat_hist;
|
||||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
cat_hist.Reset(&ctx, total_bins, {kBins, 0.5}, false, false, &hist_param);
|
||||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
cat_hist.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
|
||||||
gpair.HostVector(), force_read_by_column);
|
cat_hist.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||||
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||||
|
force_read_by_column);
|
||||||
}
|
}
|
||||||
|
cat_hist.SyncHistogram(&tree, nodes_to_build, {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate hist with one hot encoded data.
|
* Generate hist with one hot encoded data.
|
||||||
*/
|
*/
|
||||||
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
||||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
||||||
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
HistogramBuilder onehot_hist;
|
||||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
onehot_hist.Reset(&ctx, total_bins, {kBins, 0.5}, false, false, &hist_param);
|
||||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
onehot_hist.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
|
||||||
gpair.HostVector(), force_read_by_column);
|
onehot_hist.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||||
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||||
|
force_read_by_column);
|
||||||
}
|
}
|
||||||
|
onehot_hist.SyncHistogram(&tree, nodes_to_build, {});
|
||||||
|
|
||||||
auto cat = cat_hist.Histogram()[0];
|
auto cat = cat_hist.Histogram()[0];
|
||||||
auto onehot = onehot_hist.Histogram()[0];
|
auto onehot = onehot_hist.Histogram()[0];
|
||||||
@ -383,19 +409,22 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
|||||||
batch_param.hess = hess;
|
batch_param.hess = hess;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> partition_size(1, 0);
|
std::vector<std::size_t> partition_size(1, 0);
|
||||||
size_t total_bins{0};
|
bst_bin_t total_bins{0};
|
||||||
size_t n_samples{0};
|
bst_row_t n_samples{0};
|
||||||
|
|
||||||
auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0);
|
auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0);
|
||||||
auto const &h_gpair = gpair.HostVector();
|
auto const &h_gpair = gpair.HostVector();
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<CPUExpandEntry> nodes;
|
std::vector<bst_node_t> nodes{RegTree::kRoot};
|
||||||
nodes.emplace_back(0, tree.GetDepth(0));
|
common::BlockedSpace2d space{
|
||||||
|
1, [&](std::size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 256};
|
||||||
|
|
||||||
common::GHistRow multi_page;
|
common::GHistRow multi_page;
|
||||||
HistogramBuilder<CPUExpandEntry> multi_build;
|
HistogramBuilder multi_build;
|
||||||
|
HistMakerTrainParam hist_param;
|
||||||
|
std::vector<bst_node_t> dummy_sub;
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
* Multi page
|
* Multi page
|
||||||
@ -413,23 +442,21 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
|||||||
}
|
}
|
||||||
ASSERT_EQ(n_samples, m->Info().num_row_);
|
ASSERT_EQ(n_samples, m->Info().num_row_);
|
||||||
|
|
||||||
common::BlockedSpace2d space{
|
multi_build.Reset(ctx, total_bins, batch_param, false, false, &hist_param);
|
||||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
multi_build.AddHistRows(&tree, &nodes, &dummy_sub, false);
|
||||||
256};
|
std::size_t page_idx{0};
|
||||||
|
|
||||||
multi_build.Reset(total_bins, batch_param, ctx->Threads(), rows_set.size(), false, false);
|
|
||||||
|
|
||||||
size_t page_idx{0};
|
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(ctx, batch_param)) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(ctx, batch_param)) {
|
||||||
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {}, h_gpair,
|
multi_build.BuildHist(page_idx, space, page, rows_set[page_idx], nodes,
|
||||||
|
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
++page_idx;
|
++page_idx;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(page_idx, 2);
|
multi_build.SyncHistogram(&tree, nodes, {});
|
||||||
multi_page = multi_build.Histogram()[0];
|
|
||||||
|
multi_page = multi_build.Histogram()[RegTree::kRoot];
|
||||||
}
|
}
|
||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> single_build;
|
HistogramBuilder single_build;
|
||||||
common::GHistRow single_page;
|
common::GHistRow single_page;
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
@ -438,18 +465,24 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
|||||||
common::RowSetCollection row_set_collection;
|
common::RowSetCollection row_set_collection;
|
||||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||||
|
|
||||||
single_build.Reset(total_bins, batch_param, ctx->Threads(), 1, false, false);
|
single_build.Reset(ctx, total_bins, batch_param, false, false, &hist_param);
|
||||||
SparsePage concat;
|
SparsePage concat;
|
||||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||||
concat.Push(page);
|
concat.Push(page);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cut = common::SketchOnDMatrix(ctx, m.get(), batch_param.max_bin, false, hess);
|
auto cut = common::SketchOnDMatrix(ctx, m.get(), batch_param.max_bin, false, hess);
|
||||||
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
||||||
std::numeric_limits<double>::quiet_NaN(), ctx->Threads());
|
std::numeric_limits<double>::quiet_NaN(), ctx->Threads());
|
||||||
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column);
|
|
||||||
single_page = single_build.Histogram()[0];
|
single_build.AddHistRows(&tree, &nodes, &dummy_sub, false);
|
||||||
|
single_build.BuildHist(0, space, gmat, row_set_collection, nodes,
|
||||||
|
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
|
||||||
|
force_read_by_column);
|
||||||
|
single_build.SyncHistogram(&tree, nodes, {});
|
||||||
|
|
||||||
|
single_page = single_build.Histogram()[RegTree::kRoot];
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < single_page.size(); ++i) {
|
for (size_t i = 0; i < single_page.size(); ++i) {
|
||||||
@ -473,4 +506,108 @@ TEST(CPUHistogram, ExternalMemory) {
|
|||||||
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, false);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, false);
|
||||||
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class OverflowTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||||
|
public:
|
||||||
|
std::vector<GradientPairPrecise> TestOverflow(bool limit, bool is_distributed,
|
||||||
|
bool is_col_split) {
|
||||||
|
bst_bin_t constexpr kBins = 256;
|
||||||
|
Context ctx;
|
||||||
|
HistMakerTrainParam hist_param;
|
||||||
|
if (limit) {
|
||||||
|
hist_param.Init(Args{{"internal_max_cached_hist_node", "1"}});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DMatrix> Xy =
|
||||||
|
is_col_split ? RandomDataGenerator{8192, 16, 0.5}.GenerateDMatrix(true)
|
||||||
|
: RandomDataGenerator{8192, 16, 0.5}.Bins(kBins).GenerateQuantileDMatrix(true);
|
||||||
|
if (is_col_split) {
|
||||||
|
Xy =
|
||||||
|
std::shared_ptr<DMatrix>{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
|
|
||||||
|
double sparse_thresh{TrainParam::DftSparseThreshold()};
|
||||||
|
auto batch = BatchParam{kBins, sparse_thresh};
|
||||||
|
bst_bin_t n_total_bins{0};
|
||||||
|
float split_cond{0};
|
||||||
|
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
|
n_total_bins = page.cut.TotalBins();
|
||||||
|
// use a cut point in the second column for split
|
||||||
|
split_cond = page.cut.Values()[kBins + kBins / 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
MultiHistogramBuilder hist_builder;
|
||||||
|
CHECK_EQ(Xy->Info().IsColumnSplit(), is_col_split);
|
||||||
|
|
||||||
|
hist_builder.Reset(&ctx, n_total_bins, tree.NumTargets(), batch, is_distributed,
|
||||||
|
Xy->Info().IsColumnSplit(), &hist_param);
|
||||||
|
|
||||||
|
std::vector<CommonRowPartitioner> partitioners;
|
||||||
|
partitioners.emplace_back(&ctx, Xy->Info().num_row_, /*base_rowid=*/0,
|
||||||
|
Xy->Info().IsColumnSplit());
|
||||||
|
|
||||||
|
auto gpair = GenerateRandomGradients(Xy->Info().num_row_, 0.0, 1.0);
|
||||||
|
|
||||||
|
CPUExpandEntry best;
|
||||||
|
hist_builder.BuildRootHist(Xy.get(), &tree, partitioners,
|
||||||
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1),
|
||||||
|
best, batch);
|
||||||
|
|
||||||
|
best.split.Update(1.0f, 1, split_cond, false, false, GradStats{1.0, 1.0}, GradStats{1.0, 1.0});
|
||||||
|
tree.ExpandNode(best.nid, best.split.SplitIndex(), best.split.split_value, false,
|
||||||
|
/*base_weight=*/2.0f,
|
||||||
|
/*left_leaf_weight=*/1.0f, /*right_leaf_weight=*/1.0f, best.GetLossChange(),
|
||||||
|
/*sum_hess=*/2.0f, best.split.left_sum.GetHess(),
|
||||||
|
best.split.right_sum.GetHess());
|
||||||
|
|
||||||
|
std::vector<CPUExpandEntry> valid_candidates{best};
|
||||||
|
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
|
partitioners.front().UpdatePosition(&ctx, page, valid_candidates, &tree);
|
||||||
|
}
|
||||||
|
CHECK_NE(partitioners.front()[tree.LeftChild(best.nid)].Size(), 0);
|
||||||
|
CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0);
|
||||||
|
|
||||||
|
hist_builder.BuildHistLeftRight(
|
||||||
|
Xy.get(), &tree, partitioners, valid_candidates,
|
||||||
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch);
|
||||||
|
|
||||||
|
if (limit) {
|
||||||
|
CHECK(!hist_builder.Histogram(0).HistogramExists(best.nid));
|
||||||
|
} else {
|
||||||
|
CHECK(hist_builder.Histogram(0).HistogramExists(best.nid));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<GradientPairPrecise> result;
|
||||||
|
auto hist = hist_builder.Histogram(0)[tree.LeftChild(best.nid)];
|
||||||
|
std::copy(hist.cbegin(), hist.cend(), std::back_inserter(result));
|
||||||
|
hist = hist_builder.Histogram(0)[tree.RightChild(best.nid)];
|
||||||
|
std::copy(hist.cbegin(), hist.cend(), std::back_inserter(result));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest() {
|
||||||
|
auto param = GetParam();
|
||||||
|
auto res0 = this->TestOverflow(false, std::get<0>(param), std::get<1>(param));
|
||||||
|
auto res1 = this->TestOverflow(true, std::get<0>(param), std::get<1>(param));
|
||||||
|
ASSERT_EQ(res0, res1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto MakeParamsForTest() {
|
||||||
|
std::vector<std::tuple<bool, bool>> configs;
|
||||||
|
for (auto i : {true, false}) {
|
||||||
|
for (auto j : {true, false}) {
|
||||||
|
configs.emplace_back(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return configs;
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST_P(OverflowTest, Overflow) { this->RunTest(); }
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(CPUHistogram, OverflowTest, ::testing::ValuesIn(MakeParamsForTest()));
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -2,22 +2,24 @@
|
|||||||
* Copyright 2022-2023 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
|
#include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
|
||||||
#include <xgboost/data.h> // for MetaInfo
|
#include <xgboost/data.h> // for MetaInfo
|
||||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||||
#include <xgboost/span.h> // for operator!=, Span, SpanIterator
|
#include <xgboost/span.h> // for operator!=, Span, SpanIterator
|
||||||
|
|
||||||
#include <algorithm> // for max, max_element, next_permutation, copy
|
#include <algorithm> // for max, max_element, next_permutation, copy
|
||||||
#include <cmath> // for isnan
|
#include <cmath> // for isnan
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t, uint64_t, uint32_t
|
#include <cstdint> // for int32_t, uint64_t, uint32_t
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
#include <numeric> // for iota
|
#include <numeric> // for iota
|
||||||
#include <tuple> // for make_tuple, tie, tuple
|
#include <tuple> // for make_tuple, tie, tuple
|
||||||
#include <utility> // for pair
|
#include <utility> // for pair
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
|
#include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
|
||||||
|
#include "../../../src/tree/hist/hist_cache.h" // for HistogramCollection
|
||||||
|
#include "../../../src/tree/hist/param.h" // for HistMakerTrainParam
|
||||||
#include "../../../src/tree/param.h" // for TrainParam, GradStats
|
#include "../../../src/tree/param.h" // for TrainParam, GradStats
|
||||||
#include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
|
#include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
|
||||||
#include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
|
#include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
|
||||||
@ -35,7 +37,7 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
|||||||
MetaInfo info_;
|
MetaInfo info_;
|
||||||
float best_score_{-std::numeric_limits<float>::infinity()};
|
float best_score_{-std::numeric_limits<float>::infinity()};
|
||||||
common::HistogramCuts cuts_;
|
common::HistogramCuts cuts_;
|
||||||
common::HistCollection hist_;
|
BoundedHistCollection hist_;
|
||||||
GradientPairPrecise total_gpair_;
|
GradientPairPrecise total_gpair_;
|
||||||
|
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
@ -56,9 +58,9 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
|||||||
|
|
||||||
cuts_.min_vals_.Resize(1);
|
cuts_.min_vals_.Resize(1);
|
||||||
|
|
||||||
hist_.Init(cuts_.TotalBins());
|
HistMakerTrainParam hist_param;
|
||||||
hist_.AddHistRow(0);
|
hist_.Reset(cuts_.TotalBins(), hist_param.internal_max_cached_hist_node);
|
||||||
hist_.AllocateAllData();
|
hist_.AllocateHistograms({0});
|
||||||
auto node_hist = hist_[0];
|
auto node_hist = hist_[0];
|
||||||
|
|
||||||
SimpleLCG lcg;
|
SimpleLCG lcg;
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from hypothesis import given, settings, strategies
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.testing.data import check_inf
|
from xgboost.testing.data import check_inf
|
||||||
|
from xgboost.testing.data_iter import run_mixed_sparsity
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import test_quantile_dmatrix as tqd
|
import test_quantile_dmatrix as tqd
|
||||||
@ -232,3 +233,6 @@ class TestQuantileDMatrix:
|
|||||||
|
|
||||||
rng = cp.random.default_rng(1994)
|
rng = cp.random.default_rng(1994)
|
||||||
check_inf(rng)
|
check_inf(rng)
|
||||||
|
|
||||||
|
def test_mixed_sparsity(self) -> None:
|
||||||
|
run_mixed_sparsity("cuda")
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from xgboost.testing import (
|
|||||||
predictor_equal,
|
predictor_equal,
|
||||||
)
|
)
|
||||||
from xgboost.testing.data import check_inf, np_dtypes
|
from xgboost.testing.data import check_inf, np_dtypes
|
||||||
|
from xgboost.testing.data_iter import run_mixed_sparsity
|
||||||
|
|
||||||
|
|
||||||
class TestQuantileDMatrix:
|
class TestQuantileDMatrix:
|
||||||
@ -334,3 +335,6 @@ class TestQuantileDMatrix:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="consistent"):
|
with pytest.raises(ValueError, match="consistent"):
|
||||||
xgb.train({}, Xy, num_boost_round=2, xgb_model=booster)
|
xgb.train({}, Xy, num_boost_round=2, xgb_model=booster)
|
||||||
|
|
||||||
|
def test_mixed_sparsity(self) -> None:
|
||||||
|
run_mixed_sparsity("cpu")
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from xgboost import testing as tm
|
|||||||
from xgboost.testing.params import (
|
from xgboost.testing.params import (
|
||||||
cat_parameter_strategy,
|
cat_parameter_strategy,
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
hist_multi_parameter_strategy,
|
hist_multi_parameter_strategy,
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
)
|
)
|
||||||
@ -40,14 +41,22 @@ class TestTreeMethodMulti:
|
|||||||
@given(
|
@given(
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
strategies.integers(1, 20),
|
strategies.integers(1, 20),
|
||||||
tm.multi_dataset_strategy,
|
tm.multi_dataset_strategy,
|
||||||
)
|
)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
def test_approx(self, param, hist_param, num_rounds, dataset):
|
def test_approx(
|
||||||
|
self, param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
) -> None:
|
||||||
param["tree_method"] = "approx"
|
param["tree_method"] = "approx"
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
param.update(hist_param)
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
@ -55,18 +64,25 @@ class TestTreeMethodMulti:
|
|||||||
@given(
|
@given(
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
hist_multi_parameter_strategy,
|
hist_multi_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
strategies.integers(1, 20),
|
strategies.integers(1, 20),
|
||||||
tm.multi_dataset_strategy,
|
tm.multi_dataset_strategy,
|
||||||
)
|
)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
def test_hist(
|
def test_hist(
|
||||||
self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset
|
self,
|
||||||
|
param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
if dataset.name.endswith("-l1"):
|
if dataset.name.endswith("-l1"):
|
||||||
return
|
return
|
||||||
param["tree_method"] = "hist"
|
param["tree_method"] = "hist"
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
param.update(hist_param)
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
@ -91,14 +107,23 @@ class TestTreeMethod:
|
|||||||
@given(
|
@given(
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
strategies.integers(1, 20),
|
strategies.integers(1, 20),
|
||||||
tm.make_dataset_strategy(),
|
tm.make_dataset_strategy(),
|
||||||
)
|
)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
def test_approx(self, param, hist_param, num_rounds, dataset):
|
def test_approx(
|
||||||
|
self,
|
||||||
|
param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
) -> None:
|
||||||
param["tree_method"] = "approx"
|
param["tree_method"] = "approx"
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
param.update(hist_param)
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result["train"][dataset.metric])
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
@ -130,17 +155,25 @@ class TestTreeMethod:
|
|||||||
@given(
|
@given(
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
|
hist_cache_strategy,
|
||||||
strategies.integers(1, 20),
|
strategies.integers(1, 20),
|
||||||
tm.make_dataset_strategy()
|
tm.make_dataset_strategy()
|
||||||
)
|
)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
def test_hist(self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
def test_hist(
|
||||||
param['tree_method'] = 'hist'
|
self, param: Dict[str, Any],
|
||||||
|
hist_param: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
num_rounds: int,
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
) -> None:
|
||||||
|
param["tree_method"] = "hist"
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
param.update(hist_param)
|
param.update(hist_param)
|
||||||
|
param.update(cache_param)
|
||||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result['train'][dataset.metric])
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
def test_hist_categorical(self):
|
def test_hist_categorical(self):
|
||||||
# hist must be same as exact on all-categorial data
|
# hist must be same as exact on all-categorial data
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from sklearn.datasets import make_classification, make_regression
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.data import _is_cudf_df
|
from xgboost.data import _is_cudf_df
|
||||||
from xgboost.testing.params import hist_parameter_strategy
|
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
||||||
from xgboost.testing.shared import (
|
from xgboost.testing.shared import (
|
||||||
get_feature_weights,
|
get_feature_weights,
|
||||||
validate_data_initialization,
|
validate_data_initialization,
|
||||||
@ -1512,14 +1512,23 @@ class TestWithDask:
|
|||||||
else:
|
else:
|
||||||
assert history[-1] < history[0]
|
assert history[-1] < history[0]
|
||||||
|
|
||||||
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
|
@given(
|
||||||
|
params=hist_parameter_strategy,
|
||||||
|
cache_param=hist_cache_strategy,
|
||||||
|
dataset=tm.make_dataset_strategy(),
|
||||||
|
)
|
||||||
@settings(
|
@settings(
|
||||||
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
||||||
)
|
)
|
||||||
def test_hist(
|
def test_hist(
|
||||||
self, params: Dict, dataset: tm.TestDataset, client: "Client"
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
dataset: tm.TestDataset,
|
||||||
|
client: "Client",
|
||||||
) -> None:
|
) -> None:
|
||||||
num_rounds = 10
|
num_rounds = 10
|
||||||
|
params.update(cache_param)
|
||||||
self.run_updater_test(client, params, num_rounds, dataset, "hist")
|
self.run_updater_test(client, params, num_rounds, dataset, "hist")
|
||||||
|
|
||||||
def test_quantile_dmatrix(self, client: Client) -> None:
|
def test_quantile_dmatrix(self, client: Client) -> None:
|
||||||
@ -1579,14 +1588,23 @@ class TestWithDask:
|
|||||||
rmse = result["history"]["Valid"]["rmse"][-1]
|
rmse = result["history"]["Valid"]["rmse"][-1]
|
||||||
assert rmse < 32.0
|
assert rmse < 32.0
|
||||||
|
|
||||||
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
|
@given(
|
||||||
|
params=hist_parameter_strategy,
|
||||||
|
cache_param=hist_cache_strategy,
|
||||||
|
dataset=tm.make_dataset_strategy()
|
||||||
|
)
|
||||||
@settings(
|
@settings(
|
||||||
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
||||||
)
|
)
|
||||||
def test_approx(
|
def test_approx(
|
||||||
self, client: "Client", params: Dict, dataset: tm.TestDataset
|
self,
|
||||||
|
client: "Client",
|
||||||
|
params: Dict,
|
||||||
|
cache_param: Dict[str, Any],
|
||||||
|
dataset: tm.TestDataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_rounds = 10
|
num_rounds = 10
|
||||||
|
params.update(cache_param)
|
||||||
self.run_updater_test(client, params, num_rounds, dataset, "approx")
|
self.run_updater_test(client, params, num_rounds, dataset, "approx")
|
||||||
|
|
||||||
def test_adaptive(self) -> None:
|
def test_adaptive(self) -> None:
|
||||||
@ -2239,7 +2257,7 @@ async def test_worker_left(c, s, a, b):
|
|||||||
)
|
)
|
||||||
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
|
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
|
||||||
with pytest.raises(RuntimeError, match="Missing"):
|
with pytest.raises(RuntimeError, match="Missing"):
|
||||||
await xgb.dask.train(
|
await xgb.dask.train(
|
||||||
c,
|
c,
|
||||||
{},
|
{},
|
||||||
d_train,
|
d_train,
|
||||||
@ -2256,7 +2274,7 @@ async def test_worker_restarted(c, s, a, b):
|
|||||||
)
|
)
|
||||||
await c.restart_workers([a.worker_address])
|
await c.restart_workers([a.worker_address])
|
||||||
with pytest.raises(RuntimeError, match="Missing"):
|
with pytest.raises(RuntimeError, match="Missing"):
|
||||||
await xgb.dask.train(
|
await xgb.dask.train(
|
||||||
c,
|
c,
|
||||||
{},
|
{},
|
||||||
d_train,
|
d_train,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user