[POC] Experimental support for l1 error. (#7812)

Support adaptive tree, a feature supported by both sklearn and lightgbm.  The tree leaf is recomputed based on residue of labels and predictions after construction.

For l1 error, the optimal value is the median (50 percentile).

This is marked as experimental support for the following reasons:
- The value is not well defined for distributed training, where we might have empty leaves for local workers. Right now I just use the original leaf value for computing the average with other workers, which might cause significant errors.
- Some follow-ups are required, for exact, pruner, and optimization for quantile function. Also, we need to calculate the initial estimation.
This commit is contained in:
Jiaming Yuan 2022-04-26 21:41:55 +08:00 committed by GitHub
parent ad06172c6b
commit fdf533f2b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1727 additions and 336 deletions

View File

@ -24,6 +24,7 @@
#include "../src/objective/rank_obj.cc" #include "../src/objective/rank_obj.cc"
#include "../src/objective/hinge.cc" #include "../src/objective/hinge.cc"
#include "../src/objective/aft_obj.cc" #include "../src/objective/aft_obj.cc"
#include "../src/objective/adaptive.cc"
// gbms // gbms
#include "../src/gbm/gbm.cc" #include "../src/gbm/gbm.cc"

View File

@ -400,7 +400,6 @@
"reg_loss_param" "reg_loss_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -433,6 +432,14 @@
"tweedie_regression_param" "tweedie_regression_param"
] ]
}, },
{
"properties": {
"name": {
"const": "reg:absoluteerror"
}
},
"type": "object"
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@ -349,6 +349,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective. - ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective.
- ``reg:logistic``: logistic regression. - ``reg:logistic``: logistic regression.
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction.
- ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logistic``: logistic regression for binary classification, output probability
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.

View File

@ -90,9 +90,8 @@ class GradientBooster : public Model, public Configurable {
* \param prediction The output prediction cache entry that needs to be updated. * \param prediction The output prediction cache entry that needs to be updated.
* the booster may change content of gpair * the booster may change content of gpair
*/ */
virtual void DoBoost(DMatrix* p_fmat, virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
HostDeviceVector<GradientPair>* in_gpair, PredictionCacheEntry*, ObjFunction const* obj) = 0;
PredictionCacheEntry*) = 0;
/*! /*!
* \brief generate predictions for given feature matrix * \brief generate predictions for given feature matrix

View File

@ -670,9 +670,13 @@ class Tensor {
* See \ref TensorView for parameters of this constructor. * See \ref TensorView for parameters of this constructor.
*/ */
template <typename I, int32_t D> template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], int32_t device) { explicit Tensor(I const (&shape)[D], int32_t device)
: Tensor{common::Span<I const, D>{shape}, device} {}
template <typename I, size_t D>
explicit Tensor(common::Span<I const, D> shape, int32_t device) {
// No device unroll as this is a host only function. // No device unroll as this is a host only function.
std::copy(shape, shape + D, shape_); std::copy(shape.data(), shape.data() + D, shape_);
for (auto i = D; i < kDim; ++i) { for (auto i = D; i < kDim; ++i) {
shape_[i] = 1; shape_[i] = 1;
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2022 by Contributors
* \file objective.h * \file objective.h
* \brief interface of objective function used by xgboost. * \brief interface of objective function used by xgboost.
* \author Tianqi Chen, Kailong Chen * \author Tianqi Chen, Kailong Chen
@ -22,6 +22,8 @@
namespace xgboost { namespace xgboost {
class RegTree;
/*! \brief interface of objective function */ /*! \brief interface of objective function */
class ObjFunction : public Configurable { class ObjFunction : public Configurable {
protected: protected:
@ -88,6 +90,22 @@ class ObjFunction : public Configurable {
return 1; return 1;
} }
/**
* \brief Update the leaf values after a tree is built. Needed for objectives with 0
* hessian.
*
* Note that the leaf update is not well defined for distributed training as XGBoost
* computes only an average of quantile between workers. This breaks when some leaf
* have no sample assigned in a local worker.
*
* \param position The leaf index for each rows.
* \param info MetaInfo providing labels and weights.
* \param prediction Model prediction after transformation.
* \param p_tree Tree that needs to be updated.
*/
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const {}
/*! /*!
* \brief Create an objective function according to name. * \brief Create an objective function according to name.
* \param tparam Generic parameters. * \param tparam Generic parameters.

View File

@ -33,13 +33,18 @@ struct ObjInfo {
} task; } task;
// Does the objective have constant hessian value? // Does the objective have constant hessian value?
bool const_hess{false}; bool const_hess{false};
bool zero_hess{false};
explicit ObjInfo(Task t) : task{t} {} ObjInfo(Task t) : task{t} {} // NOLINT
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {} ObjInfo(Task t, bool khess, bool zhess) : task{t}, const_hess{khess}, zero_hess(zhess) {}
XGBOOST_DEVICE bool UseOneHot() const { XGBOOST_DEVICE bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary); return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
} }
/**
* \brief Use adaptive tree if the objective doesn't have valid hessian value.
*/
XGBOOST_DEVICE bool UpdateTreeLeaf() const { return zero_hess; }
}; };
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TASK_H_ #endif // XGBOOST_TASK_H_

View File

@ -49,18 +49,25 @@ class TreeUpdater : public Configurable {
* existing trees. * existing trees.
*/ */
virtual bool CanModifyTree() const { return false; } virtual bool CanModifyTree() const { return false; }
/*!
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
* tree can be used.
*/
virtual bool HasNodePosition() const { return false; }
/*! /*!
* \brief perform update to the tree models * \brief perform update to the tree models
* \param gpair the gradient pair statistics of the data * \param gpair the gradient pair statistics of the data
* \param data The data matrix passed to the updater. * \param data The data matrix passed to the updater.
* \param trees references the trees to be updated, updater will change the content of trees * \param out_position The leaf index for each row. The index is negated if that row is
* removed during sampling. So the 3th node is ~3.
* \param out_trees references the trees to be updated, updater will change the content of trees
* note: all the trees in the vector are updated, with the same statistics, * note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time, * but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model * there can be multiple trees when we train random forest style model
*/ */
virtual void Update(HostDeviceVector<GradientPair>* gpair, virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) = 0; const std::vector<RegTree*>& out_trees) = 0;
/*! /*!
* \brief determines whether updater has enough knowledge about a given dataset * \brief determines whether updater has enough knowledge about a given dataset

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2019 by Contributors * Copyright 2015-2022 by Contributors
* \file custom_metric.cc * \file custom_metric.cc
* \brief This is an example to define plugin of xgboost. * \brief This is an example to define plugin of xgboost.
* This plugin defines the additional metric function. * This plugin defines the additional metric function.
@ -31,13 +31,9 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam);
// Implement the interface. // Implement the interface.
class MyLogistic : public ObjFunction { class MyLogistic : public ObjFunction {
public: public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); }
param_.UpdateAllowUnknown(args);
}
struct ObjInfo Task() const override { ObjInfo Task() const override { return ObjInfo::kRegression; }
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float> &preds, void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, const MetaInfo &info,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2018 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file common.h * \file common.h
* \brief Common utilities * \brief Common utilities
*/ */
@ -14,12 +14,12 @@
#include <exception> #include <exception>
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <type_traits>
#include <vector>
#include <string>
#include <sstream>
#include <numeric> #include <numeric>
#include <sstream>
#include <string>
#include <type_traits>
#include <utility> #include <utility>
#include <vector>
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
@ -164,6 +164,67 @@ class Range {
Iterator end_; Iterator end_;
}; };
/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
size_t iter_{0};
Fn fn_;
public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT
public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
value_type operator*() const { return fn_(iter_); }
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};
template <typename Fn>
auto MakeIndexTransformIter(Fn&& fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}
int AllVisibleGPUs(); int AllVisibleGPUs();
inline void AssertGPUSupport() { inline void AssertGPUSupport() {
@ -191,13 +252,39 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
struct OptionalWeights { struct OptionalWeights {
Span<float const> weights; Span<float const> weights;
float dft{1.0f}; float dft{1.0f}; // fixme: make this compile time constant
explicit OptionalWeights(Span<float const> w) : weights{w} {} explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {} explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
}; };
/**
* Last index of a group in a CSR style of index pointer.
*/
template <typename Indexable>
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
return indptr[group + 1] - 1;
}
/**
* \brief Run length encode on CPU, input must be sorted.
*/
template <typename Iter, typename Idx>
void RunLengthEncode(Iter begin, Iter end, std::vector<Idx> *p_out) {
auto &out = *p_out;
out = std::vector<Idx>{0};
size_t n = std::distance(begin, end);
for (size_t i = 1; i < n; ++i) {
if (begin[i] != begin[i - 1]) {
out.push_back(i);
}
}
if (out.back() != n) {
out.push_back(n);
}
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_ #endif // XGBOOST_COMMON_COMMON_H_

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#pragma once #pragma once
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
@ -1537,6 +1537,43 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
} }
/**
* \brief Different from the above one, this one can handle cases where segment doesn't
* start from 0, but as a result it uses comparison sort.
*/
template <typename SegIt, typename ValIt>
void SegmentedArgSort(SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end,
dh::device_vector<size_t> *p_sorted_idx) {
using Tup = thrust::tuple<int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) -> Tup {
int32_t leaf_idx;
if (i < *seg_begin) {
leaf_idx = -1;
} else {
leaf_idx = dh::SegmentId(seg_begin, seg_end, i);
}
auto residue = val_begin[i];
return thrust::make_tuple(leaf_idx, residue);
});
dh::XGBCachingDeviceAllocator<char> caching;
thrust::copy(thrust::cuda::par(caching), key_it, key_it + keys.size(), keys.begin());
dh::XGBDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
}
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
}
class CUDAStreamView; class CUDAStreamView;
class CUDAEvent { class CUDAEvent {
@ -1600,5 +1637,6 @@ class CUDAStream {
} }
CUDAStreamView View() const { return CUDAStreamView{stream_}; } CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); }
}; };
} // namespace dh } // namespace dh

View File

@ -13,6 +13,7 @@ namespace xgboost {
namespace linalg { namespace linalg {
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value, static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead."); "For function with return, use transform instead.");
if (t.Contiguous()) { if (t.Contiguous()) {
@ -40,7 +41,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
} }
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t, Fn&& fn) { void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn); ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
} }
} // namespace linalg } // namespace linalg

View File

@ -12,10 +12,12 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <limits>
#include <vector> #include <vector>
#include "categorical.h" #include "categorical.h"
#include "column_matrix.h" #include "column_matrix.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
@ -254,7 +256,7 @@ class PartitionBuilder {
n_left += mem_blocks_[j]->n_left; n_left += mem_blocks_[j]->n_left;
} }
size_t n_right = 0; size_t n_right = 0;
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) {
mem_blocks_[j]->n_offset_right = n_left + n_right; mem_blocks_[j]->n_offset_right = n_left + n_right;
n_right += mem_blocks_[j]->n_right; n_right += mem_blocks_[j]->n_right;
} }
@ -279,6 +281,30 @@ class PartitionBuilder {
return blocks_offsets_[nid] + begin / BlockSize; return blocks_offsets_[nid] + begin / BlockSize;
} }
// Copy row partitions into global cache for reuse in objective
template <typename Sampledp>
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
auto& h_pos = *p_position;
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
auto p_begin = row_set.Data()->data();
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
auto const& node = row_set[i];
if (node.node_id < 0) {
return;
}
CHECK(tree[node.node_id].IsLeaf());
if (node.begin) { // guard for empty node.
size_t ptr_offset = node.end - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin; idx != node.end; ++idx) {
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
}
}
});
}
protected: protected:
struct BlockInfo{ struct BlockInfo{
size_t n_left; size_t n_left;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017 by Contributors * Copyright 2017-2022 by Contributors
* \file row_set.h * \file row_set.h
* \brief Quick Utility to compute subset of rows * \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -15,10 +15,15 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! \brief collection of rowset */ /*! \brief collection of rowset */
class RowSetCollection { class RowSetCollection {
public: public:
RowSetCollection() = default;
RowSetCollection(RowSetCollection const&) = delete;
RowSetCollection(RowSetCollection&&) = default;
RowSetCollection& operator=(RowSetCollection const&) = delete;
RowSetCollection& operator=(RowSetCollection&&) = default;
/*! \brief data structure to store an instance set, a subset of /*! \brief data structure to store an instance set, a subset of
* rows (instances) associated with a particular node in a decision * rows (instances) associated with a particular node in a decision
* tree. */ * tree. */
@ -38,20 +43,17 @@ class RowSetCollection {
return end - begin; return end - begin;
} }
}; };
/* \brief specifies how to split a rowset into two */
struct Split {
std::vector<size_t> left;
std::vector<size_t> right;
};
inline std::vector<Elem>::const_iterator begin() const { // NOLINT std::vector<Elem>::const_iterator begin() const { // NOLINT
return elem_of_each_node_.begin(); return elem_of_each_node_.begin();
} }
inline std::vector<Elem>::const_iterator end() const { // NOLINT std::vector<Elem>::const_iterator end() const { // NOLINT
return elem_of_each_node_.end(); return elem_of_each_node_.end();
} }
size_t Size() const { return std::distance(begin(), end()); }
/*! \brief return corresponding element set given the node_id */ /*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const { inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id]; const Elem& e = elem_of_each_node_[node_id];
@ -86,6 +88,8 @@ class RowSetCollection {
} }
std::vector<size_t>* Data() { return &row_indices_; } std::vector<size_t>* Data() { return &row_indices_; }
std::vector<size_t> const* Data() const { return &row_indices_; }
// split rowset into two // split rowset into two
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id, inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
size_t n_left, size_t n_right) { size_t n_left, size_t n_right) {
@ -123,7 +127,6 @@ class RowSetCollection {
// vector: node_id -> elements // vector: node_id -> elements
std::vector<Elem> elem_of_each_node_; std::vector<Elem> elem_of_each_node_;
}; };
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

127
src/common/stats.cuh Normal file
View File

@ -0,0 +1,127 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_CUH_
#define XGBOOST_COMMON_STATS_CUH_
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <iterator> // std::distance
#include "device_helpers.cuh"
#include "linalg_op.cuh"
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace common {
/**
* \brief Compute segmented quantile on GPU.
*
* \tparam SegIt Iterator for CSR style segments indptr
* \tparam ValIt Iterator for values
*
* \param alpha The p^th quantile we want to compute
*
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
*/
template <typename SegIt, typename ValIt>
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);
dh::device_vector<size_t> sorted_idx;
using Tup = thrust::tuple<size_t, float>;
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
auto n_segments = std::distance(seg_begin, seg_end) - 1;
if (n_segments <= 0) {
return;
}
quantiles->SetDevice(ctx->gpu_id);
quantiles->Resize(n_segments);
auto d_results = quantiles->DeviceSpan();
auto d_sorted_idx = dh::ToSpan(sorted_idx);
auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx));
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
// each segment is the index of a leaf.
size_t seg_idx = i;
size_t begin = seg_begin[seg_idx];
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
if (n == 0) {
d_results[i] = std::numeric_limits<float>::quiet_NaN();
return;
}
if (alpha <= (1 / (n + 1))) {
d_results[i] = val[begin];
return;
}
if (alpha >= (n / (n + 1))) {
d_results[i] = val[common::LastOf(seg_idx, seg_begin)];
return;
}
double x = alpha * static_cast<double>(n + 1);
double k = std::floor(x) - 1;
double d = (x - 1) - k;
auto v0 = val[begin + static_cast<size_t>(k)];
auto v1 = val[begin + static_cast<size_t>(k) + 1];
d_results[seg_idx] = v0 + d * (v1 - v0);
});
}
template <typename SegIt, typename ValIt, typename WIter>
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);
dh::device_vector<size_t> sorted_idx;
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
auto d_sorted_idx = dh::ToSpan(sorted_idx);
size_t n_weights = std::distance(w_begin, w_end);
dh::device_vector<float> weights_cdf(n_weights);
dh::XGBCachingDeviceAllocator<char> caching;
auto scan_key = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(seg_beg, seg_end, i); });
auto scan_val = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return w_begin[d_sorted_idx[i]]; });
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
auto n_segments = std::distance(seg_beg, seg_end) - 1;
quantiles->SetDevice(ctx->gpu_id);
quantiles->Resize(n_segments);
auto d_results = quantiles->DeviceSpan();
auto d_weight_cdf = dh::ToSpan(weights_cdf);
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
size_t seg_idx = i;
size_t begin = seg_beg[seg_idx];
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
if (n == 0) {
d_results[i] = std::numeric_limits<float>::quiet_NaN();
return;
}
auto leaf_cdf = d_weight_cdf.subspan(begin, static_cast<size_t>(n));
auto leaf_sorted_idx = d_sorted_idx.subspan(begin, static_cast<size_t>(n));
float thresh = leaf_cdf.back() * alpha;
size_t idx = thrust::lower_bound(thrust::seq, leaf_cdf.data(),
leaf_cdf.data() + leaf_cdf.size(), thresh) -
leaf_cdf.data();
idx = std::min(idx, static_cast<size_t>(n - 1));
d_results[i] = val_begin[leaf_sorted_idx[idx]];
});
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_STATS_CUH_

95
src/common/stats.h Normal file
View File

@ -0,0 +1,95 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_
#include <algorithm>
#include <iterator>
#include <limits>
#include <vector>
#include "common.h"
#include "xgboost/linalg.h"
namespace xgboost {
namespace common {
/**
* \brief Percentile with masked array using linear interpolation.
*
* https://www.itl.nist.gov/div898/handbook/prc/section2/prc262.htm
*
* \param alpha Percentile, must be in range [0, 1].
* \param begin Iterator begin for input array.
* \param end Iterator end for input array.
*
* \return The result of interpolation.
*/
template <typename Iter>
float Quantile(double alpha, Iter const& begin, Iter const& end) {
CHECK(alpha >= 0 && alpha <= 1);
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
return std::numeric_limits<float>::quiet_NaN();
}
std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
static_assert(std::is_same<decltype(val(0)), float>::value, "");
if (alpha <= (1 / (n + 1))) {
return val(0);
}
if (alpha >= (n / (n + 1))) {
return val(sorted_idx.size() - 1);
}
assert(n != 0 && "The number of rows in a leaf can not be zero.");
double x = alpha * static_cast<double>((n + 1));
double k = std::floor(x) - 1;
CHECK_GE(k, 0);
double d = (x - 1) - k;
auto v0 = val(static_cast<size_t>(k));
auto v1 = val(static_cast<size_t>(k) + 1);
return v0 + d * (v1 - v0);
}
/**
* \brief Calculate the weighted quantile with step function. Unlike the unweighted
* version, no interpolation is used.
*
* See https://aakinshin.net/posts/weighted-quantiles/ for some discussion on computing
* weighted quantile with interpolation.
*/
template <typename Iter, typename WeightIter>
float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
return std::numeric_limits<float>::quiet_NaN();
}
std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
std::vector<float> weight_cdf(n); // S_n
// weighted cdf is sorted during construction
weight_cdf[0] = *(weights + sorted_idx[0]);
for (size_t i = 1; i < n; ++i) {
weight_cdf[i] = weight_cdf[i - 1] + *(weights + sorted_idx[i]);
}
float thresh = weight_cdf.back() * alpha;
size_t idx =
std::lower_bound(weight_cdf.cbegin(), weight_cdf.cend(), thresh) - weight_cdf.cbegin();
idx = std::min(idx, static_cast<size_t>(n - 1));
return val(idx);
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_STATS_H_

View File

@ -512,16 +512,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
} }
} }
CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data."; CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data.";
group_ptr_.clear(); common::RunLengthEncode(query_ids.cbegin(), query_ids.cend(), &group_ptr_);
group_ptr_.push_back(0);
for (size_t i = 1; i < query_ids.size(); ++i) {
if (query_ids[i] != query_ids[i - 1]) {
group_ptr_.push_back(i);
}
}
if (group_ptr_.back() != query_ids.size()) {
group_ptr_.push_back(query_ids.size());
}
data::ValidateQueryGroup(group_ptr_); data::ValidateQueryGroup(group_ptr_);
return; return;
} }

View File

@ -68,7 +68,7 @@ class IterativeDeviceDMatrix : public DMatrix {
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override; BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
bool SingleColBlock() const override { return false; } bool SingleColBlock() const override { return true; }
MetaInfo &Info() override { return info_; } MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; } MetaInfo const &Info() const override { return info_; }

View File

@ -134,9 +134,8 @@ class GBLinear : public GradientBooster {
this->updater_->SaveConfig(&j_updater); this->updater_->SaveConfig(&j_updater);
} }
void DoBoost(DMatrix *p_fmat, void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair, PredictionCacheEntry*,
HostDeviceVector<GradientPair> *in_gpair, ObjFunction const*) override {
PredictionCacheEntry*) override {
monitor_.Start("DoBoost"); monitor_.Start("DoBoost");
model_.LazyInitModel(); model_.LazyInitModel();

View File

@ -1,33 +1,34 @@
/*! /*!
* Copyright 2014-2021 by Contributors * Copyright 2014-2022 by Contributors
* \file gbtree.cc * \file gbtree.cc
* \brief gradient boosted tree implementation. * \brief gradient boosted tree implementation.
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "gbtree.h"
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <limits>
#include <algorithm> #include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "xgboost/data.h"
#include "xgboost/gbm.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_updater.h"
#include "xgboost/host_device_vector.h"
#include "gbtree.h"
#include "gbtree_model.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/timer.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/timer.h"
#include "gbtree_model.h"
#include "xgboost/data.h"
#include "xgboost/gbm.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "xgboost/objective.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_updater.h"
namespace xgboost { namespace xgboost {
namespace gbm { namespace gbm {
@ -216,53 +217,68 @@ void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_thre
} }
} }
void GBTree::DoBoost(DMatrix* p_fmat, void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
HostDeviceVector<GradientPair>* in_gpair, ObjFunction const* obj, size_t gidx,
PredictionCacheEntry* predt) { std::vector<std::unique_ptr<RegTree>>* p_trees) {
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees; CHECK(!updaters_.empty());
if (!updaters_.back()->HasNodePosition()) {
return;
}
if (!obj || !obj->Task().UpdateTreeLeaf()) {
return;
}
auto& trees = *p_trees;
for (size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
auto const& position = this->node_position_.at(tree_idx);
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, trees[tree_idx].get());
}
}
void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry* predt, ObjFunction const* obj) {
std::vector<std::vector<std::unique_ptr<RegTree>>> new_trees;
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
ConfigureWithKnownData(this->cfg_, p_fmat); ConfigureWithKnownData(this->cfg_, p_fmat);
monitor_.Start("BoostNewTrees"); monitor_.Start("BoostNewTrees");
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
// `gpu_id` be the single source of determining what algorithms to run, but that will // `gpu_id` be the single source of determining what algorithms to run, but that will
// break a lots of existing code. // break a lots of existing code.
auto device = tparam_.tree_method != TreeMethod::kGPUHist auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id;
? GenericParameter::kCpuId
: ctx_->gpu_id;
auto out = linalg::TensorView<float, 2>{ auto out = linalg::TensorView<float, 2>{
device == GenericParameter::kCpuId ? predt->predictions.HostSpan() device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(),
: predt->predictions.DeviceSpan(), {static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)},
{static_cast<size_t>(p_fmat->Info().num_row_),
static_cast<size_t>(ngroup)},
device}; device};
CHECK_NE(ngroup, 0); CHECK_NE(ngroup, 0);
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective doesn't support external memory.";
}
if (ngroup == 1) { if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree>> ret; std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret); BoostNewTrees(in_gpair, p_fmat, 0, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, &ret);
const size_t num_new_trees = ret.size(); const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret)); new_trees.push_back(std::move(ret));
auto v_predt = out.Slice(linalg::All(), 0); auto v_predt = out.Slice(linalg::All(), 0);
if (updaters_.size() > 0 && num_new_trees == 1 && if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 &&
predt->predictions.Size() > 0 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) { updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) {
predt->Update(1); predt->Update(1);
} }
} else { } else {
CHECK_EQ(in_gpair->Size() % ngroup, 0U) CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup * nrow gpairs";
<< "must have exactly ngroup * nrow gpairs"; HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup, GradientPair(),
HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup,
GradientPair(),
in_gpair->DeviceIdx()); in_gpair->DeviceIdx());
bool update_predict = true; bool update_predict = true;
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp); CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp);
std::vector<std::unique_ptr<RegTree> > ret; std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(&tmp, p_fmat, gid, &ret); BoostNewTrees(&tmp, p_fmat, gid, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, &ret);
const size_t num_new_trees = ret.size(); const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret)); new_trees.push_back(std::move(ret));
auto v_predt = out.Slice(linalg::All(), gid); auto v_predt = out.Slice(linalg::All(), gid);
if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && num_new_trees == 1 &&
num_new_trees == 1 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) { updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) {
update_predict = false; update_predict = false;
} }
@ -271,6 +287,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
predt->Update(1); predt->Update(1);
} }
} }
monitor_.Stop("BoostNewTrees"); monitor_.Stop("BoostNewTrees");
this->CommitModel(std::move(new_trees), p_fmat, predt); this->CommitModel(std::move(new_trees), p_fmat, predt);
} }
@ -316,10 +333,8 @@ void GBTree::InitUpdater(Args const& cfg) {
} }
} }
void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
DMatrix *p_fmat, std::vector<std::unique_ptr<RegTree>>* ret) {
int bst_group,
std::vector<std::unique_ptr<RegTree> >* ret) {
std::vector<RegTree*> new_trees; std::vector<RegTree*> new_trees;
ret->clear(); ret->clear();
// create the trees // create the trees
@ -338,9 +353,9 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
} else if (tparam_.process_type == TreeProcessType::kUpdate) { } else if (tparam_.process_type == TreeProcessType::kUpdate) {
for (auto const& up : updaters_) { for (auto const& up : updaters_) {
CHECK(up->CanModifyTree()) CHECK(up->CanModifyTree())
<< "Updater: `" << up->Name() << "` " << "Updater: `" << up->Name() << "` "
<< "can not be used to modify existing trees. " << "can not be used to modify existing trees. "
<< "Set `process_type` to `default` if you want to build new trees."; << "Set `process_type` to `default` if you want to build new trees.";
} }
CHECK_LT(model_.trees.size(), model_.trees_to_update.size()) CHECK_LT(model_.trees.size(), model_.trees_to_update.size())
<< "No more tree left for updating. For updating existing trees, " << "No more tree left for updating. For updating existing trees, "
@ -356,8 +371,10 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
<< "Mismatching size between number of rows from input data and size of " << "Mismatching size between number of rows from input data and size of "
"gradient vector."; "gradient vector.";
node_position_.resize(new_trees.size());
for (auto& up : updaters_) { for (auto& up : updaters_) {
up->Update(gpair, p_fmat, new_trees); up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{node_position_},
new_trees);
} }
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2021 by Contributors * Copyright 2014-2022 by Contributors
* \file gbtree.cc * \file gbtree.cc
* \brief gradient boosted tree implementation. * \brief gradient boosted tree implementation.
* \author Tianqi Chen * \author Tianqi Chen
@ -202,10 +202,16 @@ class GBTree : public GradientBooster {
void ConfigureUpdaters(); void ConfigureUpdaters();
void ConfigureWithKnownData(Args const& cfg, DMatrix* fmat); void ConfigureWithKnownData(Args const& cfg, DMatrix* fmat);
/**
* \brief Optionally update the leaf value.
*/
void UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
ObjFunction const* obj, size_t gidx,
std::vector<std::unique_ptr<RegTree>>* p_trees);
/*! \brief Carry out one iteration of boosting */ /*! \brief Carry out one iteration of boosting */
void DoBoost(DMatrix* p_fmat, void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
HostDeviceVector<GradientPair>* in_gpair, PredictionCacheEntry* predt, ObjFunction const* obj) override;
PredictionCacheEntry* predt) override;
bool UseGPU() const override { bool UseGPU() const override {
return return
@ -435,6 +441,9 @@ class GBTree : public GradientBooster {
Args cfg_; Args cfg_;
// the updaters that can be applied to each of tree // the updaters that can be applied to each of tree
std::vector<std::unique_ptr<TreeUpdater>> updaters_; std::vector<std::unique_ptr<TreeUpdater>> updaters_;
// The node position for each row, 1 HDV for each tree in the forest. Note that the
// position is negated if the row is sampled out.
std::vector<HostDeviceVector<bst_node_t>> node_position_;
// Predictors // Predictors
std::unique_ptr<Predictor> cpu_predictor_; std::unique_ptr<Predictor> cpu_predictor_;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)

View File

@ -1169,7 +1169,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Stop("GetGradient"); monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients"); TrainingObserver::Instance().Observe(gpair_, "Gradients");
gbm_->DoBoost(train.get(), &gpair_, &predt); gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
monitor_.Stop("UpdateOneIter"); monitor_.Stop("UpdateOneIter");
} }
@ -1186,7 +1186,7 @@ class LearnerImpl : public LearnerIO {
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
local_cache->Cache(train, generic_parameters_.gpu_id); local_cache->Cache(train, generic_parameters_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get())); gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#include <thrust/scan.h> #include <thrust/scan.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
@ -201,14 +201,6 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
}); });
} }
/**
* Last index of a group in a CSR style of index pointer.
*/
template <typename Idx>
XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
return indptr[group + 1] - 1;
}
double ScaleClasses(common::Span<double> results, double ScaleClasses(common::Span<double> results,
common::Span<double> local_area, common::Span<double> fp, common::Span<double> local_area, common::Span<double> fp,
common::Span<double> tp, common::Span<double> auc, common::Span<double> tp, common::Span<double> auc,
@ -300,9 +292,9 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
double fp, tp, fp_prev, tp_prev; double fp, tp, fp_prev, tp_prev;
if (i == d_unique_class_ptr[class_id]) { if (i == d_unique_class_ptr[class_id]) {
// first item is ignored, we use this thread to calculate the last item // first item is ignored, we use this thread to calculate the last item
thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; thrust::tie(fp, tp) = d_fptp[common::LastOf(class_id, d_class_ptr)];
thrust::tie(fp_prev, tp_prev) = thrust::tie(fp_prev, tp_prev) =
d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; d_neg_pos[d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)]];
} else { } else {
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
@ -413,10 +405,10 @@ double GPUMultiClassAUCOVR(common::Span<float const> predts,
} }
uint32_t class_id = d_unique_idx[i] / n_samples; uint32_t class_id = d_unique_idx[i] / n_samples;
d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
if (i == LastOf(class_id, d_unique_class_ptr)) { if (i == common::LastOf(class_id, d_unique_class_ptr)) {
// last one needs to be included. // last one needs to be included.
size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)]; size_t last = d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)];
d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; d_neg_pos[common::LastOf(class_id, d_class_ptr)] = d_fptp[last - 1];
return; return;
} }
}); });
@ -592,7 +584,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
auto data_group_begin = d_group_ptr[group_id]; auto data_group_begin = d_group_ptr[group_id];
size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin; size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin;
// last item of current group // last item of current group
if (item.idx == LastOf(group_id, d_threads_group_ptr)) { if (item.idx == common::LastOf(group_id, d_threads_group_ptr)) {
if (item.w > 0) { if (item.w > 0) {
s_d_auc[group_id] = item.predt / item.w; s_d_auc[group_id] = item.predt / item.w;
} else { } else {
@ -797,10 +789,10 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
} }
auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]); auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]);
d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
if (i == LastOf(group_idx, d_unique_class_ptr)) { if (i == common::LastOf(group_idx, d_unique_class_ptr)) {
// last one needs to be included. // last one needs to be included.
size_t last = d_unique_idx[LastOf(group_idx, d_unique_class_ptr)]; size_t last = d_unique_idx[common::LastOf(group_idx, d_unique_class_ptr)];
d_neg_pos[LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; d_neg_pos[common::LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1];
return; return;
} }
}); });
@ -821,7 +813,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
auto it = dh::MakeTransformIterator<thrust::pair<double, uint32_t>>( auto it = dh::MakeTransformIterator<thrust::pair<double, uint32_t>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) {
double fp, tp; double fp, tp;
thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; thrust::tie(fp, tp) = d_fptp[common::LastOf(g, d_group_ptr)];
double area = fp * tp; double area = fp * tp;
auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g];
if (area > 0 && n_documents >= 2) { if (area > 0 && n_documents >= 2) {

126
src/objective/adaptive.cc Normal file
View File

@ -0,0 +1,126 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include "adaptive.h"
#include <limits>
#include <vector>
#include "../common/common.h"
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace obj {
namespace detail {
void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& position,
std::vector<size_t>* p_nptr, std::vector<bst_node_t>* p_nidx,
std::vector<size_t>* p_ridx) {
auto& nptr = *p_nptr;
auto& nidx = *p_nidx;
auto& ridx = *p_ridx;
ridx = common::ArgSort<size_t>(position);
std::vector<bst_node_t> sorted_pos(position);
// permutation
for (size_t i = 0; i < position.size(); ++i) {
sorted_pos[i] = position[ridx[i]];
}
// find the first non-sampled row
auto begin_pos =
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
[](bst_node_t nidx) { return nidx >= 0; }));
CHECK_LE(begin_pos, sorted_pos.size());
std::vector<bst_node_t> leaf;
tree.WalkTree([&](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
leaf.push_back(nidx);
}
return true;
});
if (begin_pos == sorted_pos.size()) {
nidx = leaf;
return;
}
auto beg_it = sorted_pos.begin() + begin_pos;
common::RunLengthEncode(beg_it, sorted_pos.end(), &nptr);
CHECK_GT(nptr.size(), 0);
// skip the sampled rows in indptr
std::transform(nptr.begin(), nptr.end(), nptr.begin(),
[begin_pos](size_t ptr) { return ptr + begin_pos; });
size_t n_leaf = nptr.size() - 1;
auto n_unique = std::unique(beg_it, sorted_pos.end()) - beg_it;
CHECK_EQ(n_unique, n_leaf);
nidx.resize(n_leaf);
std::copy(beg_it, beg_it + n_unique, nidx.begin());
if (n_leaf != leaf.size()) {
FillMissingLeaf(leaf, &nidx, &nptr);
}
}
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree) {
auto& tree = *p_tree;
std::vector<bst_node_t> nidx;
std::vector<size_t> nptr;
std::vector<size_t> ridx;
EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx);
size_t n_leaf = nidx.size();
if (nptr.empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx, p_tree);
return;
}
CHECK(!position.empty());
std::vector<float> quantiles(n_leaf, 0);
std::vector<int32_t> n_valids(n_leaf, 0);
auto const& h_node_idx = nidx;
auto const& h_node_ptr = nptr;
CHECK_LE(h_node_ptr.back(), info.num_row_);
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
// multi-target not yet supported.
auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);
auto const& h_predt = predt.ConstHostVector();
auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt[row_idx];
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});
float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
UpdateLeafValues(&quantiles, nidx, p_tree);
}
} // namespace detail
} // namespace obj
} // namespace xgboost

182
src/objective/adaptive.cu Normal file
View File

@ -0,0 +1,182 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <thrust/sort.h>
#include <cub/cub.cuh>
#include "../common/device_helpers.cuh"
#include "../common/stats.cuh"
#include "adaptive.h"
namespace xgboost {
namespace obj {
namespace detail {
void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
dh::device_vector<size_t>* p_ridx, HostDeviceVector<size_t>* p_nptr,
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
// copy position to buffer
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
size_t n_samples = position.size();
dh::XGBDeviceAllocator<char> alloc;
dh::device_vector<bst_node_t> sorted_position(position.size());
dh::safe_cuda(cudaMemcpyAsync(sorted_position.data().get(), position.data(),
position.size_bytes(), cudaMemcpyDeviceToDevice));
p_ridx->resize(position.size());
dh::Iota(dh::ToSpan(*p_ridx));
// sort row index according to node index
thrust::stable_sort_by_key(thrust::cuda::par(alloc), sorted_position.begin(),
sorted_position.begin() + n_samples, p_ridx->begin());
dh::XGBCachingDeviceAllocator<char> caching;
auto beg_pos =
thrust::find_if(thrust::cuda::par(caching), sorted_position.cbegin(), sorted_position.cend(),
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) -
sorted_position.cbegin();
if (beg_pos == sorted_position.size()) {
auto& leaf = p_nidx->HostVector();
tree.WalkTree([&](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
leaf.push_back(nidx);
}
return true;
});
return;
}
size_t n_leaf = tree.GetNumLeaves();
size_t max_n_unique = n_leaf;
dh::caching_device_vector<size_t> counts_out(max_n_unique + 1, 0);
auto d_counts_out = dh::ToSpan(counts_out).subspan(0, max_n_unique);
auto d_num_runs_out = dh::ToSpan(counts_out).subspan(max_n_unique, 1);
dh::caching_device_vector<bst_node_t> unique_out(max_n_unique, 0);
auto d_unique_out = dh::ToSpan(unique_out);
size_t nbytes;
auto begin_it = sorted_position.begin() + beg_pos;
cub::DeviceRunLengthEncode::Encode(nullptr, nbytes, begin_it, unique_out.data().get(),
counts_out.data().get(), d_num_runs_out.data(),
n_samples - beg_pos);
dh::TemporaryArray<char> temp(nbytes);
cub::DeviceRunLengthEncode::Encode(temp.data().get(), nbytes, begin_it, unique_out.data().get(),
counts_out.data().get(), d_num_runs_out.data(),
n_samples - beg_pos);
dh::PinnedMemory pinned_pool;
auto pinned = pinned_pool.GetSpan<char>(sizeof(size_t) + sizeof(bst_node_t));
dh::CUDAStream copy_stream;
size_t* h_num_runs = reinterpret_cast<size_t*>(pinned.subspan(0, sizeof(size_t)).data());
// flag for whether there's ignored position
bst_node_t* h_first_unique =
reinterpret_cast<bst_node_t*>(pinned.subspan(sizeof(size_t), sizeof(bst_node_t)).data());
dh::safe_cuda(cudaMemcpyAsync(h_num_runs, d_num_runs_out.data(), sizeof(size_t),
cudaMemcpyDeviceToHost, copy_stream.View()));
dh::safe_cuda(cudaMemcpyAsync(h_first_unique, d_unique_out.data(), sizeof(bst_node_t),
cudaMemcpyDeviceToHost, copy_stream.View()));
/**
* copy node index (leaf index)
*/
auto& nidx = *p_nidx;
auto& nptr = *p_nptr;
nidx.SetDevice(ctx->gpu_id);
nidx.Resize(n_leaf);
auto d_node_idx = nidx.DeviceSpan();
nptr.SetDevice(ctx->gpu_id);
nptr.Resize(n_leaf + 1, 0);
auto d_node_ptr = nptr.DeviceSpan();
dh::LaunchN(n_leaf, [=] XGBOOST_DEVICE(size_t i) {
if (i >= d_num_runs_out[0]) {
// d_num_runs_out <= max_n_unique
// this omits all the leaf that are empty. A leaf can be empty when there's
// missing data, which can be caused by sparse input and distributed training.
return;
}
d_node_idx[i] = d_unique_out[i];
d_node_ptr[i + 1] = d_counts_out[i];
if (i == 0) {
d_node_ptr[0] = beg_pos;
}
});
thrust::inclusive_scan(thrust::cuda::par(caching), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr),
dh::tbegin(d_node_ptr));
copy_stream.View().Sync();
CHECK_GT(*h_num_runs, 0);
CHECK_LE(*h_num_runs, n_leaf);
if (*h_num_runs < n_leaf) {
// shrink to omit the sampled nodes.
nptr.Resize(*h_num_runs + 1);
nidx.Resize(*h_num_runs);
std::vector<bst_node_t> leaves;
tree.WalkTree([&](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
leaves.push_back(nidx);
}
return true;
});
CHECK_EQ(leaves.size(), n_leaf);
// Fill all the leaves that don't have any sample. This is hacky and inefficient. An
// alternative is to leave the objective to handle missing leaf, which is more messy
// as we need to take other distributed workers into account.
auto& h_nidx = nidx.HostVector();
auto& h_nptr = nptr.HostVector();
FillMissingLeaf(leaves, &h_nidx, &h_nptr);
nidx.DevicePointer();
nptr.DevicePointer();
}
CHECK_EQ(nidx.Size(), n_leaf);
CHECK_EQ(nptr.Size(), n_leaf + 1);
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree) {
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
dh::device_vector<size_t> ridx;
HostDeviceVector<size_t> nptr;
HostDeviceVector<bst_node_t> nidx;
EncodeTreeLeafDevice(ctx, position, &ridx, &nptr, &nidx, *p_tree);
if (nptr.Empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), p_tree);
}
HostDeviceVector<float> quantiles;
predt.SetDevice(ctx->gpu_id);
auto d_predt = predt.ConstDeviceSpan();
auto d_labels = info.labels.View(ctx->gpu_id);
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
auto seg_end = seg_beg + nptr.Size();
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) {
auto predt = d_predt[d_row_index[i]];
auto y = d_labels(d_row_index[i]);
return y - predt;
});
auto val_end = val_beg + d_labels.Size();
CHECK_EQ(nidx.Size() + 1, nptr.Size());
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
} else {
info.weights_.SetDevice(ctx->gpu_id);
auto d_weights = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weights.size(), d_row_index.size());
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
w_it + d_weights.size(), &quantiles);
}
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), p_tree);
}
} // namespace detail
} // namespace obj
} // namespace xgboost

83
src/objective/adaptive.h Normal file
View File

@ -0,0 +1,83 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#include <algorithm>
#include <limits>
#include <vector>
#include "rabit/rabit.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace obj {
namespace detail {
inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_nptr) {
auto& h_node_idx = *p_nidx;
auto& h_node_ptr = *p_nptr;
for (auto leaf : maybe_missing) {
if (std::binary_search(h_node_idx.cbegin(), h_node_idx.cend(), leaf)) {
continue;
}
auto it = std::upper_bound(h_node_idx.cbegin(), h_node_idx.cend(), leaf);
auto pos = it - h_node_idx.cbegin();
h_node_idx.insert(h_node_idx.cbegin() + pos, leaf);
h_node_ptr.insert(h_node_ptr.cbegin() + pos, h_node_ptr[pos]);
}
}
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const nidx,
RegTree* p_tree) {
auto& tree = *p_tree;
auto& quantiles = *p_quantiles;
auto const& h_node_idx = nidx;
size_t n_leaf{h_node_idx.size()};
rabit::Allreduce<rabit::op::Max>(&n_leaf, 1);
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
if (quantiles.empty()) {
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
}
// number of workers that have valid quantiles
std::vector<int32_t> n_valids(quantiles.size());
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
rabit::Allreduce<rabit::op::Sum>(n_valids.data(), n_valids.size());
// convert to 0 for all reduce
std::replace_if(
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
// use the mean value
rabit::Allreduce<rabit::op::Sum>(quantiles.data(), quantiles.size());
for (size_t i = 0; i < n_leaf; ++i) {
if (n_valids[i] > 0) {
quantiles[i] /= static_cast<float>(n_valids[i]);
} else {
// Use original leaf value if no worker can provide the quantile.
quantiles[i] = tree[h_node_idx[i]].LeafValue();
}
}
for (size_t i = 0; i < nidx.size(); ++i) {
auto nidx = h_node_idx[i];
auto q = quantiles[i];
CHECK(tree[nidx].IsLeaf());
tree[nidx].SetLeaf(q);
}
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree);
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree);
} // namespace detail
} // namespace obj
} // namespace xgboost

View File

@ -34,11 +34,11 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu);
class AFTObj : public ObjFunction { class AFTObj : public ObjFunction {
public: public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; } ObjInfo Task() const override { return ObjInfo::kSurvival; }
template <typename Distribution> template <typename Distribution>
void GetGradientImpl(const HostDeviceVector<bst_float> &preds, void GetGradientImpl(const HostDeviceVector<bst_float> &preds,

View File

@ -24,10 +24,8 @@ class HingeObj : public ObjFunction {
public: public:
HingeObj() = default; HingeObj() = default;
void Configure( void Configure(Args const&) override {}
const std::vector<std::pair<std::string, std::string> > &args) override {} ObjInfo Task() const override { return ObjInfo::kRegression; }
ObjInfo Task() const override { return {ObjInfo::kRegression, false}; }
void GetGradient(const HostDeviceVector<bst_float> &preds, void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, const MetaInfo &info,

View File

@ -46,7 +46,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
ObjInfo Task() const override { return {ObjInfo::kClassification, false}; } ObjInfo Task() const override { return ObjInfo::kClassification; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info, const MetaInfo& info,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2019 XGBoost contributors * Copyright 2015-2022 XGBoost contributors
*/ */
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dmlc/timer.h> #include <dmlc/timer.h>
@ -750,11 +750,8 @@ class SortedLabelList : dh::SegmentSorter<float> {
template <typename LambdaWeightComputerT> template <typename LambdaWeightComputerT>
class LambdaRankObj : public ObjFunction { class LambdaRankObj : public ObjFunction {
public: public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); }
param_.UpdateAllowUnknown(args); ObjInfo Task() const override { return ObjInfo::kRanking; }
}
ObjInfo Task() const override { return {ObjInfo::kRanking, false}; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info, const MetaInfo& info,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2019 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ #ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ #define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
@ -38,7 +38,7 @@ struct LinearSquareLoss {
static const char* DefaultEvalMetric() { return "rmse"; } static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:squarederror"; } static const char* Name() { return "reg:squarederror"; }
static ObjInfo Info() { return {ObjInfo::kRegression, true}; } static ObjInfo Info() { return {ObjInfo::kRegression, true, false}; }
}; };
struct SquaredLogError { struct SquaredLogError {
@ -65,7 +65,7 @@ struct SquaredLogError {
static const char* Name() { return "reg:squaredlogerror"; } static const char* Name() { return "reg:squaredlogerror"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; } static ObjInfo Info() { return ObjInfo::kRegression; }
}; };
// logistic loss for probability regression task // logistic loss for probability regression task
@ -102,14 +102,14 @@ struct LogisticRegression {
static const char* Name() { return "reg:logistic"; } static const char* Name() { return "reg:logistic"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; } static ObjInfo Info() { return ObjInfo::kRegression; }
}; };
// logistic loss for binary classification task // logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression { struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "logloss"; } static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logistic"; } static const char* Name() { return "binary:logistic"; }
static ObjInfo Info() { return {ObjInfo::kBinary, false}; } static ObjInfo Info() { return ObjInfo::kBinary; }
}; };
// logistic loss, but predict un-transformed margin // logistic loss, but predict un-transformed margin
@ -146,7 +146,7 @@ struct LogisticRaw : public LogisticRegression {
static const char* Name() { return "binary:logitraw"; } static const char* Name() { return "binary:logitraw"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; } static ObjInfo Info() { return ObjInfo::kRegression; }
}; };
} // namespace obj } // namespace obj

View File

@ -4,10 +4,10 @@
* \brief Definition of single-value regression and classification objectives. * \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen * \author Tianqi Chen, Kailong Chen
*/ */
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <xgboost/tree_model.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
@ -19,12 +19,18 @@
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/transform.h" #include "../common/transform.h"
#include "./regression_loss.h" #include "./regression_loss.h"
#include "adaptive.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh"
#include "../common/linalg_op.cuh" #include "../common/linalg_op.cuh"
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
@ -67,9 +73,7 @@ class RegLossObj : public ObjFunction {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
struct ObjInfo Task() const override { ObjInfo Task() const override { return Loss::Info(); }
return Loss::Info();
}
uint32_t Targets(MetaInfo const& info) const override { uint32_t Targets(MetaInfo const& info) const override {
// Multi-target regression. // Multi-target regression.
@ -209,7 +213,7 @@ class PseudoHuberRegression : public ObjFunction {
public: public:
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
struct ObjInfo Task() const override { return {ObjInfo::kRegression, false}; } ObjInfo Task() const override { return ObjInfo::kRegression; }
uint32_t Targets(MetaInfo const& info) const override { uint32_t Targets(MetaInfo const& info) const override {
return std::max(static_cast<size_t>(1), info.labels.Shape(1)); return std::max(static_cast<size_t>(1), info.labels.Shape(1));
} }
@ -286,9 +290,7 @@ class PoissonRegression : public ObjFunction {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
struct ObjInfo Task() const override { ObjInfo Task() const override { return ObjInfo::kRegression; }
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int, const MetaInfo &info, int,
@ -378,12 +380,8 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
// cox regression for survival data (negative values mean they are censored) // cox regression for survival data (negative values mean they are censored)
class CoxRegression : public ObjFunction { class CoxRegression : public ObjFunction {
public: public:
void Configure( void Configure(Args const&) override {}
const std::vector<std::pair<std::string, std::string> >&) override {} ObjInfo Task() const override { return ObjInfo::kRegression; }
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int, const MetaInfo &info, int,
@ -479,12 +477,8 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
// gamma regression // gamma regression
class GammaRegression : public ObjFunction { class GammaRegression : public ObjFunction {
public: public:
void Configure( void Configure(Args const&) override {}
const std::vector<std::pair<std::string, std::string> >&) override {} ObjInfo Task() const override { return ObjInfo::kRegression; }
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float> &preds, void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, int, const MetaInfo &info, int,
@ -582,9 +576,7 @@ class TweedieRegression : public ObjFunction {
metric_ = os.str(); metric_ = os.str();
} }
struct ObjInfo Task() const override { ObjInfo Task() const override { return ObjInfo::kRegression; }
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int, const MetaInfo &info, int,
@ -675,5 +667,65 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
.describe("Tweedie regression for insurance data.") .describe("Tweedie regression for insurance data.")
.set_body([]() { return new TweedieRegression(); }); .set_body([]() { return new TweedieRegression(); });
class MeanAbsoluteError : public ObjFunction {
public:
void Configure(Args const&) override {}
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds);
auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(info.labels.Size());
auto gpair = linalg::MakeVec(out_gpair);
preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds);
info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
};
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
auto grad = sign(predt(i) - y) * weight[i];
auto hess = weight[sample_id];
gpair(i) = GradientPair{grad, hess};
});
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const override {
if (ctx_->IsCPU()) {
auto const& h_position = position.ConstHostVector();
detail::UpdateTreeLeafHost(ctx_, h_position, info, prediction, 0.5, p_tree);
} else {
#if defined(XGBOOST_USE_CUDA)
position.SetDevice(ctx_->gpu_id);
auto d_position = position.ConstDeviceSpan();
detail::UpdateTreeLeafDevice(ctx_, d_position, info, prediction, 0.5, p_tree);
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
}
const char* DefaultEvalMetric() const override { return "mae"; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("reg:absoluteerror");
}
void LoadConfig(Json const& in) override {}
};
XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror")
.describe("Mean absoluate error.")
.set_body([]() { return new MeanAbsoluteError(); });
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -1,13 +1,17 @@
/*! /*!
* Copyright 2017-2019 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#pragma once #pragma once
#include <limits>
#include <vector>
#include "xgboost/base.h" #include "xgboost/base.h"
#include "../../common/device_helpers.cuh" #include "../../common/device_helpers.cuh"
#include "xgboost/generic_parameters.h"
#include "xgboost/task.h"
#include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
/*! \brief Count how many rows are assigned to left node. */ /*! \brief Count how many rows are assigned to left node. */
__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) { __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) {
#if __CUDACC_VER_MAJOR__ > 8 #if __CUDACC_VER_MAJOR__ > 8
@ -149,23 +153,48 @@ class RowPartitioner {
} }
/** /**
* \brief Finalise the position of all training instances after tree * \brief Finalise the position of all training instances after tree construction is
* construction is complete. Does not update any other meta information in * complete. Does not update any other meta information in this data structure, so
* this data structure, so should only be used at the end of training. * should only be used at the end of training.
* *
* \param op Device lambda. Should provide the row index and current * When the task requires update leaf, this function will copy the node index into
* position as an argument and return the new position for this training * p_out_position. The index is negated if it's being sampled in current iteration.
* instance. *
* \param p_out_position Node index for each row.
* \param op Device lambda. Should provide the row index and current position as an
* argument and return the new position for this training instance.
* \param sampled A device lambda to inform the partitioner whether a row is sampled.
*/ */
template <typename FinalisePositionOpT> template <typename FinalisePositionOpT, typename Sampledp>
void FinalisePosition(FinalisePositionOpT op) { void FinalisePosition(Context const* ctx, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position, FinalisePositionOpT op,
Sampledp sampledp) {
auto d_position = position_.Current(); auto d_position = position_.Current();
const auto d_ridx = ridx_.Current(); const auto d_ridx = ridx_.Current();
if (!task.UpdateTreeLeaf()) {
dh::LaunchN(position_.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
if (new_position == kIgnoredTreePosition) {
return;
}
d_position[idx] = new_position;
});
return;
}
p_out_position->SetDevice(ctx->gpu_id);
p_out_position->Resize(position_.Size());
auto sorted_position = p_out_position->DevicePointer();
dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { dh::LaunchN(position_.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx]; auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx]; RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position); bst_node_t new_position = op(ridx, position);
if (new_position == kIgnoredTreePosition) return; sorted_position[ridx] = sampledp(ridx) ? ~new_position : new_position;
if (new_position == kIgnoredTreePosition) {
return;
}
d_position[idx] = new_position; d_position[idx] = new_position;
}); });
} }

View File

@ -390,7 +390,6 @@ void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_las
CHECK(p_last_tree); CHECK(p_last_tree);
auto const &tree = *p_last_tree; auto const &tree = *p_last_tree;
auto const &snode = hist_evaluator.Stats();
auto evaluator = hist_evaluator.Evaluator(); auto evaluator = hist_evaluator.Evaluator();
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
size_t n_nodes = p_last_tree->GetNodes().size(); size_t n_nodes = p_last_tree->GetNodes().size();
@ -401,9 +400,7 @@ void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_las
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) { common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) { if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx]; auto const &rowset = part[nidx];
auto const &stats = snode[nidx]; auto leaf_value = tree[nidx].LeafValue();
auto leaf_value =
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value; out_preds(*it) += leaf_value;
} }

View File

@ -19,6 +19,7 @@
#include "param.h" #include "param.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
namespace xgboost { namespace xgboost {
@ -154,6 +155,18 @@ class GloablApproxBuilder {
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
void LeafPartition(RegTree const &tree, common::Span<float> hess,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!evaluator_.Task().UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, hess, p_out_position);
}
monitor_->Stop(__func__);
}
public: public:
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task, std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
@ -164,8 +177,8 @@ class GloablApproxBuilder {
ctx_{ctx}, ctx_{ctx},
monitor_{monitor} {} monitor_{monitor} {}
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess, void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
DMatrix *p_fmat) { RegTree *p_tree, HostDeviceVector<bst_node_t> *p_out_position) {
p_last_tree_ = p_tree; p_last_tree_ = p_tree;
this->InitData(p_fmat, hess); this->InitData(p_fmat, hess);
@ -231,6 +244,9 @@ class GloablApproxBuilder {
driver.Push(best_splits.begin(), best_splits.end()); driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop(); expand_set = driver.Pop();
} }
auto &h_position = p_out_position->HostVector();
this->LeafPartition(tree, hess, &h_position);
} }
}; };
@ -275,6 +291,7 @@ class GlobalApproxUpdater : public TreeUpdater {
sampled->resize(h_gpair.size()); sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom(); auto &rnd = common::GlobalRandom();
if (param.subsample != 1.0) { if (param.subsample != 1.0) {
CHECK(param.sampling_method != TrainParam::kGradientBased) CHECK(param.sampling_method != TrainParam::kGradientBased)
<< "Gradient based sampling is not supported for approx tree method."; << "Gradient based sampling is not supported for approx tree method.";
@ -292,6 +309,7 @@ class GlobalApproxUpdater : public TreeUpdater {
char const *Name() const override { return "grow_histmaker"; } char const *Name() const override { return "grow_histmaker"; }
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m, void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -313,12 +331,14 @@ class GlobalApproxUpdater : public TreeUpdater {
cached_ = m; cached_ = m;
size_t t_idx = 0;
for (auto p_tree : trees) { for (auto p_tree : trees) {
if (hist_param_.single_precision_histogram) { if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m); this->f32_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]);
} else { } else {
this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m); this->f64_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]);
} }
++t_idx;
} }
param_.learning_rate = lr; param_.learning_rate = lr;
} }
@ -335,6 +355,8 @@ class GlobalApproxUpdater : public TreeUpdater {
} }
return true; return true;
} }
bool HasNodePosition() const override { return true; }
}; };
DMLC_REGISTRY_FILE_TAG(grow_histmaker); DMLC_REGISTRY_FILE_TAG(grow_histmaker);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 XGBoost contributors * Copyright 2021-2022 XGBoost contributors
* *
* \brief Implementation for the approx tree method. * \brief Implementation for the approx tree method.
*/ */
@ -18,6 +18,7 @@
#include "hist/expand_entry.h" #include "hist/expand_entry.h"
#include "hist/param.h" #include "hist/param.h"
#include "param.h" #include "param.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
@ -122,6 +123,12 @@ class ApproxRowPartitioner {
auto const &Partitions() const { return row_set_collection_; } auto const &Partitions() const { return row_set_collection_; }
void LeafPartition(Context const *ctx, RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
}
auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }

View File

@ -96,9 +96,9 @@ class ColMaker: public TreeUpdater {
} }
} }
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
DMatrix* dmat, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree *> &trees) override {
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
"support distributed training."; "support distributed training.";

View File

@ -11,6 +11,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
@ -35,6 +38,8 @@
#include "gpu_hist/histogram.cuh" #include "gpu_hist/histogram.cuh"
#include "gpu_hist/evaluate_splits.cuh" #include "gpu_hist/evaluate_splits.cuh"
#include "gpu_hist/expand_entry.cuh" #include "gpu_hist/expand_entry.cuh"
#include "xgboost/task.h"
#include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -161,9 +166,9 @@ template <typename GradientSumT>
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
private: private:
GPUHistEvaluator<GradientSumT> evaluator_; GPUHistEvaluator<GradientSumT> evaluator_;
Context const* ctx_;
public: public:
int device_id;
EllpackPageImpl const* page; EllpackPageImpl const* page;
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
BatchParam batch_param; BatchParam batch_param;
@ -195,12 +200,12 @@ struct GPUHistMakerDevice {
// Storing split categories for last node. // Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories; dh::caching_device_vector<uint32_t> node_categories;
GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page, GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows, common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
BatchParam _batch_param) BatchParam _batch_param)
: evaluator_{_param, n_features, _device_id}, : evaluator_{_param, n_features, ctx->gpu_id},
device_id(_device_id), ctx_(ctx),
page(_page), page(_page),
feature_types{_feature_types}, feature_types{_feature_types},
param(std::move(_param)), param(std::move(_param)),
@ -216,14 +221,15 @@ struct GPUHistMakerDevice {
node_sum_gradients.resize(param.MaxNodes()); node_sum_gradients.resize(param.MaxNodes());
// Init histogram // Init histogram
hist.Init(device_id, page->Cuts().TotalBins()); hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
feature_groups.reset(new FeatureGroups( feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT))); dh::MaxSharedMemoryOptin(ctx_->gpu_id),
sizeof(GradientSumT)));
} }
~GPUHistMakerDevice() { // NOLINT ~GPUHistMakerDevice() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
} }
// Reset values for each update iteration // Reset values for each update iteration
@ -235,10 +241,10 @@ struct GPUHistMakerDevice {
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(), this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel, param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param, this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param,
device_id); ctx_->gpu_id);
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
@ -256,7 +262,7 @@ struct GPUHistMakerDevice {
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair); histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows)); row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
hist.Reset(); hist.Reset();
} }
@ -264,10 +270,10 @@ struct GPUHistMakerDevice {
int nidx = RegTree::kRoot; int nidx = RegTree::kRoot;
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0); auto sampled_features = column_sampler.GetFeatureSet(0);
sampled_features->SetDevice(device_id); sampled_features->SetDevice(ctx_->gpu_id);
common::Span<bst_feature_t> feature_set = common::Span<bst_feature_t> feature_set =
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
auto matrix = page->GetDeviceAccessor(device_id); auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
EvaluateSplitInputs<GradientSumT> inputs{nidx, EvaluateSplitInputs<GradientSumT> inputs{nidx,
root_sum, root_sum,
gpu_param, gpu_param,
@ -287,14 +293,14 @@ struct GPUHistMakerDevice {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2); dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(device_id); left_sampled_features->SetDevice(ctx_->gpu_id);
common::Span<bst_feature_t> left_feature_set = common::Span<bst_feature_t> left_feature_set =
interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx); interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(device_id); right_sampled_features->SetDevice(ctx_->gpu_id);
common::Span<bst_feature_t> right_feature_set = common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx); interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx);
auto matrix = page->GetDeviceAccessor(device_id); auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
EvaluateSplitInputs<GradientSumT> left{left_nidx, EvaluateSplitInputs<GradientSumT> left{left_nidx,
candidate.split.left_sum, candidate.split.left_sum,
@ -325,8 +331,8 @@ struct GPUHistMakerDevice {
hist.AllocateHistogram(nidx); hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx); auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(device_id), BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
feature_groups->DeviceAccessor(device_id), gpair, feature_groups->DeviceAccessor(ctx_->gpu_id), gpair,
d_ridx, d_node_hist, histogram_rounding); d_ridx, d_node_hist, histogram_rounding);
} }
@ -351,7 +357,7 @@ struct GPUHistMakerDevice {
void UpdatePosition(int nidx, RegTree* p_tree) { void UpdatePosition(int nidx, RegTree* p_tree) {
RegTree::Node split_node = (*p_tree)[nidx]; RegTree::Node split_node = (*p_tree)[nidx];
auto split_type = p_tree->NodeSplitType(nidx); auto split_type = p_tree->NodeSplitType(nidx);
auto d_matrix = page->GetDeviceAccessor(device_id); auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto node_cats = dh::ToSpan(node_categories); auto node_cats = dh::ToSpan(node_categories);
row_partitioner->UpdatePosition( row_partitioner->UpdatePosition(
@ -384,7 +390,8 @@ struct GPUHistMakerDevice {
// After tree update is finished, update the position of all training // After tree update is finished, update the position of all training
// instances to their final leaf. This information is used later to update the // instances to their final leaf. This information is used later to update the
// prediction cache // prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) {
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size()); dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node), d_nodes.size() * sizeof(RegTree::Node),
@ -405,17 +412,21 @@ struct GPUHistMakerDevice {
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, p_fmat->Info().num_row_));
}
if (task.UpdateTreeLeaf() && !p_fmat->SingleColBlock() && param.subsample != 1.0) {
// see comment in the `FinalisePositionInPage`.
LOG(FATAL) << "Current objective function can not be used with subsampled external memory.";
} }
if (page->n_rows == p_fmat->Info().num_row_) { if (page->n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, dh::ToSpan(d_nodes), FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_split_types), dh::ToSpan(d_categories), dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task,
dh::ToSpan(d_categories_segments)); p_out_position);
} else { } else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) { for (auto const& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_split_types), dh::ToSpan(d_categories), dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task,
dh::ToSpan(d_categories_segments)); p_out_position);
} }
} }
} }
@ -424,9 +435,13 @@ struct GPUHistMakerDevice {
const common::Span<RegTree::Node> d_nodes, const common::Span<RegTree::Node> d_nodes,
common::Span<FeatureType const> d_feature_types, common::Span<FeatureType const> d_feature_types,
common::Span<uint32_t const> categories, common::Span<uint32_t const> categories,
common::Span<RegTree::Segment> categories_segments) { common::Span<RegTree::Segment> categories_segments,
auto d_matrix = page->GetDeviceAccessor(device_id); ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) {
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto d_gpair = this->gpair;
row_partitioner->FinalisePosition( row_partitioner->FinalisePosition(
ctx_, task, p_out_position,
[=] __device__(size_t row_id, int position) { [=] __device__(size_t row_id, int position) {
// What happens if user prune the tree? // What happens if user prune the tree?
if (!d_matrix.IsInRange(row_id)) { if (!d_matrix.IsInRange(row_id)) {
@ -457,13 +472,20 @@ struct GPUHistMakerDevice {
} }
node = d_nodes[position]; node = d_nodes[position];
} }
return position; return position;
},
[d_gpair] __device__(size_t ridx) {
// FIXME(jiamingy): Doesn't work when sampling is used with external memory as
// the sampler compacts the gradient vector.
return d_gpair[ridx].GetHess() - .0f == 0.f;
}); });
} }
void UpdatePredictionCache(linalg::VectorView<float> out_preds_d) { void UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
dh::safe_cuda(cudaSetDevice(device_id)); CHECK(p_tree);
CHECK_EQ(out_preds_d.DeviceIdx(), device_id); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
auto d_ridx = row_partitioner->GetRows(); auto d_ridx = row_partitioner->GetRows();
GPUTrainingParam param_d(param); GPUTrainingParam param_d(param);
@ -476,12 +498,15 @@ struct GPUHistMakerDevice {
auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto tree_evaluator = evaluator_.GetEvaluator(); auto tree_evaluator = evaluator_.GetEvaluator();
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable { auto const& h_nodes = p_tree->GetNodes();
int pos = d_position[local_idx]; dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
bst_float weight = dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]}); h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice));
static_assert(!std::is_const<decltype(out_preds_d)>::value, ""); auto d_nodes = dh::ToSpan(nodes);
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate; dh::LaunchN(d_ridx.size(), [=] XGBOOST_DEVICE(size_t idx) mutable {
bst_node_t nidx = d_position[idx];
auto weight = d_nodes[nidx].LeafValue();
out_preds_d(d_ridx[idx]) += weight;
}); });
row_partitioner.reset(); row_partitioner.reset();
} }
@ -610,7 +635,8 @@ struct GPUHistMakerDevice {
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, dh::AllReducer* reducer) { RegTree* p_tree, dh::AllReducer* reducer,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree; auto& tree = *p_tree;
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy)); Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
@ -641,7 +667,7 @@ struct GPUHistMakerDevice {
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild(); int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed // Only create child entries if needed_
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) { num_leaves)) {
monitor.Start("UpdatePosition"); monitor.Start("UpdatePosition");
@ -671,7 +697,7 @@ struct GPUHistMakerDevice {
} }
monitor.Start("FinalisePosition"); monitor.Start("FinalisePosition");
this->FinalisePosition(p_tree, p_fmat); this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
monitor.Stop("FinalisePosition"); monitor.Stop("FinalisePosition");
} }
}; };
@ -682,7 +708,7 @@ class GPUHistMakerSpecialised {
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
void Configure(const Args& args, GenericParameter const* generic_param) { void Configure(const Args& args, GenericParameter const* generic_param) {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
generic_param_ = generic_param; ctx_ = generic_param;
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
dh::CheckComputeCapability(); dh::CheckComputeCapability();
@ -694,20 +720,24 @@ class GPUHistMakerSpecialised {
} }
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) {
monitor_.Start("Update"); monitor_.Start("Update");
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
// build tree // build tree
try { try {
size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, dmat, tree); this->UpdateTree(gpair, dmat, tree, &out_position[t_idx]);
if (hist_maker_param_.debug_synchronize) { if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(tree); this->CheckTreesSynchronized(tree);
} }
++t_idx;
} }
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
} catch (const std::exception& e) { } catch (const std::exception& e) {
@ -719,41 +749,36 @@ class GPUHistMakerSpecialised {
} }
void InitDataOnce(DMatrix* dmat) { void InitDataOnce(DMatrix* dmat) {
device_ = generic_param_->gpu_id; CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
CHECK_GE(device_, 0) << "Must have at least one device";
info_ = &dmat->Info(); info_ = &dmat->Info();
reducer_.Init({device_}); // NOLINT reducer_.Init({ctx_->gpu_id}); // NOLINT
// Synchronise the column sampling seed // Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()(); uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
BatchParam batch_param{ BatchParam batch_param{
device_, ctx_->gpu_id,
param_.max_bin, param_.max_bin,
}; };
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl(); auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
info_->feature_types.SetDevice(device_); info_->feature_types.SetDevice(ctx_->gpu_id);
maker.reset(new GPUHistMakerDevice<GradientSumT>(device_, maker.reset(new GPUHistMakerDevice<GradientSumT>(
page, ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_,
info_->feature_types.ConstDeviceSpan(), column_sampling_seed, info_->num_col_, batch_param));
info_->num_row_,
param_,
column_sampling_seed,
info_->num_col_,
batch_param));
p_last_fmat_ = dmat; p_last_fmat_ = dmat;
initialised_ = true; initialised_ = true;
} }
void InitData(DMatrix* dmat) { void InitData(DMatrix* dmat, RegTree const* p_tree) {
if (!initialised_) { if (!initialised_) {
monitor_.Start("InitDataOnce"); monitor_.Start("InitDataOnce");
this->InitDataOnce(dmat); this->InitDataOnce(dmat);
monitor_.Stop("InitDataOnce"); monitor_.Stop("InitDataOnce");
} }
p_last_tree_ = p_tree;
} }
// Only call this method for testing // Only call this method for testing
@ -771,13 +796,14 @@ class GPUHistMakerSpecialised {
CHECK(*local_tree == reference_tree); CHECK(*local_tree == reference_tree);
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree) { void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData"); monitor_.Start("InitData");
this->InitData(p_fmat); this->InitData(p_fmat, p_tree);
monitor_.Stop("InitData"); monitor_.Stop("InitData");
gpair->SetDevice(device_); gpair->SetDevice(ctx_->gpu_id);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
} }
bool UpdatePredictionCache(const DMatrix *data, bool UpdatePredictionCache(const DMatrix *data,
@ -786,7 +812,7 @@ class GPUHistMakerSpecialised {
return false; return false;
} }
monitor_.Start("UpdatePredictionCache"); monitor_.Start("UpdatePredictionCache");
maker->UpdatePredictionCache(p_out_preds); maker->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache"); monitor_.Stop("UpdatePredictionCache");
return true; return true;
} }
@ -800,12 +826,12 @@ class GPUHistMakerSpecialised {
bool initialised_ { false }; bool initialised_ { false };
GPUHistMakerTrainParam hist_maker_param_; GPUHistMakerTrainParam hist_maker_param_;
GenericParameter const* generic_param_; Context const* ctx_;
dh::AllReducer reducer_; dh::AllReducer reducer_;
DMatrix* p_last_fmat_ { nullptr }; DMatrix* p_last_fmat_ { nullptr };
int device_{-1}; RegTree const* p_last_tree_{nullptr};
ObjInfo task_; ObjInfo task_;
common::Monitor monitor_; common::Monitor monitor_;
@ -859,17 +885,17 @@ class GPUHistMaker : public TreeUpdater {
} }
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
float_maker_->Update(gpair, dmat, trees); float_maker_->Update(gpair, dmat, out_position, trees);
} else { } else {
double_maker_->Update(gpair, dmat, trees); double_maker_->Update(gpair, dmat, out_position, trees);
} }
} }
bool bool UpdatePredictionCache(const DMatrix* data,
UpdatePredictionCache(const DMatrix *data, linalg::VectorView<bst_float> p_out_preds) override {
linalg::VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds); return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else { } else {
@ -881,6 +907,8 @@ class GPUHistMaker : public TreeUpdater {
return "grow_gpu_hist"; return "grow_gpu_hist";
} }
bool HasNodePosition() const override { return true; }
private: private:
GPUHistMakerTrainParam hist_maker_param_; GPUHistMakerTrainParam hist_maker_param_;
ObjInfo task_; ObjInfo task_;

View File

@ -24,9 +24,9 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
class HistMaker: public BaseMaker { class HistMaker: public BaseMaker {
public: public:
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
DMatrix *p_fmat, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree *> &trees) override {
interaction_constraints_.Configure(param_, p_fmat->Info().num_col_); interaction_constraints_.Configure(param_, p_fmat->Info().num_col_);
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;

View File

@ -50,9 +50,9 @@ class TreePruner: public TreeUpdater {
} }
// update the tree, do pruning // update the tree, do pruning
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
DMatrix *p_fmat, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*>& trees) override {
pruner_monitor_.Start("PrunerUpdate"); pruner_monitor_.Start("PrunerUpdate");
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
@ -61,7 +61,7 @@ class TreePruner: public TreeUpdater {
this->DoPrune(tree); this->DoPrune(tree);
} }
param_.learning_rate = lr; param_.learning_rate = lr;
syncher_->Update(gpair, p_fmat, trees); syncher_->Update(gpair, p_fmat, out_position, trees);
pruner_monitor_.Stop("PrunerUpdate"); pruner_monitor_.Stop("PrunerUpdate");
} }

View File

@ -36,6 +36,7 @@ void QuantileHistMaker::Configure(const Args &args) {
} }
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat, void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) { const std::vector<RegTree *> &trees) {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
@ -53,12 +54,15 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *d
} }
} }
size_t t_idx{0};
for (auto p_tree : trees) { for (auto p_tree : trees) {
auto &t_row_position = out_position[t_idx];
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
this->float_builder_->UpdateTree(gpair, dmat, p_tree); this->float_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position);
} else { } else {
this->double_builder_->UpdateTree(gpair, dmat, p_tree); this->double_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position);
} }
++t_idx;
} }
param_.learning_rate = lr; param_.learning_rate = lr;
@ -169,13 +173,29 @@ void QuantileHistMaker::Builder<GradientSumT>::BuildHistogram(
} }
} }
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::LeafPartition(
RegTree const &tree, common::Span<GradientPair const> gpair,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!evaluator_->Task().UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position);
}
monitor_->Stop(__func__);
}
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree( void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) { DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
HostDeviceVector<bst_node_t> *p_out_position) {
monitor_->Start(__func__); monitor_->Start(__func__);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)); Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
auto const &tree = *p_tree;
bst_node_t num_leaves{1}; bst_node_t num_leaves{1};
auto expand_set = driver.Pop(); auto expand_set = driver.Pop();
@ -208,7 +228,6 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
std::vector<CPUExpandEntry> best_splits; std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) { if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h); this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h);
auto const &tree = *p_tree;
for (auto const &candidate : valid_candidates) { for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild(); int right_child_nidx = tree[candidate.nid].RightChild();
@ -228,12 +247,15 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
expand_set = driver.Pop(); expand_set = driver.Pop();
} }
auto &h_out_position = p_out_position->HostVector();
this->LeafPartition(tree, gpair_h, &h_out_position);
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::UpdateTree(HostDeviceVector<GradientPair> *gpair, void QuantileHistMaker::Builder<GradientSumT>::UpdateTree(
DMatrix *p_fmat, RegTree *p_tree) { HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat, RegTree *p_tree,
HostDeviceVector<bst_node_t> *p_out_position) {
monitor_->Start(__func__); monitor_->Start(__func__);
std::vector<GradientPair> *gpair_ptr = &(gpair->HostVector()); std::vector<GradientPair> *gpair_ptr = &(gpair->HostVector());
@ -246,8 +268,7 @@ void QuantileHistMaker::Builder<GradientSumT>::UpdateTree(HostDeviceVector<Gradi
this->InitData(p_fmat, *p_tree, gpair_ptr); this->InitData(p_fmat, *p_tree, gpair_ptr);
ExpandTree(p_fmat, p_tree, *gpair_ptr); ExpandTree(p_fmat, p_tree, *gpair_ptr, p_out_position);
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }

View File

@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/json.h" #include "xgboost/json.h"
@ -214,6 +215,15 @@ class HistRowPartitioner {
size_t Size() const { size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end()); return std::distance(row_set_collection_.begin(), row_set_collection_.end());
} }
void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
}
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
}; };
@ -228,8 +238,8 @@ class QuantileHistMaker: public TreeUpdater {
explicit QuantileHistMaker(ObjInfo task) : task_{task} {} explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
void Configure(const Args& args) override; void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
DMatrix* dmat, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override; const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix *data, bool UpdatePredictionCache(const DMatrix *data,
@ -266,6 +276,8 @@ class QuantileHistMaker: public TreeUpdater {
return "grow_quantile_histmaker"; return "grow_quantile_histmaker";
} }
bool HasNodePosition() const override { return true; }
protected: protected:
CPUHistMakerTrainParam hist_maker_param_; CPUHistMakerTrainParam hist_maker_param_;
// training parameter // training parameter
@ -289,7 +301,8 @@ class QuantileHistMaker: public TreeUpdater {
monitor_->Init("Quantile::Builder"); monitor_->Init("Quantile::Builder");
} }
// update one tree, growing // update one tree, growing
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree); void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
HostDeviceVector<bst_node_t>* p_out_position);
bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const; bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const;
@ -308,7 +321,11 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<CPUExpandEntry> const& valid_candidates, std::vector<CPUExpandEntry> const& valid_candidates,
std::vector<GradientPair> const& gpair); std::vector<GradientPair> const& gpair);
void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector<GradientPair>& gpair_h); void LeafPartition(RegTree const& tree, common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position);
void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector<GradientPair>& gpair_h,
HostDeviceVector<bst_node_t>* p_out_position);
private: private:
const size_t n_trees_; const size_t n_trees_;

View File

@ -42,9 +42,9 @@ class TreeRefresher: public TreeUpdater {
return true; return true;
} }
// update the tree, do pruning // update the tree, do pruning
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
DMatrix *p_fmat, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree *> &trees) override {
if (trees.size() == 0) return; if (trees.size() == 0) return;
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector(); const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector();
// thread temporal space // thread temporal space

View File

@ -31,9 +31,9 @@ class TreeSyncher: public TreeUpdater {
return "prune"; return "prune";
} }
void Update(HostDeviceVector<GradientPair>* , void Update(HostDeviceVector<GradientPair>*, DMatrix*,
DMatrix*, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*>& trees) override {
if (rabit::GetWorldSize() == 1) return; if (rabit::GetWorldSize() == 1) return;
std::string s_model; std::string s_model;
common::MemoryBufferStream fs(&s_model); common::MemoryBufferStream fs(&s_model);

View File

@ -0,0 +1,58 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/generic_parameters.h>
#include "../../../src/common/stats.h"
namespace xgboost {
namespace common {
TEST(Stats, Quantile) {
{
linalg::Tensor<float, 1> arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId);
std::vector<size_t> index{0, 2, 3, 4, 6};
auto h_arr = arr.HostView();
auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); });
auto end = beg + index.size();
auto q = Quantile(0.40f, beg, end);
ASSERT_EQ(q, 26.0);
q = Quantile(0.20f, beg, end);
ASSERT_EQ(q, 16.0);
q = Quantile(0.10f, beg, end);
ASSERT_EQ(q, 15.0);
}
{
std::vector<float> vec{1., 2., 3., 4., 5.};
auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; });
auto end = beg + vec.size();
auto q = Quantile(0.5f, beg, end);
ASSERT_EQ(q, 3.);
}
}
TEST(Stats, WeightedQuantile) {
linalg::Tensor<float, 1> arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId);
linalg::Tensor<float, 1> weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId);
auto h_arr = arr.HostView();
auto h_weight = weight.HostView();
auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(i); });
auto end = beg + arr.Size();
auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); });
auto q = WeightedQuantile(0.50f, beg, end, w);
ASSERT_EQ(q, 3);
q = WeightedQuantile(0.0, beg, end, w);
ASSERT_EQ(q, 1);
q = WeightedQuantile(1.0, beg, end, w);
ASSERT_EQ(q, 5);
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,77 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <utility>
#include <vector>
#include "../../../src/common/stats.cuh"
#include "xgboost/base.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
namespace xgboost {
namespace common {
namespace {
class StatsGPU : public ::testing::Test {
private:
linalg::Tensor<float, 1> arr_{
{1.f, 2.f, 3.f, 4.f, 5.f,
2.f, 4.f, 5.f, 3.f, 1.f},
{10}, 0};
linalg::Tensor<size_t, 1> indptr_{{0, 5, 10}, {3}, 0};
HostDeviceVector<float> resutls_;
using TestSet = std::vector<std::pair<float, float>>;
Context ctx_;
void Check(float expected) {
auto const& h_results = resutls_.HostVector();
ASSERT_EQ(h_results.size(), indptr_.Size() - 1);
ASSERT_EQ(h_results.front(), expected);
EXPECT_EQ(h_results.back(), expected);
}
public:
void SetUp() override { ctx_.gpu_id = 0; }
void Weighted() {
auto d_arr = arr_.View(0);
auto d_key = indptr_.View(0);
auto key_it = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
[=] __device__(size_t i) { return d_key(i); });
auto val_it = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); });
linalg::Tensor<float, 1> weights{{10}, 0};
linalg::ElementWiseTransformDevice(weights.View(0),
[=] XGBOOST_DEVICE(size_t, float) { return 1.0; });
auto w_it = weights.Data()->ConstDevicePointer();
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
SegmentedWeightedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
val_it + arr_.Size(), w_it, w_it + weights.Size(), &resutls_);
this->Check(pair.second);
}
}
void NonWeighted() {
auto d_arr = arr_.View(0);
auto d_key = indptr_.View(0);
auto key_it = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
[=] __device__(size_t i) { return d_key(i); });
auto val_it = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); });
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
SegmentedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
val_it + arr_.Size(), &resutls_);
this->Check(pair.second);
}
}
};
} // anonymous namespace
TEST_F(StatsGPU, Quantile) { this->NonWeighted(); }
TEST_F(StatsGPU, WeightedQuantile) { this->Weighted(); }
} // namespace common
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 XGBoost contributors * Copyright 2019-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
@ -69,13 +69,13 @@ TEST(GBTree, PredictionCache) {
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows); auto gpair = GenerateRandomGradients(kRows);
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(1, out_predictions.version); ASSERT_EQ(1, out_predictions.version);
std::vector<float> first_iter = out_predictions.predictions.HostVector(); std::vector<float> first_iter = out_predictions.predictions.HostVector();
// Add 1 more boosted round // Add 1 more boosted round
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version); ASSERT_EQ(2, out_predictions.version);
// Update the cache for all rounds // Update the cache for all rounds
@ -83,7 +83,7 @@ TEST(GBTree, PredictionCache) {
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version); ASSERT_EQ(2, out_predictions.version);
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr);
// drop the cache. // drop the cache.
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2);
ASSERT_EQ(0, out_predictions.version); ASSERT_EQ(0, out_predictions.version);

View File

@ -548,7 +548,7 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
PredictionCacheEntry predts; PredictionCacheEntry predts;
gbm->DoBoost(p_dmat.get(), &gpair, &predts); gbm->DoBoost(p_dmat.get(), &gpair, &predts, nullptr);
return gbm; return gbm;
} }

View File

@ -1,11 +1,14 @@
/*! /*!
* Copyright 2017-2021 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/objective.h>
#include <xgboost/generic_parameters.h> #include <xgboost/generic_parameters.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/objective.h>
#include "../../../src/objective/adaptive.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost {
TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
@ -378,4 +381,113 @@ TEST(Objective, CoxRegressionGPair) {
{ 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f});
} }
#endif #endif
TEST(Objective, DeclareUnifiedTest(AbsoluteError)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:absoluteerror", &ctx)};
obj->Configure({});
CheckConfigReload(obj, "reg:absoluteerror");
MetaInfo info;
std::vector<float> labels{0.f, 3.f, 2.f, 5.f, 4.f, 7.f};
info.labels.Reshape(6, 1);
info.labels.Data()->HostVector() = labels;
info.num_row_ = labels.size();
HostDeviceVector<float> predt{1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
info.weights_.HostVector() = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
CheckObjFunction(obj, predt.HostVector(), labels, info.weights_.HostVector(),
{1.f, -1.f, 1.f, -1.f, 1.f, -1.f}, info.weights_.HostVector());
RegTree tree;
tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
HostDeviceVector<bst_node_t> position(labels.size(), 0);
auto& h_position = position.HostVector();
for (size_t i = 0; i < labels.size(); ++i) {
if (i < labels.size() / 2) {
h_position[i] = 1; // left
} else {
h_position[i] = 2; // right
}
}
auto& h_predt = predt.HostVector();
for (size_t i = 0; i < h_predt.size(); ++i) {
h_predt[i] = labels[i] + i;
}
obj->UpdateTreeLeaf(position, info, predt, &tree);
ASSERT_EQ(tree[1].LeafValue(), -1);
ASSERT_EQ(tree[2].LeafValue(), -4);
}
TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:absoluteerror", &ctx)};
obj->Configure({});
MetaInfo info;
info.labels.Reshape(16, 1);
info.num_row_ = info.labels.Size();
CHECK_EQ(info.num_row_, 16);
auto h_labels = info.labels.HostView().Values();
std::iota(h_labels.begin(), h_labels.end(), 0);
HostDeviceVector<float> predt(h_labels.size());
auto& h_predt = predt.HostVector();
for (size_t i = 0; i < h_predt.size(); ++i) {
h_predt[i] = h_labels[i] + i;
}
HostDeviceVector<bst_node_t> position(info.labels.Size(), 0);
auto& h_position = position.HostVector();
for (int32_t i = 0; i < 3; ++i) {
h_position[i] = ~i; // negation for sampled nodes.
}
for (size_t i = 3; i < 8; ++i) {
h_position[i] = 3;
}
// empty leaf for node 4
for (size_t i = 8; i < 13; ++i) {
h_position[i] = 5;
}
for (size_t i = 13; i < h_labels.size(); ++i) {
h_position[i] = 6;
}
RegTree tree;
tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
ASSERT_EQ(tree.GetNumLeaves(), 4);
auto empty_leaf = tree[4].LeafValue();
obj->UpdateTreeLeaf(position, info, predt, &tree);
ASSERT_EQ(tree[3].LeafValue(), -5);
ASSERT_EQ(tree[4].LeafValue(), empty_leaf);
ASSERT_EQ(tree[5].LeafValue(), -10);
ASSERT_EQ(tree[6].LeafValue(), -14);
}
TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) {
std::vector<bst_node_t> missing{1, 3};
std::vector<bst_node_t> h_nidx = {2, 4, 5};
std::vector<size_t> h_nptr = {0, 4, 8, 16};
obj::detail::FillMissingLeaf(missing, &h_nidx, &h_nptr);
ASSERT_EQ(h_nidx[0], missing[0]);
ASSERT_EQ(h_nidx[2], missing[1]);
ASSERT_EQ(h_nidx[1], 2);
ASSERT_EQ(h_nidx[3], 4);
ASSERT_EQ(h_nidx[4], 5);
ASSERT_EQ(h_nptr[0], 0);
ASSERT_EQ(h_nptr[1], 0); // empty
ASSERT_EQ(h_nptr[2], 4);
ASSERT_EQ(h_nptr[3], 4); // empty
ASSERT_EQ(h_nptr[4], 8);
ASSERT_EQ(h_nptr[5], 16);
}
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2020 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -222,7 +222,7 @@ void TestUpdatePredictionCache(bool use_subsampling) {
PredictionCacheEntry predtion_cache; PredictionCacheEntry predtion_cache;
predtion_cache.predictions.Resize(kRows*kClasses, 0); predtion_cache.predictions.Resize(kRows*kClasses, 0);
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values // after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values
gbm->DoBoost(dmat.get(), &gpair, &predtion_cache); gbm->DoBoost(dmat.get(), &gpair, &predtion_cache, nullptr);
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
// perform fair prediction on the same input data, should be equal to cached result // perform fair prediction on the same input data, should be equal to cached result

View File

@ -1,7 +1,8 @@
/*! /*!
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm>
#include <vector> #include <vector>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
@ -10,6 +11,10 @@
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../helpers.h" #include "../../helpers.h"
#include "xgboost/base.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/task.h"
#include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -103,17 +108,58 @@ TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
void TestFinalise() { void TestFinalise() {
const int kNumRows = 10; const int kNumRows = 10;
ObjInfo task{ObjInfo::kRegression, false, false};
HostDeviceVector<bst_node_t> position;
Context ctx;
ctx.gpu_id = 0;
{
RowPartitioner rp(0, kNumRows);
rp.FinalisePosition(
&ctx, task, &position,
[=] __device__(RowPartitioner::RowIndexT ridx, int position) { return 7; },
[] XGBOOST_DEVICE(size_t idx) { return false; });
auto position = rp.GetPositionHost();
for (auto p : position) {
EXPECT_EQ(p, 7);
}
}
/**
* Test for sampling.
*/
dh::device_vector<float> hess(kNumRows);
for (size_t i = 0; i < hess.size(); ++i) {
// removed rows, 0, 3, 6, 9
if (i % 3 == 0) {
hess[i] = 0;
} else {
hess[i] = i;
}
}
auto d_hess = dh::ToSpan(hess);
RowPartitioner rp(0, kNumRows); RowPartitioner rp(0, kNumRows);
rp.FinalisePosition([=]__device__(RowPartitioner::RowIndexT ridx, int position) rp.FinalisePosition(
{ &ctx, task, &position,
return 7; [] __device__(RowPartitioner::RowIndexT ridx, bst_node_t position) {
}); return ridx % 2 == 0 ? 1 : 2;
auto position = rp.GetPositionHost(); },
for(auto p:position) [d_hess] __device__(size_t ridx) { return d_hess[ridx] - 0.f == 0.f; });
{
EXPECT_EQ(p, 7); auto const& h_position = position.ConstHostVector();
for (size_t ridx = 0; ridx < h_position.size(); ++ridx) {
if (ridx % 3 == 0) {
ASSERT_LT(h_position[ridx], 0);
} else {
ASSERT_EQ(h_position[ridx], ridx % 2 == 0 ? 1 : 2);
}
} }
} }
TEST(RowPartitioner, Finalise) { TestFinalise(); } TEST(RowPartitioner, Finalise) { TestFinalise(); }
void TestIncorrectRow() { void TestIncorrectRow() {

View File

@ -26,7 +26,7 @@ TEST(Approx, Partitioner) {
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(), std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); }); [](auto gpair) { return gpair.GetHess(); });
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) { for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
bst_feature_t const split_ind = 0; bst_feature_t const split_ind = 0;
{ {
auto min_value = page.cut.MinValues()[split_ind]; auto min_value = page.cut.MinValues()[split_ind];
@ -44,9 +44,9 @@ TEST(Approx, Partitioner) {
float split_value = page.cut.Values().at(ptr / 2); float split_value = page.cut.Values().at(ptr / 2);
RegTree tree; RegTree tree;
GetSplit(&tree, split_value, &candidates); GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
partitioner.UpdatePosition(&ctx, page, candidates, &tree); partitioner.UpdatePosition(&ctx, page, candidates, &tree);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
auto elem = partitioner[left_nidx]; auto elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
@ -54,6 +54,7 @@ TEST(Approx, Partitioner) {
auto value = page.cut.Values().at(page.index[*it]); auto value = page.cut.Values().at(page.index[*it]);
ASSERT_LE(value, split_value); ASSERT_LE(value, split_value);
} }
auto right_nidx = tree[RegTree::kRoot].RightChild(); auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx]; elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) { for (auto it = elem.begin; it != elem.end; ++it) {
@ -63,5 +64,78 @@ TEST(Approx, Partitioner) {
} }
} }
} }
namespace {
void TestLeafPartition(size_t n_samples) {
size_t const n_features = 2, base_rowid = 0;
common::RowSetCollection row_set;
ApproxRowPartitioner partitioner{n_samples, base_rowid};
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
GenericParameter ctx;
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
RegTree tree;
std::vector<float> hess(n_samples, 0);
// emulate sampling
auto not_sampled = [](size_t i) {
size_t const kSampleFactor{3};
return i % kSampleFactor != 0;
};
size_t n{0};
for (size_t i = 0; i < hess.size(); ++i) {
if (not_sampled(i)) {
hess[i] = 1.0f;
++n;
}
}
std::vector<size_t> h_nptr;
float split_value{0};
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({Context::kCpuId, 64})) {
bst_feature_t const split_ind = 0;
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
std::vector<bst_node_t> position;
partitioner.LeafPartition(&ctx, tree, hess, &position);
std::sort(position.begin(), position.end());
size_t beg = std::distance(
position.begin(),
std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; }));
std::vector<size_t> nptr;
common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr);
std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; });
auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg);
ASSERT_EQ(nptr.size(), n_uniques + 1);
ASSERT_EQ(nptr[0], beg);
ASSERT_EQ(nptr.back(), n_samples);
h_nptr = nptr;
}
if (h_nptr.front() == n_samples) {
return;
}
ASSERT_GE(h_nptr.size(), 2);
for (auto const& page : Xy->GetBatches<SparsePage>()) {
auto batch = page.GetView();
size_t left{0};
for (size_t i = 0; i < batch.Size(); ++i) {
if (not_sampled(i) && batch[i].front().fvalue < split_value) {
left++;
}
}
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
}
}
} // anonymous namespace
TEST(Approx, LeafPartition) {
for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) {
TestLeafPartition(n_samples);
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
@ -13,6 +13,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "../histogram_helpers.h" #include "../histogram_helpers.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h" #include "../../../src/data/sparse_page_source.h"
#include "../../../src/tree/updater_gpu_hist.cu" #include "../../../src/tree/updater_gpu_hist.cu"
@ -22,7 +23,6 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
TEST(GpuHist, DeviceHistogram) { TEST(GpuHist, DeviceHistogram) {
// Ensures that node allocates correctly after reaching `kStopGrowingSize`. // Ensures that node allocates correctly after reaching `kStopGrowingSize`.
dh::safe_cuda(cudaSetDevice(0)); dh::safe_cuda(cudaSetDevice(0));
@ -81,8 +81,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
param.Init(args); param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{}; BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, Context ctx{CreateEmptyGenericParam(0)};
kNCols, kNCols, batch_param); GPUHistMakerDevice<GradientSumT> maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols,
batch_param);
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows); HostDeviceVector<GradientPair> gpair(kNRows);
@ -158,14 +159,14 @@ TEST(GpuHist, ApplySplit) {
BatchParam bparam; BatchParam bparam;
bparam.gpu_id = 0; bparam.gpu_id = 0;
bparam.max_bin = 3; bparam.max_bin = 3;
Context ctx{CreateEmptyGenericParam(0)};
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){ for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
auto impl = ellpack.Impl(); auto impl = ellpack.Impl();
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical); HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id); feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater( tree::GPUHistMakerDevice<GradientPairPrecise> updater(
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, &ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
bparam);
updater.ApplySplit(candidate, &tree); updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3); ASSERT_EQ(tree.GetSplitTypes().size(), 3);
@ -224,8 +225,9 @@ TEST(GpuHist, EvaluateRootSplit) {
// Initialize GPUHistMakerDevice // Initialize GPUHistMakerDevice
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{}; BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise> maker( Context ctx{CreateEmptyGenericParam(0)};
0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param); GPUHistMakerDevice<GradientPairPrecise> maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols,
batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients // Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {}; maker.node_sum_gradients = {};
@ -348,7 +350,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
GenericParameter generic_param(CreateEmptyGenericParam(0)); GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param); hist_maker.Configure(args, &generic_param);
hist_maker.Update(gpair, dmat, {tree}); std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});
auto cache = linalg::VectorView<float>{preds->DeviceSpan(), {preds->Size()}, 0}; auto cache = linalg::VectorView<float>{preds->DeviceSpan(), {preds->Size()}, 0};
hist_maker.UpdatePredictionCache(dmat, cache); hist_maker.UpdatePredictionCache(dmat, cache);
} }
@ -483,7 +486,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
auto preds_h = preds.ConstHostVector(); auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) { for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 1e-3); ASSERT_NEAR(preds_h[i], preds_ext_h[i], 1e-3);
} }
} }

View File

@ -39,7 +39,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
updater->Configure(Args{ updater->Configure(Args{
{"interaction_constraints", "[[0, 1]]"}, {"interaction_constraints", "[[0, 1]]"},
{"num_feature", std::to_string(kCols)}}); {"num_feature", std::to_string(kCols)}});
updater->Update(&gradients, p_dmat.get(), {&tree}); std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&gradients, p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 4); ASSERT_EQ(tree.NumExtraNodes(), 4);
ASSERT_EQ(tree[0].SplitIndex(), 1); ASSERT_EQ(tree[0].SplitIndex(), 1);
@ -55,7 +56,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
std::unique_ptr<TreeUpdater> updater{ std::unique_ptr<TreeUpdater> updater{
TreeUpdater::Create("grow_histmaker", &param, ObjInfo{ObjInfo::kRegression})}; TreeUpdater::Create("grow_histmaker", &param, ObjInfo{ObjInfo::kRegression})};
updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
updater->Update(&gradients, p_dmat.get(), {&tree}); std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&gradients, p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 10); ASSERT_EQ(tree.NumExtraNodes(), 10);
ASSERT_EQ(tree[0].SplitIndex(), 1); ASSERT_EQ(tree[0].SplitIndex(), 1);

View File

@ -77,7 +77,8 @@ class TestPredictionCache : public ::testing::Test {
std::vector<RegTree *> trees{&tree}; std::vector<RegTree *> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples_); auto gpair = GenerateRandomGradients(n_samples_);
updater->Configure(Args{{"max_bin", "64"}}); updater->Configure(Args{{"max_bin", "64"}});
updater->Update(&gpair, Xy_.get(), trees); std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached; HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.SetDevice(ctx.gpu_id); out_prediction_cached.SetDevice(ctx.gpu_id);
out_prediction_cached.Resize(n_samples_); out_prediction_cached.Resize(n_samples_);

View File

@ -43,22 +43,23 @@ TEST(Updater, Prune) {
pruner->Configure(cfg); pruner->Configure(cfg);
// loss_chg < min_split_loss; // loss_chg < min_split_loss;
std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f, tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), position, trees);
ASSERT_EQ(tree.NumExtraNodes(), 0); ASSERT_EQ(tree.NumExtraNodes(), 0);
// loss_chg > min_split_loss; // loss_chg > min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f, tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), position, trees);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
// loss_chg == min_split_loss; // loss_chg == min_split_loss;
tree.Stat(0).loss_chg = 10; tree.Stat(0).loss_chg = 10;
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), position, trees);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
@ -74,7 +75,7 @@ TEST(Updater, Prune) {
/*left_sum=*/0.0f, /*right_sum=*/0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("max_depth", "1")); cfg.emplace_back(std::make_pair("max_depth", "1"));
pruner->Configure(cfg); pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), position, trees);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
@ -84,7 +85,7 @@ TEST(Updater, Prune) {
/*left_sum=*/0.0f, /*right_sum=*/0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("min_split_loss", "0")); cfg.emplace_back(std::make_pair("min_split_loss", "0"));
pruner->Configure(cfg); pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), position, trees);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
} }
} // namespace tree } // namespace tree

View File

@ -44,7 +44,8 @@ TEST(Updater, Refresh) {
tree.Stat(cright).base_weight = 1.3; tree.Stat(cright).base_weight = 1.3;
refresher->Configure(cfg); refresher->Configure(cfg);
refresher->Update(&gpair, p_dmat.get(), trees); std::vector<HostDeviceVector<bst_node_t>> position;
refresher->Update(&gpair, p_dmat.get(), position, trees);
bst_float constexpr kEps = 1e-6; bst_float constexpr kEps = 1e-6;
ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps); ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps);

View File

@ -27,7 +27,8 @@ class UpdaterTreeStatTest : public ::testing::Test {
up->Configure(Args{}); up->Configure(Args{});
RegTree tree; RegTree tree;
tree.param.num_feature = kCols; tree.param.num_feature = kCols;
up->Update(&gpairs_, p_dmat_.get(), {&tree}); std::vector<HostDeviceVector<bst_node_t>> position(1);
up->Update(&gpairs_, p_dmat_.get(), position, {&tree});
tree.WalkTree([&tree](bst_node_t nidx) { tree.WalkTree([&tree](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) { if (tree[nidx].IsLeaf()) {
@ -87,13 +88,15 @@ class UpdaterEtaTest : public ::testing::Test {
RegTree tree_0; RegTree tree_0;
{ {
tree_0.param.num_feature = kCols; tree_0.param.num_feature = kCols;
up_0->Update(&gpairs_, p_dmat_.get(), {&tree_0}); std::vector<HostDeviceVector<bst_node_t>> position(1);
up_0->Update(&gpairs_, p_dmat_.get(), position, {&tree_0});
} }
RegTree tree_1; RegTree tree_1;
{ {
tree_1.param.num_feature = kCols; tree_1.param.num_feature = kCols;
up_1->Update(&gpairs_, p_dmat_.get(), {&tree_1}); std::vector<HostDeviceVector<bst_node_t>> position(1);
up_1->Update(&gpairs_, p_dmat_.get(), position, {&tree_1});
} }
tree_0.WalkTree([&](bst_node_t nidx) { tree_0.WalkTree([&](bst_node_t nidx) {
if (tree_0[nidx].IsLeaf()) { if (tree_0[nidx].IsLeaf()) {
@ -149,7 +152,8 @@ class TestMinSplitLoss : public ::testing::Test {
up->Configure(args); up->Configure(args);
RegTree tree; RegTree tree;
up->Update(&gpair_, dmat_.get(), {&tree}); std::vector<HostDeviceVector<bst_node_t>> position(1);
up->Update(&gpair_, dmat_.get(), position, {&tree});
auto n_nodes = tree.NumExtraNodes(); auto n_nodes = tree.NumExtraNodes();
return n_nodes; return n_nodes;

View File

@ -249,6 +249,8 @@ class TestGPUPredict:
tm.dataset_strategy, shap_parameter_strategy) tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
def test_shap(self, num_rounds, dataset, param): def test_shap(self, num_rounds, dataset, param):
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"predictor": "gpu_predictor", "gpu_id": 0}) param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param) param = dataset.set_params(param)
dmat = dataset.get_dmat() dmat = dataset.get_dmat()
@ -263,6 +265,8 @@ class TestGPUPredict:
tm.dataset_strategy, shap_parameter_strategy) tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
def test_shap_interactions(self, num_rounds, dataset, param): def test_shap_interactions(self, num_rounds, dataset, param):
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"predictor": "gpu_predictor", "gpu_id": 0}) param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param) param = dataset.set_params(param)
dmat = dataset.get_dmat() dmat = dataset.get_dmat()

View File

@ -90,6 +90,8 @@ class TestGPUUpdaters:
tm.dataset_strategy) tm.dataset_strategy)
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
def test_external_memory(self, param, num_rounds, dataset): def test_external_memory(self, param, num_rounds, dataset):
if dataset.name.endswith("-l1"):
return
# We cannot handle empty dataset yet # We cannot handle empty dataset yet
assume(len(dataset.y) > 0) assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist' param['tree_method'] = 'gpu_hist'

View File

@ -1,7 +1,7 @@
"""Copyright 2019-2022 XGBoost contributors""" """Copyright 2019-2022 XGBoost contributors"""
import sys import sys
import os import os
from typing import Type, TypeVar, Any, Dict, List, Tuple from typing import Type, TypeVar, Any, Dict, List
import pytest import pytest
import numpy as np import numpy as np
import asyncio import asyncio
@ -198,9 +198,19 @@ def run_gpu_hist(
dtrain=m, dtrain=m,
num_boost_round=num_rounds, num_boost_round=num_rounds,
evals=[(m, "train")], evals=[(m, "train")],
)["history"] )["history"]["train"][dataset.metric]
note(history) note(history)
assert tm.non_increasing(history["train"][dataset.metric])
# See note on `ObjFunction::UpdateTreeLeaf`.
update_leaf = dataset.name.endswith("-l1")
if update_leaf and len(history) == 2:
assert history[0] + 1e-2 >= history[-1]
return
if update_leaf and len(history) > 2:
assert history[0] >= history[-1]
return
else:
assert tm.non_increasing(history)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
@ -305,8 +315,7 @@ class TestDistributedGPU:
def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None: def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
parameters = {'tree_method': 'gpu_hist', parameters = {'tree_method': 'gpu_hist', 'debug_synchronize': True}
'debug_synchronize': True}
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters) run_empty_dmatrix_cls(client, parameters)

View File

@ -40,6 +40,8 @@ class TestTreeMethod:
tm.dataset_strategy) tm.dataset_strategy)
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
def test_exact(self, param, num_rounds, dataset): def test_exact(self, param, num_rounds, dataset):
if dataset.name.endswith("-l1"):
return
param['tree_method'] = 'exact' param['tree_method'] = 'exact'
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds) result = train_result(param, dataset.get_dmat(), num_rounds)

View File

@ -35,6 +35,7 @@ import dask.dataframe as dd
import dask.array as da import dask.array as da
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
dask.config.set({"distributed.scheduler.allowed-failures": False})
if hasattr(HealthCheck, 'function_scoped_fixture'): if hasattr(HealthCheck, 'function_scoped_fixture'):
suppress = [HealthCheck.function_scoped_fixture] suppress = [HealthCheck.function_scoped_fixture]
@ -673,7 +674,8 @@ def test_empty_dmatrix_training_continuation(client: "Client") -> None:
def run_empty_dmatrix_reg(client: "Client", parameters: dict) -> None: def run_empty_dmatrix_reg(client: "Client", parameters: dict) -> None:
def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None: def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None:
assert isinstance(out['booster'], xgb.dask.Booster) assert isinstance(out['booster'], xgb.dask.Booster)
assert len(out['history']['validation']['rmse']) == 2 for _, v in out['history']['validation'].items():
assert len(v) == 2
assert isinstance(predictions, np.ndarray) assert isinstance(predictions, np.ndarray)
assert predictions.shape[0] == 1 assert predictions.shape[0] == 1
@ -866,6 +868,8 @@ def test_empty_dmatrix(tree_method) -> None:
parameters = {'tree_method': tree_method} parameters = {'tree_method': tree_method}
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters) run_empty_dmatrix_cls(client, parameters)
parameters = {'tree_method': tree_method, "objective": "reg:absoluteerror"}
run_empty_dmatrix_reg(client, parameters)
async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainReturnT: async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainReturnT:
@ -1284,7 +1288,12 @@ class TestWithDask:
def minimum_bin(): def minimum_bin():
return "max_bin" in params and params["max_bin"] == 2 return "max_bin" in params and params["max_bin"] == 2
if minimum_bin() and is_stump(): # See note on `ObjFunction::UpdateTreeLeaf`.
update_leaf = dataset.name.endswith("-l1")
if update_leaf and len(history) >= 2:
assert history[0] >= history[-1]
return
elif minimum_bin() and is_stump():
assert tm.non_increasing(history, tolerance=1e-3) assert tm.non_increasing(history, tolerance=1e-3)
else: else:
assert tm.non_increasing(history) assert tm.non_increasing(history)
@ -1304,7 +1313,7 @@ class TestWithDask:
dataset=tm.dataset_strategy) dataset=tm.dataset_strategy)
@settings(deadline=None, suppress_health_check=suppress, print_blob=True) @settings(deadline=None, suppress_health_check=suppress, print_blob=True)
def test_approx( def test_approx(
self, client: "Client", params: Dict, dataset: tm.TestDataset self, client: "Client", params: Dict, dataset: tm.TestDataset
) -> None: ) -> None:
num_rounds = 30 num_rounds = 30
self.run_updater_test(client, params, num_rounds, dataset, 'approx') self.run_updater_test(client, params, num_rounds, dataset, 'approx')

View File

@ -327,6 +327,9 @@ _unweighted_datasets_strategy = strategies.sampled_from(
TestDataset( TestDataset(
"calif_housing", get_california_housing, "reg:squarederror", "rmse" "calif_housing", get_california_housing, "reg:squarederror", "rmse"
), ),
TestDataset(
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
),
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"), TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"), TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
TestDataset( TestDataset(
@ -336,6 +339,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
"rmse", "rmse",
), ),
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"), TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
TestDataset( TestDataset(
"empty", "empty",
lambda: (np.empty((0, 100)), np.empty(0)), lambda: (np.empty((0, 100)), np.empty(0)),