Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan 2022-09-20 20:53:54 +08:00 committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 999 additions and 343 deletions

View File

@ -75,19 +75,20 @@
#include "../src/collective/communicator.cc" #include "../src/collective/communicator.cc"
// common // common
#include "../src/common/common.cc"
#include "../src/common/column_matrix.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc" #include "../src/common/charconv.cc"
#include "../src/common/timer.cc" #include "../src/common/column_matrix.cc"
#include "../src/common/quantile.cc" #include "../src/common/common.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc" #include "../src/common/hist_util.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/io.cc" #include "../src/common/io.cc"
#include "../src/common/json.cc" #include "../src/common/json.cc"
#include "../src/common/numeric.cc"
#include "../src/common/pseudo_huber.cc" #include "../src/common/pseudo_huber.cc"
#include "../src/common/quantile.cc"
#include "../src/common/random.cc"
#include "../src/common/survival_util.cc" #include "../src/common/survival_util.cc"
#include "../src/common/threading_utils.cc" #include "../src/common/threading_utils.cc"
#include "../src/common/timer.cc"
#include "../src/common/version.cc" #include "../src/common/version.cc"
// c_api // c_api

View File

@ -8,10 +8,9 @@
#ifndef XGBOOST_LEARNER_H_ #ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_ #define XGBOOST_LEARNER_H_
#include <dmlc/any.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/feature_map.h> #include <xgboost/feature_map.h>
#include <xgboost/generic_parameters.h> #include <xgboost/generic_parameters.h> // Context
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/model.h> #include <xgboost/model.h>
#include <xgboost/predictor.h> #include <xgboost/predictor.h>
@ -274,7 +273,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
/** /**
* \brief Return the context object of this Booster. * \brief Return the context object of this Booster.
*/ */
virtual GenericParameter const* Ctx() const = 0; virtual Context const* Ctx() const = 0;
/*! /*!
* \brief Get configuration arguments currently stored by the learner * \brief Get configuration arguments currently stored by the learner
* \return Key-value pairs representing configuration arguments * \return Key-value pairs representing configuration arguments
@ -289,7 +288,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
/*! \brief The evaluation metrics used to evaluate the model. */ /*! \brief The evaluation metrics used to evaluate the model. */
std::vector<std::unique_ptr<Metric> > metrics_; std::vector<std::unique_ptr<Metric> > metrics_;
/*! \brief Training parameter. */ /*! \brief Training parameter. */
GenericParameter generic_parameters_; Context ctx_;
}; };
struct LearnerModelParamLegacy; struct LearnerModelParamLegacy;
@ -298,8 +297,14 @@ struct LearnerModelParamLegacy;
* \brief Basic Model Parameters, used to describe the booster. * \brief Basic Model Parameters, used to describe the booster.
*/ */
struct LearnerModelParam { struct LearnerModelParam {
/* \brief global bias */ private:
bst_float base_score { 0.5f }; /**
* \brief Global bias, this is just a scalar value but can be extended to vector when we
* support multi-class and multi-target.
*/
linalg::Tensor<float, 1> base_score_;
public:
/* \brief number of features */ /* \brief number of features */
uint32_t num_feature { 0 }; uint32_t num_feature { 0 };
/* \brief number of classes, if it is multi-class classification */ /* \brief number of classes, if it is multi-class classification */
@ -310,7 +315,18 @@ struct LearnerModelParam {
LearnerModelParam() = default; LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
// this one as an immutable copy. // this one as an immutable copy.
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t); LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param,
linalg::Tensor<float, 1> base_margin, ObjInfo t);
LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t);
LearnerModelParam(bst_feature_t n_features, linalg::Tensor<float, 1> base_margin,
uint32_t n_groups)
: base_score_{std::move(base_margin)}, num_feature{n_features}, num_output_group{n_groups} {}
linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
linalg::TensorView<float const, 1> BaseScore(int32_t device) const;
void Copy(LearnerModelParam const& that);
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; } bool Initialized() const { return num_feature != 0; }
}; };

View File

@ -8,6 +8,7 @@
#include <dmlc/endian.h> #include <dmlc/endian.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/span.h> #include <xgboost/span.h>
@ -16,6 +17,7 @@
#include <cassert> #include <cassert>
#include <limits> #include <limits>
#include <string> #include <string>
#include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -213,6 +215,22 @@ LINALG_HD decltype(auto) constexpr Apply(Fn &&f, Tup &&t) {
constexpr auto kSize = std::tuple_size<Tup>::value; constexpr auto kSize = std::tuple_size<Tup>::value;
return Apply(std::forward<Fn>(f), std::forward<Tup>(t), std::make_index_sequence<kSize>{}); return Apply(std::forward<Fn>(f), std::forward<Tup>(t), std::make_index_sequence<kSize>{});
} }
/**
* C++ 17 conjunction
*/
template <class...>
struct Conjunction : std::true_type {};
template <class B1>
struct Conjunction<B1> : B1 {};
template <class B1, class... Bn>
struct Conjunction<B1, Bn...> : std::conditional_t<bool(B1::value), Conjunction<Bn...>, B1> {};
template <typename... Index>
using IsAllIntegral = Conjunction<std::is_integral<std::remove_reference_t<Index>>...>;
template <typename... Index>
using EnableIfIntegral = std::enable_if_t<IsAllIntegral<Index...>::value>;
} // namespace detail } // namespace detail
/** /**
@ -406,7 +424,7 @@ class TensorView {
* *
* \endcode * \endcode
*/ */
template <typename... Index> template <typename... Index, detail::EnableIfIntegral<Index...> * = nullptr>
LINALG_HD T &operator()(Index &&...index) { LINALG_HD T &operator()(Index &&...index) {
static_assert(sizeof...(index) <= kDim, "Invalid index."); static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...); size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
@ -416,7 +434,7 @@ class TensorView {
/** /**
* \brief Index the tensor to obtain a scalar value. * \brief Index the tensor to obtain a scalar value.
*/ */
template <typename... Index> template <typename... Index, detail::EnableIfIntegral<Index...> * = nullptr>
LINALG_HD T const &operator()(Index &&...index) const { LINALG_HD T const &operator()(Index &&...index) const {
static_assert(sizeof...(index) <= kDim, "Invalid index."); static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...); size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
@ -656,7 +674,7 @@ class Tensor {
} }
if (device >= 0) { if (device >= 0) {
data_.SetDevice(device); data_.SetDevice(device);
data_.DevicePointer(); // Pull to device; data_.ConstDevicePointer(); // Pull to device;
} }
CHECK_EQ(data_.Size(), detail::CalcSize(shape_)); CHECK_EQ(data_.Size(), detail::CalcSize(shape_));
} }
@ -702,12 +720,29 @@ class Tensor {
} }
template <typename I, int32_t D> template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], int32_t device) { explicit Tensor(std::initializer_list<T> data, I const (&shape)[D],
int32_t device = Context::kCpuId) {
auto &h_vec = data_.HostVector(); auto &h_vec = data_.HostVector();
h_vec = data; h_vec = data;
// shape // shape
this->Initialize(shape, device); this->Initialize(shape, device);
} }
/**
* \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first.
*/
template <typename... Index>
T &operator()(Index &&...idx) {
return this->HostView()(std::forward<Index>(idx)...);
}
/**
* \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first.
*/
template <typename... Index>
T const &operator()(Index &&...idx) const {
return this->HostView()(std::forward<Index>(idx)...);
}
/** /**
* \brief Get a \ref TensorView for this tensor. * \brief Get a \ref TensorView for this tensor.
@ -761,7 +796,7 @@ class Tensor {
* *
* If the total size is changed, then data in this tensor is no longer valid. * If the total size is changed, then data in this tensor is no longer valid.
*/ */
template <typename... S> template <typename... S, detail::EnableIfIntegral<S...> * = nullptr>
void Reshape(S &&...s) { void Reshape(S &&...s) {
static_assert(sizeof...(S) <= kDim, "Invalid shape."); static_assert(sizeof...(S) <= kDim, "Invalid shape.");
detail::ReshapeImpl<0>(shape_, std::forward<S>(s)...); detail::ReshapeImpl<0>(shape_, std::forward<S>(s)...);
@ -777,15 +812,20 @@ class Tensor {
* *
* If the total size is changed, then data in this tensor is no longer valid. * If the total size is changed, then data in this tensor is no longer valid.
*/ */
template <int32_t D> template <size_t D>
void Reshape(size_t (&shape)[D]) { void Reshape(common::Span<size_t const, D> shape) {
static_assert(D <= kDim, "Invalid shape."); static_assert(D <= kDim, "Invalid shape.");
std::copy(shape, shape + D, this->shape_); std::copy(shape.data(), shape.data() + D, this->shape_);
std::fill(shape_ + D, shape_ + kDim, 1); std::fill(shape_ + D, shape_ + kDim, 1);
auto n = detail::CalcSize(shape_); auto n = detail::CalcSize(shape_);
data_.Resize(n); data_.Resize(n);
} }
template <size_t D>
void Reshape(size_t (&shape)[D]) {
this->Reshape(common::Span<size_t const, D>{shape});
}
/** /**
* \brief Set device ordinal for this tensor. * \brief Set device ordinal for this tensor.
*/ */

View File

@ -27,7 +27,10 @@ class RegTree;
/*! \brief interface of objective function */ /*! \brief interface of objective function */
class ObjFunction : public Configurable { class ObjFunction : public Configurable {
protected: protected:
GenericParameter const* ctx_; Context const* ctx_;
public:
static constexpr float DefaultBaseScore() { return 0.5f; }
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
@ -75,6 +78,13 @@ class ObjFunction : public Configurable {
virtual bst_float ProbToMargin(bst_float base_score) const { virtual bst_float ProbToMargin(bst_float base_score) const {
return base_score; return base_score;
} }
/**
* \brief Make initialize estimation of prediction.
*
* \param info MetaInfo that contains label.
* \param base_score Output estimation.
*/
virtual void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const;
/*! /*!
* \brief Return task of this objective. * \brief Return task of this objective.
*/ */

View File

@ -102,13 +102,10 @@ class PredictionContainer {
*/ */
class Predictor { class Predictor {
protected: protected:
/* Context const* ctx_;
* \brief Runtime parameters.
*/
GenericParameter const* ctx_;
public: public:
explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {} explicit Predictor(Context const* ctx) : ctx_{ctx} {}
virtual ~Predictor() = default; virtual ~Predictor() = default;

View File

@ -1,7 +1,8 @@
/*! /*!
* Copyright 2022 by XGBoost Contributors * Copyright 2022 by XGBoost Contributors
*/ */
#pragma once #ifndef XGBOOST_COMMON_ALGORITHM_H_
#define XGBOOST_COMMON_ALGORITHM_H_
#include <algorithm> // std::upper_bound #include <algorithm> // std::upper_bound
#include <cinttypes> // std::size_t #include <cinttypes> // std::size_t
@ -14,3 +15,4 @@ auto SegmentId(It first, It last, Idx idx) {
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_ALGORITHM_H_

View File

@ -265,6 +265,7 @@ struct OptionalWeights {
explicit OptionalWeights(float w) : dft{w} {} explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); }
}; };
/** /**
@ -276,7 +277,7 @@ XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
} }
/** /**
* @brief A CRTP (curiously recurring template pattern) helper function. * \brief A CRTP (curiously recurring template pattern) helper function.
* *
* https://www.fluentcpp.com/2017/05/19/crtp-helper/ * https://www.fluentcpp.com/2017/05/19/crtp-helper/
* *
@ -284,7 +285,7 @@ XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
* 1. Makes "crtp" explicit in the inheritance structure of a CRTP base class. * 1. Makes "crtp" explicit in the inheritance structure of a CRTP base class.
* 2. Avoids having to `static_cast` in a lot of places. * 2. Avoids having to `static_cast` in a lot of places.
* *
* @tparam T The derived class in a CRTP hierarchy. * \tparam T The derived class in a CRTP hierarchy.
*/ */
template <typename T> template <typename T>
struct Crtp { struct Crtp {
@ -292,6 +293,13 @@ struct Crtp {
T const &Underlying() const { return static_cast<T const &>(*this); } T const &Underlying() const { return static_cast<T const &>(*this); }
}; };
/**
* \brief C++17 std::as_const
*/
template <typename T>
typename std::add_const<T>::type &AsConst(T &v) noexcept { // NOLINT(runtime/references)
return v;
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_ #endif // XGBOOST_COMMON_COMMON_H_

View File

@ -4,6 +4,7 @@
#ifndef XGBOOST_COMMON_LINALG_OP_H_ #ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_
#include <type_traits> #include <type_traits>
#include <cstdint> // std::int32_t
#include "common.h" #include "common.h"
#include "threading_utils.h" #include "threading_utils.h"
@ -59,6 +60,31 @@ void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t,
ElementWiseKernelHost(t, ctx->Threads(), fn); ElementWiseKernelHost(t, ctx->Threads(), fn);
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> v) { // NOLINT
auto it = common::MakeIndexTransformIter([&](size_t i) -> std::remove_cv_t<T> const& {
return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape()));
});
return it;
}
template <typename T, std::int32_t kDim>
auto cend(TensorView<T, kDim> v) { // NOLINT
return cbegin(v) + v.Size();
}
template <typename T, std::int32_t kDim>
auto begin(TensorView<T, kDim> v) { // NOLINT
auto it = common::MakeIndexTransformIter(
[&](size_t i) -> T& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); });
return it;
}
template <typename T, std::int32_t kDim>
auto end(TensorView<T, kDim> v) { // NOLINT
return begin(v) + v.Size();
}
} // namespace linalg } // namespace linalg
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_LINALG_OP_H_ #endif // XGBOOST_COMMON_LINALG_OP_H_

28
src/common/numeric.cc Normal file
View File

@ -0,0 +1,28 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include "numeric.h"
#include <numeric> // std::accumulate
#include <type_traits> // std::is_same
#include "threading_utils.h" // MemStackAllocator, ParallelFor, DefaultMaxThreads
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost {
namespace common {
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
if (ctx->IsCPU()) {
auto const& h_values = values.ConstHostVector();
MemStackAllocator<double, DefaultMaxThreads()> result_tloc(ctx->Threads(), 0);
ParallelFor(h_values.size(), ctx->Threads(),
[&](auto i) { result_tloc[omp_get_thread_num()] += h_values[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cend(), 0.0);
static_assert(std::is_same<decltype(result), double>::value, "");
return result;
}
return cuda::Reduce(ctx, values);
}
} // namespace common
} // namespace xgboost

25
src/common/numeric.cu Normal file
View File

@ -0,0 +1,25 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <thrust/execution_policy.h>
#include <thrust/functional.h> // thrust:plus
#include "device_helpers.cuh" // dh::Reduce, safe_cuda, dh::XGBCachingDeviceAllocator
#include "numeric.h"
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost {
namespace common {
namespace cuda {
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
values.SetDevice(ctx->gpu_id);
auto const d_values = values.ConstDeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
auto res = dh::Reduce(thrust::cuda::par(alloc), d_values.data(),
d_values.data() + d_values.size(), 0.0, thrust::plus<double>{});
return res;
}
} // namespace cuda
} // namespace common
} // namespace xgboost

View File

@ -8,8 +8,10 @@
#include <iterator> // std::iterator_traits #include <iterator> // std::iterator_traits
#include <vector> #include <vector>
#include "threading_utils.h" #include "common.h" // AssertGPUSupport
#include "xgboost/generic_parameters.h" #include "threading_utils.h" // MemStackAllocator, DefaultMaxThreads
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -18,8 +20,8 @@ namespace common {
* \brief Run length encode on CPU, input must be sorted. * \brief Run length encode on CPU, input must be sorted.
*/ */
template <typename Iter, typename Idx> template <typename Iter, typename Idx>
void RunLengthEncode(Iter begin, Iter end, std::vector<Idx> *p_out) { void RunLengthEncode(Iter begin, Iter end, std::vector<Idx>* p_out) {
auto &out = *p_out; auto& out = *p_out;
out = std::vector<Idx>{0}; out = std::vector<Idx>{0};
size_t n = std::distance(begin, end); size_t n = std::distance(begin, end);
for (size_t i = 1; i < n; ++i) { for (size_t i = 1; i < n; ++i) {
@ -45,7 +47,7 @@ void PartialSum(int32_t n_threads, InIt begin, InIt end, T init, OutIt out_it) {
auto n = static_cast<size_t>(std::distance(begin, end)); auto n = static_cast<size_t>(std::distance(begin, end));
const size_t batch_threads = const size_t batch_threads =
std::max(static_cast<size_t>(1), std::min(n, static_cast<size_t>(n_threads))); std::max(static_cast<size_t>(1), std::min(n, static_cast<size_t>(n_threads)));
common::MemStackAllocator<T, 128> partial_sums(batch_threads); MemStackAllocator<T, DefaultMaxThreads()> partial_sums(batch_threads);
size_t block_size = n / batch_threads; size_t block_size = n / batch_threads;
@ -90,6 +92,20 @@ void PartialSum(int32_t n_threads, InIt begin, InIt end, T init, OutIt out_it) {
} }
exc.Rethrow(); exc.Rethrow();
} }
namespace cuda {
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
#if !defined(XGBOOST_USE_CUDA)
inline double Reduce(Context const*, HostDeviceVector<float> const&) {
AssertGPUSupport();
return 0;
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda
/**
* \brief Reduction with summation.
*/
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

47
src/common/stats.cu Normal file
View File

@ -0,0 +1,47 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
#include "common.h" // common::OptionalWeights
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply
namespace xgboost {
namespace common {
namespace cuda {
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
common::OptionalWeights weights) {
HostDeviceVector<size_t> segments{0, t.Size()};
segments.SetDevice(ctx->gpu_id);
auto d_segments = segments.ConstDeviceSpan();
auto val_it = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
return linalg::detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
});
HostDeviceVector<float> quantile{0};
quantile.SetDevice(ctx->gpu_id);
if (weights.Empty()) {
common::SegmentedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments), val_it,
val_it + t.Size(), &quantile);
} else {
CHECK_NE(t.Shape(1), 0);
auto w_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) {
auto sample_idx = i / t.Shape(1);
return weights[sample_idx];
});
common::SegmentedWeightedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments),
val_it, val_it + t.Size(), w_it, w_it + t.Size(), &quantile);
}
CHECK_EQ(quantile.Size(), 1);
return quantile.HostVector().front();
}
} // namespace cuda
} // namespace common
} // namespace xgboost

View File

@ -8,7 +8,8 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "common.h" #include "common.h" // AssertGPUSupport
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
namespace xgboost { namespace xgboost {
@ -90,6 +91,44 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
idx = std::min(idx, static_cast<size_t>(n - 1)); idx = std::min(idx, static_cast<size_t>(n - 1));
return val(idx); return val(idx);
} }
namespace cuda {
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
common::OptionalWeights weights);
#if !defined(XGBOOST_USE_CUDA)
inline float Median(Context const*, linalg::TensorView<float const, 2>, common::OptionalWeights) {
AssertGPUSupport();
return 0;
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda
inline float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights) {
if (!ctx->IsCPU()) {
weights.SetDevice(ctx->gpu_id);
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
auto t_v = t.View(ctx->gpu_id);
return cuda::Median(ctx, t_v, opt_weights);
}
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
auto t_v = t.HostView();
auto iter = common::MakeIndexTransformIter(
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
float q{0};
if (opt_weights.Empty()) {
q = common::Quantile(0.5, iter, iter + t_v.Size());
} else {
CHECK_NE(t_v.Shape(1), 0);
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
auto sample_idx = i / t_v.Shape(1);
return opt_weights[sample_idx];
});
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
}
return q;
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_STATS_H_ #endif // XGBOOST_COMMON_STATS_H_

View File

@ -8,6 +8,7 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdint> // std::int32_t
#include <limits> #include <limits>
#include <type_traits> // std::is_signed #include <type_traits> // std::is_signed
#include <vector> #include <vector>
@ -253,7 +254,7 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated. * heap-allocated.
*/ */
template <typename T, size_t MaxStackSize> template <typename T, std::size_t MaxStackSize>
class MemStackAllocator { class MemStackAllocator {
public: public:
explicit MemStackAllocator(size_t required_size) : required_size_(required_size) { explicit MemStackAllocator(size_t required_size) : required_size_(required_size) {
@ -278,11 +279,23 @@ class MemStackAllocator {
T& operator[](size_t i) { return ptr_[i]; } T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; } T const& operator[](size_t i) const { return ptr_[i]; }
auto data() const { return ptr_; } // NOLINT
auto data() { return ptr_; } // NOLINT
std::size_t size() const { return required_size_; } // NOLINT
auto cbegin() const { return data(); } // NOLINT
auto cend() const { return data() + size(); } // NOLINT
private: private:
T* ptr_ = nullptr; T* ptr_ = nullptr;
size_t required_size_; size_t required_size_;
T stack_mem_[MaxStackSize]; T stack_mem_[MaxStackSize];
}; };
/**
* \brief Constant that can be used for initializing static thread local memory.
*/
std::int32_t constexpr DefaultMaxThreads() { return 128; }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -345,8 +345,8 @@ struct ToDType<int64_t> {
}; };
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { common::AssertGPUSupport(); } inline void ArrayInterfaceHandler::SyncCudaStream(int64_t) { common::AssertGPUSupport(); }
inline bool ArrayInterfaceHandler::IsCudaPtr(void const *ptr) { return false; } inline bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
/** /**

View File

@ -161,9 +161,10 @@ class GBLinear : public GradientBooster {
uint32_t layer_begin, uint32_t) override { uint32_t layer_begin, uint32_t) override {
LinearCheckLayer(layer_begin); LinearCheckLayer(layer_begin);
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
auto base_score = learner_model_param_->BaseScore(ctx_);
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_score(0));
learner_model_param_->base_score);
} }
} }
@ -184,6 +185,7 @@ class GBLinear : public GradientBooster {
contribs.resize(p_fmat->Info().num_row_ * ncolumns * ngroup); contribs.resize(p_fmat->Info().num_row_ * ncolumns * ngroup);
// make sure contributions is zeroed, we could be reusing a previously allocated one // make sure contributions is zeroed, we could be reusing a previously allocated one
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
auto base_score = learner_model_param_->BaseScore(ctx_);
// start collecting the contributions // start collecting the contributions
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) { for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
// parallel over local batch // parallel over local batch
@ -202,8 +204,8 @@ class GBLinear : public GradientBooster {
} }
// add base margin to BIAS // add base margin to BIAS
p_contribs[ncolumns - 1] = p_contribs[ncolumns - 1] =
model_.Bias()[gid] + ((base_margin.Size() != 0) ? base_margin(row_idx, gid) model_.Bias()[gid] +
: learner_model_param_->base_score); ((base_margin.Size() != 0) ? base_margin(row_idx, gid) : base_score(0));
} }
}); });
} }
@ -268,10 +270,12 @@ class GBLinear : public GradientBooster {
monitor_.Start("PredictBatchInternal"); monitor_.Start("PredictBatchInternal");
model_.LazyInitModel(); model_.LazyInitModel();
std::vector<bst_float> &preds = *out_preds; std::vector<bst_float> &preds = *out_preds;
auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
// start collecting the prediction // start collecting the prediction
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
preds.resize(p_fmat->Info().num_row_ * ngroup); preds.resize(p_fmat->Info().num_row_ * ngroup);
auto base_score = learner_model_param_->BaseScore(Context::kCpuId);
for (const auto &page : p_fmat->GetBatches<SparsePage>()) { for (const auto &page : p_fmat->GetBatches<SparsePage>()) {
auto const& batch = page.GetView(); auto const& batch = page.GetView();
// output convention: nrow * k, where nrow is number of rows // output convention: nrow * k, where nrow is number of rows
@ -285,8 +289,7 @@ class GBLinear : public GradientBooster {
const size_t ridx = page.base_rowid + i; const size_t ridx = page.base_rowid + i;
// loop over output groups // loop over output groups
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
float margin = float margin = (base_margin.Size() != 0) ? base_margin(ridx, gid) : base_score(0);
(base_margin.Size() != 0) ? base_margin(ridx, gid) : learner_model_param_->base_score;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
} }
}); });

View File

@ -638,13 +638,12 @@ void GPUDartPredictInc(common::Span<float> out_predts,
} }
#endif #endif
void GPUDartInplacePredictInc(common::Span<float> out_predts, void GPUDartInplacePredictInc(common::Span<float> /*out_predts*/, common::Span<float> /*predts*/,
common::Span<float> predts, float tree_w, float /*tree_w*/, size_t /*n_rows*/,
size_t n_rows, float base_score, linalg::TensorView<float const, 1> /*base_score*/,
bst_group_t n_groups, bst_group_t /*n_groups*/, bst_group_t /*group*/)
bst_group_t group)
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
; // NOLINT ; // NOLINT
#else #else
{ {
common::AssertGPUSupport(); common::AssertGPUSupport();
@ -850,15 +849,17 @@ class Dart : public GBTree {
size_t n_rows = p_fmat->Info().num_row_; size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != Context::kCpuId) { if (predts.predictions.DeviceIdx() != Context::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx()); p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.DeviceIdx());
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(), GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,
model_.learner_model_param->base_score, n_groups, group); group);
} else { } else {
auto base_score = model_.learner_model_param->BaseScore(Context::kCpuId);
auto& h_predts = predts.predictions.HostVector(); auto& h_predts = predts.predictions.HostVector();
auto& h_out_predts = p_out_preds->predictions.HostVector(); auto& h_out_predts = p_out_preds->predictions.HostVector();
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) { common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {
const size_t offset = ridx * n_groups + group; const size_t offset = ridx * n_groups + group;
h_out_predts[offset] += (h_predts[offset] - model_.learner_model_param->base_score) * w; h_out_predts[offset] += (h_predts[offset] - base_score(0)) * w;
}); });
} }
} }

View File

@ -31,13 +31,14 @@ void GPUDartPredictInc(common::Span<float> out_predts,
}); });
} }
void GPUDartInplacePredictInc(common::Span<float> out_predts, void GPUDartInplacePredictInc(common::Span<float> out_predts, common::Span<float> predts,
common::Span<float> predts, float tree_w, float tree_w, size_t n_rows,
size_t n_rows, float base_score, linalg::TensorView<float const, 1> base_score, bst_group_t n_groups,
bst_group_t n_groups, bst_group_t group) { bst_group_t group) {
CHECK_EQ(base_score.Size(), 1);
dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) { dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group; const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] - base_score) * tree_w; out_predts[offset] += (predts[offset] - base_score(0)) * tree_w;
}); });
} }
} // namespace gbm } // namespace gbm

View File

@ -4,47 +4,48 @@
* \brief Implementation of learning algorithm. * \brief Implementation of learning algorithm.
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "xgboost/learner.h"
#include <dmlc/any.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <atomic>
#include <mutex>
#include <algorithm> #include <algorithm>
#include <atomic>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits> // std::numeric_limits
#include <memory> #include <memory>
#include <mutex>
#include <sstream> #include <sstream>
#include <string>
#include <stack> #include <stack>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "dmlc/any.h" #include "common/charconv.h"
#include "common/common.h"
#include "common/io.h"
#include "common/linalg_op.h"
#include "common/observer.h"
#include "common/random.h"
#include "common/threading_utils.h"
#include "common/timer.h"
#include "common/version.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/model.h"
#include "xgboost/predictor.h"
#include "xgboost/feature_map.h" #include "xgboost/feature_map.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/generic_parameters.h" #include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/learner.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/metric.h" #include "xgboost/metric.h"
#include "xgboost/model.h"
#include "xgboost/objective.h" #include "xgboost/objective.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/predictor.h"
#include "common/common.h"
#include "common/io.h"
#include "common/observer.h"
#include "common/random.h"
#include "common/timer.h"
#include "common/charconv.h"
#include "common/version.h"
#include "common/threading_utils.h"
namespace { namespace {
@ -85,26 +86,29 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
uint32_t minor_version; uint32_t minor_version;
uint32_t num_target{1}; uint32_t num_target{1};
int32_t base_score_estimated{0};
/*! \brief reserved field */ /*! \brief reserved field */
int reserved[26]; int reserved[25];
/*! \brief constructor */ /*! \brief constructor */
LearnerModelParamLegacy() { LearnerModelParamLegacy() {
std::memset(this, 0, sizeof(LearnerModelParamLegacy)); std::memset(this, 0, sizeof(LearnerModelParamLegacy));
base_score = 0.5f; base_score = ObjFunction::DefaultBaseScore();
num_target = 1; num_target = 1;
major_version = std::get<0>(Version::Self()); major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self()); minor_version = std::get<1>(Version::Self());
base_score_estimated = 0;
static_assert(sizeof(LearnerModelParamLegacy) == 136, static_assert(sizeof(LearnerModelParamLegacy) == 136,
"Do not change the size of this struct, as it will break binary IO."); "Do not change the size of this struct, as it will break binary IO.");
} }
// Skip other legacy fields. // Skip other legacy fields.
Json ToJson() const { Json ToJson() const {
Object obj; Object obj;
char floats[NumericLimits<float>::kToCharsSize]; char floats[NumericLimits<float>::kToCharsSize];
auto ret = to_chars(floats, floats + NumericLimits<float>::kToCharsSize, base_score); auto ret = to_chars(floats, floats + NumericLimits<float>::kToCharsSize, base_score);
CHECK(ret.ec == std::errc()); CHECK(ret.ec == std::errc{});
obj["base_score"] = obj["base_score"] = std::string{floats, static_cast<size_t>(std::distance(floats, ret.ptr))};
std::string{floats, static_cast<size_t>(std::distance(floats, ret.ptr))};
char integers[NumericLimits<int64_t>::kToCharsSize]; char integers[NumericLimits<int64_t>::kToCharsSize];
ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize, ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
@ -136,10 +140,14 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
} }
this->Init(m); this->Init(m);
std::string str = get<String const>(j_param.at("base_score")); std::string str = get<String const>(j_param.at("base_score"));
from_chars(str.c_str(), str.c_str() + str.size(), base_score); from_chars(str.c_str(), str.c_str() + str.size(), base_score);
// It can only be estimated during the first training, we consider it estimated afterward
base_score_estimated = 1;
} }
inline LearnerModelParamLegacy ByteSwap() const {
LearnerModelParamLegacy ByteSwap() const {
LearnerModelParamLegacy x = *this; LearnerModelParamLegacy x = *this;
dmlc::ByteSwap(&x.base_score, sizeof(x.base_score), 1); dmlc::ByteSwap(&x.base_score, sizeof(x.base_score), 1);
dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1); dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
@ -149,14 +157,30 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1); dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1); dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1); dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1);
dmlc::ByteSwap(&x.base_score_estimated, sizeof(x.base_score_estimated), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0])); dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x; return x;
} }
template <typename Container>
Args UpdateAllowUnknown(Container const& kwargs) {
// Detect whether user has made their own base score.
if (std::find_if(kwargs.cbegin(), kwargs.cend(),
[](auto const& kv) { return kv.first == "base_score"; }) != kwargs.cend()) {
base_score_estimated = true;
}
if (std::find_if(kwargs.cbegin(), kwargs.cend(), [](auto const& kv) {
return kv.first == "base_score_estimated";
}) != kwargs.cend()) {
LOG(FATAL) << "`base_score_estimated` cannot be specified as hyper-parameter.";
}
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
}
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) { DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) {
DMLC_DECLARE_FIELD(base_score) DMLC_DECLARE_FIELD(base_score)
.set_default(0.5f) .set_default(ObjFunction::DefaultBaseScore())
.describe("Global bias of the model."); .describe("Global bias of the model.");
DMLC_DECLARE_FIELD(num_feature) DMLC_DECLARE_FIELD(num_feature)
.set_default(0) .set_default(0)
@ -170,12 +194,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
.set_default(1) .set_default(1)
.set_lower_bound(1) .set_lower_bound(1)
.describe("Number of target for multi-target regression."); .describe("Number of target for multi-target regression.");
DMLC_DECLARE_FIELD(base_score_estimated).set_default(0);
} }
}; };
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t)
ObjInfo t) : num_feature{user_param.num_feature}, task{t} {
: base_score{base_margin}, num_feature{user_param.num_feature}, task{t} {
auto n_classes = std::max(static_cast<uint32_t>(user_param.num_class), 1u); auto n_classes = std::max(static_cast<uint32_t>(user_param.num_class), 1u);
auto n_targets = user_param.num_target; auto n_targets = user_param.num_target;
num_output_group = std::max(n_classes, n_targets); num_output_group = std::max(n_classes, n_targets);
@ -185,6 +209,53 @@ LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param,
<< ", n_targets:" << n_targets; << ", n_targets:" << n_targets;
} }
LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param,
linalg::Tensor<float, 1> base_margin, ObjInfo t)
: LearnerModelParam{user_param, t} {
std::swap(base_score_, base_margin);
// Make sure read access everywhere for thread-safe prediction.
common::AsConst(base_score_).HostView();
if (!ctx->IsCPU()) {
common::AsConst(base_score_).View(ctx->gpu_id);
}
CHECK(common::AsConst(base_score_).Data()->HostCanRead());
}
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const {
// multi-class is not yet supported.
CHECK_EQ(base_score_.Size(), 1);
if (device == Context::kCpuId) {
// Make sure that we won't run into race condition.
CHECK(base_score_.Data()->HostCanRead());
return base_score_.HostView();
}
// Make sure that we won't run into race condition.
CHECK(base_score_.Data()->DeviceCanRead());
auto v = base_score_.View(device);
CHECK(base_score_.Data()->HostCanRead()); // make sure read access is not removed.
return v;
}
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(Context const* ctx) const {
return this->BaseScore(ctx->gpu_id);
}
void LearnerModelParam::Copy(LearnerModelParam const& that) {
base_score_.Reshape(that.base_score_.Shape());
base_score_.Data()->SetDevice(that.base_score_.DeviceIdx());
base_score_.Data()->Copy(*that.base_score_.Data());
common::AsConst(base_score_).HostView();
if (that.base_score_.DeviceIdx() != Context::kCpuId) {
common::AsConst(base_score_).View(that.base_score_.DeviceIdx());
}
CHECK_EQ(base_score_.Data()->DeviceCanRead(), that.base_score_.Data()->DeviceCanRead());
CHECK(base_score_.Data()->HostCanRead());
num_feature = that.num_feature;
num_output_group = that.num_output_group;
task = that.task;
}
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> { struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none. // data split mode, can be row, col, or none.
DataSplitMode dsplit {DataSplitMode::kAuto}; DataSplitMode dsplit {DataSplitMode::kAuto};
@ -308,8 +379,61 @@ class LearnerConfiguration : public Learner {
LearnerModelParamLegacy mparam_; LearnerModelParamLegacy mparam_;
LearnerModelParam learner_model_param_; LearnerModelParam learner_model_param_;
LearnerTrainParam tparam_; LearnerTrainParam tparam_;
// Initial prediction.
std::vector<std::string> metric_names_; std::vector<std::string> metric_names_;
/**
* \brief Calculate the `base_score` based on input data.
*
* \param p_fmat The training DMatrix used to estimate the base score.
*/
void InitBaseScore(DMatrix const* p_fmat) {
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
// keep the stability, we initialize it in binary LoadModel instead of configuration.
// Under what condition should we omit the transformation:
//
// - base_score is loaded from old binary model.
//
// What are the other possible conditions:
//
// - model loaded from new binary or JSON.
// - model is created from scratch.
// - model is configured second time due to change of parameter
CHECK(obj_);
if (!mparam_.base_score_estimated) {
if (p_fmat) {
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
obj_->InitEstimation(p_fmat->Info(), &base_score);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
} else {
mparam_.base_score = ObjFunction::DefaultBaseScore();
}
mparam_.base_score_estimated = true;
// Update the shared model parameter
this->ConfigureModelParam();
}
}
// Convert mparam to learner_model_param
void ConfigureModelParam() {
this->ConfigureTargets();
CHECK(obj_);
auto task = obj_->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
auto h_base_score = base_score.HostView();
// transform to margin
h_base_score(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(Ctx(), mparam_, std::move(base_score), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(Ctx()).Size(), 0);
}
public: public:
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache) explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
: need_configuration_{true} { : need_configuration_{true} {
@ -329,22 +453,24 @@ class LearnerConfiguration : public Learner {
// Configuration before data is known. // Configuration before data is known.
void Configure() override { void Configure() override {
// Varient of double checked lock // Varient of double checked lock
if (!this->need_configuration_) { return; } if (!this->need_configuration_) {
return;
}
std::lock_guard<std::mutex> guard(config_lock_); std::lock_guard<std::mutex> guard(config_lock_);
if (!this->need_configuration_) { return; } if (!this->need_configuration_) {
return;
}
monitor_.Start("Configure"); monitor_.Start("Configure");
auto old_tparam = tparam_; auto old_tparam = tparam_;
Args args = {cfg_.cbegin(), cfg_.cend()}; Args args = {cfg_.cbegin(), cfg_.cend()};
tparam_.UpdateAllowUnknown(args); tparam_.UpdateAllowUnknown(args);
auto mparam_backup = mparam_;
mparam_.UpdateAllowUnknown(args); mparam_.UpdateAllowUnknown(args);
auto initialized = generic_parameters_.GetInitialised(); auto initialized = ctx_.GetInitialised();
auto old_seed = generic_parameters_.seed; auto old_seed = ctx_.seed;
generic_parameters_.UpdateAllowUnknown(args); ctx_.UpdateAllowUnknown(args);
ConsoleLogger::Configure(args); ConsoleLogger::Configure(args);
@ -355,8 +481,8 @@ class LearnerConfiguration : public Learner {
} }
// set seed only before the model is initialized // set seed only before the model is initialized
if (!initialized || generic_parameters_.seed != old_seed) { if (!initialized || ctx_.seed != old_seed) {
common::GlobalRandom().seed(generic_parameters_.seed); common::GlobalRandom().seed(ctx_.seed);
} }
// must precede configure gbm since num_features is required for gbm // must precede configure gbm since num_features is required for gbm
@ -364,31 +490,15 @@ class LearnerConfiguration : public Learner {
args = {cfg_.cbegin(), cfg_.cend()}; // renew args = {cfg_.cbegin(), cfg_.cend()}; // renew
this->ConfigureObjective(old_tparam, &args); this->ConfigureObjective(old_tparam, &args);
auto task = this->ConfigureTargets(); learner_model_param_.task = obj_->Task(); // required by gbm configuration.
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
// keep the stability, we initialize it in binary LoadModel instead of configuration.
// Under what condition should we omit the transformation:
//
// - base_score is loaded from old binary model.
//
// What are the other possible conditions:
//
// - model loaded from new binary or JSON.
// - model is created from scratch.
// - model is configured second time due to change of parameter
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), task);
}
this->ConfigureGBM(old_tparam, args); this->ConfigureGBM(old_tparam, args);
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU()); ctx_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureModelParam();
this->ConfigureMetrics(args); this->ConfigureMetrics(args);
this->need_configuration_ = false; this->need_configuration_ = false;
if (generic_parameters_.validate_parameters) { if (ctx_.validate_parameters) {
this->ValidateParameters(); this->ValidateParameters();
} }
@ -396,6 +506,11 @@ class LearnerConfiguration : public Learner {
monitor_.Stop("Configure"); monitor_.Stop("Configure");
} }
void CheckModelInitialized() const {
CHECK(learner_model_param_.Initialized()) << "Model not yet initialized.";
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0);
}
virtual PredictionContainer* GetPredictionCache() const { virtual PredictionContainer* GetPredictionCache() const {
return &((*ThreadLocalPredictionCache::Get())[this]); return &((*ThreadLocalPredictionCache::Get())[this]);
} }
@ -417,7 +532,7 @@ class LearnerConfiguration : public Learner {
auto const& objective_fn = learner_parameters.at("objective"); auto const& objective_fn = learner_parameters.at("objective");
if (!obj_) { if (!obj_) {
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
} }
obj_->LoadConfig(objective_fn); obj_->LoadConfig(objective_fn);
learner_model_param_.task = obj_->Task(); learner_model_param_.task = obj_->Task();
@ -425,7 +540,7 @@ class LearnerConfiguration : public Learner {
tparam_.booster = get<String>(gradient_booster["name"]); tparam_.booster = get<String>(gradient_booster["name"]);
if (!gbm_) { if (!gbm_) {
gbm_.reset(GradientBooster::Create(tparam_.booster, gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_)); &ctx_, &learner_model_param_));
} }
gbm_->LoadConfig(gradient_booster); gbm_->LoadConfig(gradient_booster);
@ -441,15 +556,15 @@ class LearnerConfiguration : public Learner {
} else { } else {
metric_names_[i] = get<String>(j_metrics[i]["name"]); metric_names_[i] = get<String>(j_metrics[i]["name"]);
} }
metrics_[i] = std::unique_ptr<Metric>(Metric::Create(metric_names_[i], &generic_parameters_)); metrics_[i] = std::unique_ptr<Metric>(Metric::Create(metric_names_[i], &ctx_));
if (!old_serialization) { if (!old_serialization) {
metrics_[i]->LoadConfig(j_metrics[i]); metrics_[i]->LoadConfig(j_metrics[i]);
} }
} }
FromJson(learner_parameters.at("generic_param"), &generic_parameters_); FromJson(learner_parameters.at("generic_param"), &ctx_);
// make sure the GPU ID is valid in new environment before start running configure. // make sure the GPU ID is valid in new environment before start running configure.
generic_parameters_.ConfigureGpuId(false); ctx_.ConfigureGpuId(false);
this->need_configuration_ = true; this->need_configuration_ = true;
} }
@ -478,7 +593,7 @@ class LearnerConfiguration : public Learner {
} }
learner_parameters["metrics"] = Array(std::move(metrics)); learner_parameters["metrics"] = Array(std::move(metrics));
learner_parameters["generic_param"] = ToJson(generic_parameters_); learner_parameters["generic_param"] = ToJson(ctx_);
} }
void SetParam(const std::string& key, const std::string& value) override { void SetParam(const std::string& key, const std::string& value) override {
@ -551,7 +666,7 @@ class LearnerConfiguration : public Learner {
return cfg_; return cfg_;
} }
GenericParameter const* Ctx() const override { return &generic_parameters_; } Context const* Ctx() const override { return &ctx_; }
private: private:
void ValidateParameters() { void ValidateParameters() {
@ -654,7 +769,7 @@ class LearnerConfiguration : public Learner {
void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
if (gbm_ == nullptr || old.booster != tparam_.booster) { if (gbm_ == nullptr || old.booster != tparam_.booster) {
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_,
&learner_model_param_)); &learner_model_param_));
} }
gbm_->Configure(args); gbm_->Configure(args);
@ -678,7 +793,7 @@ class LearnerConfiguration : public Learner {
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue; cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
} }
if (obj_ == nullptr || tparam_.objective != old.objective) { if (obj_ == nullptr || tparam_.objective != old.objective) {
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
} }
auto& args = *p_args; auto& args = *p_args;
args = {cfg_.cbegin(), cfg_.cend()}; // renew args = {cfg_.cbegin(), cfg_.cend()}; // renew
@ -691,7 +806,7 @@ class LearnerConfiguration : public Learner {
return m->Name() != name; return m->Name() != name;
}; };
if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) { if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) {
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &generic_parameters_))); metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &ctx_)));
mparam_.contain_eval_metrics = 1; mparam_.contain_eval_metrics = 1;
} }
} }
@ -703,7 +818,7 @@ class LearnerConfiguration : public Learner {
/** /**
* Get number of targets from objective function. * Get number of targets from objective function.
*/ */
ObjInfo ConfigureTargets() { void ConfigureTargets() {
CHECK(this->obj_); CHECK(this->obj_);
auto const& cache = this->GetPredictionCache()->Container(); auto const& cache = this->GetPredictionCache()->Container();
size_t n_targets = 1; size_t n_targets = 1;
@ -722,7 +837,6 @@ class LearnerConfiguration : public Learner {
} else { } else {
mparam_.num_target = n_targets; mparam_.num_target = n_targets;
} }
return this->obj_->Task();
} }
}; };
@ -754,14 +868,14 @@ class LearnerIO : public LearnerConfiguration {
std::string name = get<String>(objective_fn["name"]); std::string name = get<String>(objective_fn["name"]);
tparam_.UpdateAllowUnknown(Args{{"objective", name}}); tparam_.UpdateAllowUnknown(Args{{"objective", name}});
obj_.reset(ObjFunction::Create(name, &generic_parameters_)); obj_.reset(ObjFunction::Create(name, &ctx_));
obj_->LoadConfig(objective_fn); obj_->LoadConfig(objective_fn);
auto const& gradient_booster = learner.at("gradient_booster"); auto const& gradient_booster = learner.at("gradient_booster");
name = get<String>(gradient_booster["name"]); name = get<String>(gradient_booster["name"]);
tparam_.UpdateAllowUnknown(Args{{"booster", name}}); tparam_.UpdateAllowUnknown(Args{{"booster", name}});
gbm_.reset( gbm_.reset(
GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_)); GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_));
gbm_->LoadModel(gradient_booster); gbm_->LoadModel(gradient_booster);
auto const& j_attributes = get<Object const>(learner.at("attributes")); auto const& j_attributes = get<Object const>(learner.at("attributes"));
@ -791,6 +905,7 @@ class LearnerIO : public LearnerConfiguration {
void SaveModel(Json* p_out) const override { void SaveModel(Json* p_out) const override {
CHECK(!this->need_configuration_) << "Call Configure before saving model."; CHECK(!this->need_configuration_) << "Call Configure before saving model.";
this->CheckModelInitialized();
Version::Save(p_out); Version::Save(p_out);
Json& out { *p_out }; Json& out { *p_out };
@ -826,7 +941,7 @@ class LearnerIO : public LearnerConfiguration {
// About to be deprecated by JSON format // About to be deprecated by JSON format
void LoadModel(dmlc::Stream* fi) override { void LoadModel(dmlc::Stream* fi) override {
generic_parameters_.UpdateAllowUnknown(Args{}); ctx_.UpdateAllowUnknown(Args{});
tparam_.Init(std::vector<std::pair<std::string, std::string>>{}); tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
// TODO(tqchen) mark deprecation of old format. // TODO(tqchen) mark deprecation of old format.
common::PeekableInStream fp(fi); common::PeekableInStream fp(fi);
@ -881,8 +996,8 @@ class LearnerIO : public LearnerConfiguration {
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_,
&learner_model_param_)); &learner_model_param_));
gbm_->Load(fi); gbm_->Load(fi);
if (mparam_.contain_extra_attrs != 0) { if (mparam_.contain_extra_attrs != 0) {
@ -925,7 +1040,14 @@ class LearnerIO : public LearnerConfiguration {
} }
learner_model_param_ = learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); LearnerModelParam(&ctx_, mparam_,
linalg::Tensor<float, 1>{{std::isnan(mparam_.base_score)
? std::numeric_limits<float>::quiet_NaN()
: obj_->ProbToMargin(mparam_.base_score)},
{1},
Context::kCpuId},
obj_->Task());
if (attributes_.find("objective") != attributes_.cend()) { if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective"); auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
@ -969,6 +1091,8 @@ class LearnerIO : public LearnerConfiguration {
// Save model into binary format. The code is about to be deprecated by more robust // Save model into binary format. The code is about to be deprecated by more robust
// JSON serialization format. // JSON serialization format.
void SaveModel(dmlc::Stream* fo) const override { void SaveModel(dmlc::Stream* fo) const override {
this->CheckModelInitialized();
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr; std::vector<std::pair<std::string, std::string> > extra_attr;
mparam.contain_extra_attrs = 1; mparam.contain_extra_attrs = 1;
@ -1000,6 +1124,7 @@ class LearnerIO : public LearnerConfiguration {
} }
extra_attr.emplace_back("metrics", os.str()); extra_attr.emplace_back("metrics", os.str());
} }
std::string header {"binf"}; std::string header {"binf"};
fo->Write(header.data(), 4); fo->Write(header.data(), 4);
if (DMLC_IO_NO_ENDIAN_SWAP) { if (DMLC_IO_NO_ENDIAN_SWAP) {
@ -1022,6 +1147,8 @@ class LearnerIO : public LearnerConfiguration {
} }
void Save(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
this->CheckModelInitialized();
Json memory_snapshot{Object()}; Json memory_snapshot{Object()};
memory_snapshot["Model"] = Object(); memory_snapshot["Model"] = Object();
auto& model = memory_snapshot["Model"]; auto& model = memory_snapshot["Model"];
@ -1108,28 +1235,30 @@ class LearnerImpl : public LearnerIO {
} }
} }
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
bool with_stats,
std::string format) override { std::string format) override {
this->Configure(); this->Configure();
this->CheckModelInitialized();
return gbm_->DumpModel(fmap, with_stats, format); return gbm_->DumpModel(fmap, with_stats, format);
} }
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step, Learner* Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) override { bool* out_of_bound) override {
this->Configure(); this->Configure();
this->CheckModelInitialized();
CHECK_NE(this->learner_model_param_.num_feature, 0); CHECK_NE(this->learner_model_param_.num_feature, 0);
CHECK_GE(begin_layer, 0); CHECK_GE(begin_layer, 0);
auto *out_impl = new LearnerImpl({}); auto* out_impl = new LearnerImpl({});
out_impl->learner_model_param_ = this->learner_model_param_; out_impl->learner_model_param_.Copy(this->learner_model_param_);
out_impl->generic_parameters_ = this->generic_parameters_; out_impl->ctx_ = this->ctx_;
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create( auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
this->tparam_.booster, &out_impl->generic_parameters_, this->tparam_.booster, &out_impl->ctx_, &out_impl->learner_model_param_));
&out_impl->learner_model_param_));
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound); this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
out_impl->gbm_ = std::move(gbm); out_impl->gbm_ = std::move(gbm);
Json config { Object() }; Json config{Object()};
this->SaveConfig(&config); this->SaveConfig(&config);
out_impl->mparam_ = this->mparam_; out_impl->mparam_ = this->mparam_;
out_impl->attributes_ = this->attributes_; out_impl->attributes_ = this->attributes_;
@ -1156,15 +1285,17 @@ class LearnerImpl : public LearnerIO {
monitor_.Start("UpdateOneIter"); monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter); TrainingObserver::Instance().Update(iter);
this->Configure(); this->Configure();
if (generic_parameters_.seed_per_iteration) { this->InitBaseScore(train.get());
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id); auto& predt = local_cache->Cache(train, ctx_.gpu_id);
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true, 0, 0); this->PredictRaw(train.get(), &predt, true, 0, 0);
@ -1184,14 +1315,18 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<GradientPair>* in_gpair) override { HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter"); monitor_.Start("BoostOneIter");
this->Configure(); this->Configure();
if (generic_parameters_.seed_per_iteration) { // Should have been set to default in the first prediction.
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); CHECK(mparam_.base_score_estimated);
if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
local_cache->Cache(train, generic_parameters_.gpu_id); local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get()); gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
@ -1202,23 +1337,24 @@ class LearnerImpl : public LearnerIO {
const std::vector<std::string>& data_names) override { const std::vector<std::string>& data_names) override {
monitor_.Start("EvalOneIter"); monitor_.Start("EvalOneIter");
this->Configure(); this->Configure();
this->CheckModelInitialized();
std::ostringstream os; std::ostringstream os;
os.precision(std::numeric_limits<double>::max_digits10); os.precision(std::numeric_limits<double>::max_digits10);
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed); os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) { if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_)); metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
} }
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
for (size_t i = 0; i < data_sets.size(); ++i) { for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i]; std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id); auto &predt = local_cache->Cache(m, ctx_.gpu_id);
this->ValidateDMatrix(m.get(), false); this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false, 0, 0); this->PredictRaw(m.get(), &predt, false, 0, 0);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; auto &out = output_predictions_.Cache(m, ctx_.gpu_id).predictions;
out.Resize(predt.predictions.Size()); out.Resize(predt.predictions.Size());
out.Copy(predt.predictions); out.Copy(predt.predictions);
@ -1241,6 +1377,9 @@ class LearnerImpl : public LearnerIO {
static_cast<int>(pred_interactions) + static_cast<int>(pred_interactions) +
static_cast<int>(pred_contribs); static_cast<int>(pred_contribs);
this->Configure(); this->Configure();
this->InitBaseScore(nullptr);
this->CheckModelInitialized();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) { if (pred_contribs) {
gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs); gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs);
@ -1251,10 +1390,10 @@ class LearnerImpl : public LearnerIO {
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end); gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else { } else {
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id); auto& prediction = local_cache->Cache(data, ctx_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API // Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(generic_parameters_.gpu_id); out_preds->SetDevice(ctx_.gpu_id);
out_preds->Resize(prediction.predictions.Size()); out_preds->Resize(prediction.predictions.Size());
out_preds->Copy(prediction.predictions); out_preds->Copy(prediction.predictions);
if (!output_margin) { if (!output_margin) {
@ -1268,8 +1407,10 @@ class LearnerImpl : public LearnerIO {
CHECK(!this->need_configuration_); CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds(); return this->gbm_->BoostedRounds();
} }
uint32_t Groups() const override { uint32_t Groups() const override {
CHECK(!this->need_configuration_); CHECK(!this->need_configuration_);
this->CheckModelInitialized();
return this->learner_model_param_.num_output_group; return this->learner_model_param_.num_output_group;
} }
@ -1281,6 +1422,9 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin, HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
uint32_t iteration_end) override { uint32_t iteration_end) override {
this->Configure(); this->Configure();
this->InitBaseScore(nullptr);
this->CheckModelInitialized();
auto& out_predictions = this->GetThreadLocal().prediction_entry; auto& out_predictions = this->GetThreadLocal().prediction_entry;
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end); this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
if (type == PredictionType::kValue) { if (type == PredictionType::kValue) {
@ -1296,6 +1440,8 @@ class LearnerImpl : public LearnerIO {
void CalcFeatureScore(std::string const& importance_type, common::Span<int32_t const> trees, void CalcFeatureScore(std::string const& importance_type, common::Span<int32_t const> trees,
std::vector<bst_feature_t>* features, std::vector<float>* scores) override { std::vector<bst_feature_t>* features, std::vector<float>* scores) override {
this->Configure(); this->Configure();
this->CheckModelInitialized();
gbm_->FeatureScore(importance_type, trees, features, scores); gbm_->FeatureScore(importance_type, trees, features, scores);
} }
@ -1315,17 +1461,17 @@ class LearnerImpl : public LearnerIO {
void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training, void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training,
unsigned layer_begin, unsigned layer_end) const { unsigned layer_begin, unsigned layer_end) const {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->CheckModelInitialized();
this->ValidateDMatrix(data, false); this->ValidateDMatrix(data, false);
gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end); gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
} }
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const { void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
MetaInfo const& info = p_fmat->Info(); MetaInfo const& info = p_fmat->Info();
info.Validate(generic_parameters_.gpu_id); info.Validate(ctx_.gpu_id);
auto const row_based_split = [this]() { auto const row_based_split = [this]() {
return tparam_.dsplit == DataSplitMode::kRow || return tparam_.dsplit == DataSplitMode::kRow || tparam_.dsplit == DataSplitMode::kAuto;
tparam_.dsplit == DataSplitMode::kAuto;
}; };
if (row_based_split()) { if (row_based_split()) {
if (is_training) { if (is_training) {

View File

@ -7,6 +7,7 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "../common/common.h"
#include "rabit/rabit.h" #include "rabit/rabit.h"
#include "xgboost/generic_parameters.h" #include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"

View File

@ -1,10 +1,10 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2022 by Contributors
* \file objective.cc * \file objective.cc
* \brief Registry of all objective functions. * \brief Registry of all objective functions.
*/ */
#include <xgboost/objective.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <xgboost/objective.h>
#include <sstream> #include <sstream>
@ -31,6 +31,11 @@ ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const
return pobj; return pobj;
} }
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
CHECK(base_score);
base_score->Reshape(1);
(*base_score)(0) = DefaultBaseScore();
}
} // namespace xgboost } // namespace xgboost
namespace xgboost { namespace xgboost {

View File

@ -15,7 +15,9 @@
#include "../common/common.h" #include "../common/common.h"
#include "../common/linalg_op.h" #include "../common/linalg_op.h"
#include "../common/numeric.h" // Reduce
#include "../common/pseudo_huber.h" #include "../common/pseudo_huber.h"
#include "../common/stats.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/transform.h" #include "../common/transform.h"
#include "./regression_loss.h" #include "./regression_loss.h"
@ -37,14 +39,18 @@
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
namespace { namespace {
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) { void CheckInitInputs(MetaInfo const& info) {
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels."; CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) { if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_) CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points."; << "Number of weights should be equal to number of data points.";
} }
} }
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
CheckInitInputs(info);
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
}
} // anonymous namespace } // anonymous namespace
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
@ -698,6 +704,33 @@ class MeanAbsoluteError : public ObjFunction {
}); });
} }
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
CheckInitInputs(info);
base_margin->Reshape(1);
auto out = base_margin->HostView();
double w{0.0};
if (info.weights_.Empty()) {
w = static_cast<double>(info.num_row_);
} else {
w = common::Reduce(ctx_, info.weights_);
}
if (info.num_row_ == 0) {
out(0) = 0;
} else {
// weighted avg
out(0) = common::Median(ctx_, info.labels, info.weights_) * w;
}
// Weighted average base score across all workers
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
rabit::Allreduce<rabit::op::Sum>(&w, 1);
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
[w](float v) { return v / w; });
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info, void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const override { HostDeviceVector<float> const& prediction, RegTree* p_tree) const override {
if (ctx_->IsCPU()) { if (ctx_->IsCPU()) {

View File

@ -429,11 +429,12 @@ class CPUPredictor : public Predictor {
} }
out_preds->resize(model.learner_model_param->num_output_group * out_preds->resize(model.learner_model_param->num_output_group *
(model.param.size_leaf_vector + 1)); (model.param.size_leaf_vector + 1));
auto base_score = model.learner_model_param->BaseScore(ctx_)(0);
// loop over output groups // loop over output groups
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid, (*out_preds)[gid] =
&feat_vecs[0], 0, ntree_limit) + PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], 0, ntree_limit) +
model.learner_model_param->base_score; base_score;
} }
} }
@ -504,7 +505,8 @@ class CPUPredictor : public Predictor {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
}); });
auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); auto base_margin = info.base_margin_.View(Context::kCpuId);
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
// start collecting the contributions // start collecting the contributions
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) { for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView(); auto page = batch.GetView();
@ -548,7 +550,7 @@ class CPUPredictor : public Predictor {
CHECK_EQ(base_margin.Shape(1), ngroup); CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid); p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else { } else {
p_contribs[ncolumns - 1] += model.learner_model_param->base_score; p_contribs[ncolumns - 1] += base_score;
} }
} }
}); });

View File

@ -511,7 +511,7 @@ void ExtractPaths(
n = d_nodes[n.Parent() + tree_offset]; n = d_nodes[n.Parent() + tree_offset];
path_length++; path_length++;
} }
return PathInfo{int64_t(idx), path_length, tree_idx}; return PathInfo{static_cast<int64_t>(idx), path_length, tree_idx};
}); });
auto end = thrust::copy_if( auto end = thrust::copy_if(
thrust::cuda::par(alloc), nodes_transform, thrust::cuda::par(alloc), nodes_transform,
@ -859,13 +859,13 @@ class GPUPredictor : public xgboost::Predictor {
// Add the base margin term to last column // Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id); p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
dh::LaunchN( auto base_score = model.learner_model_param->BaseScore(ctx_);
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) { [=] __device__(size_t idx) {
phis[(idx + 1) * contributions_columns - 1] += phis[(idx + 1) * contributions_columns - 1] +=
margin.empty() ? base_score : margin[idx]; margin.empty() ? base_score(0) : margin[idx];
}); });
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
@ -918,17 +918,17 @@ class GPUPredictor : public xgboost::Predictor {
// Add the base margin term to last column // Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id); p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
auto base_score = model.learner_model_param->BaseScore(ctx_);
size_t n_features = model.learner_model_param->num_feature; size_t n_features = model.learner_model_param->num_feature;
dh::LaunchN( dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, [=] __device__(size_t idx) {
[=] __device__(size_t idx) { size_t group = idx % ngroup;
size_t group = idx % ngroup; size_t row_idx = idx / ngroup;
size_t row_idx = idx / ngroup; phis[gpu_treeshap::IndexPhiInteractions(row_idx, ngroup, group, n_features,
phis[gpu_treeshap::IndexPhiInteractions( n_features, n_features)] +=
row_idx, ngroup, group, n_features, n_features, n_features)] += margin.empty() ? base_score(0) : margin[idx];
margin.empty() ? base_score : margin[idx]; });
});
} }
void PredictInstance(const SparsePage::Inst&, void PredictInstance(const SparsePage::Inst&,

View File

@ -80,14 +80,15 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
if (ctx_->gpu_id >= 0) { if (ctx_->gpu_id >= 0) {
out_preds->SetDevice(ctx_->gpu_id); out_preds->SetDevice(ctx_->gpu_id);
} }
if (base_margin->Size() != 0) { if (!base_margin->Empty()) {
out_preds->Resize(n); out_preds->Resize(n);
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
out_preds->Copy(*base_margin); out_preds->Copy(*base_margin);
} else { } else {
out_preds->Resize(n);
// cannot rely on the Resize to fill as it might skip if the size is already correct. // cannot rely on the Resize to fill as it might skip if the size is already correct.
out_preds->Fill(model.learner_model_param->base_score); out_preds->Resize(n);
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
out_preds->Fill(base_score);
} }
} }
} // namespace xgboost } // namespace xgboost

View File

@ -29,5 +29,15 @@ TEST(Numeric, PartialSum) {
ASSERT_EQ(sol, result); ASSERT_EQ(sol, result);
} }
} }
TEST(Numeric, Reduce) {
Context ctx;
ASSERT_TRUE(ctx.IsCPU());
HostDeviceVector<float> values(20);
auto& h_values = values.HostVector();
std::iota(h_values.begin(), h_values.end(), 0.0f);
auto sum = Reduce(&ctx, values);
ASSERT_EQ(sum, (values.Size() - 1) * values.Size() / 2);
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -54,5 +54,20 @@ TEST(Stats, WeightedQuantile) {
q = WeightedQuantile(1.0, beg, end, w); q = WeightedQuantile(1.0, beg, end, w);
ASSERT_EQ(q, 5); ASSERT_EQ(q, 5);
} }
TEST(Stats, Median) {
linalg::Tensor<float, 2> values{{.0f, .0f, 1.f, 2.f}, {4}, Context::kCpuId};
Context ctx;
HostDeviceVector<float> weights;
auto m = Median(&ctx, values, weights);
ASSERT_EQ(m, .5f);
#if defined(XGBOOST_USE_CUDA)
ctx.gpu_id = 0;
ASSERT_FALSE(ctx.IsCPU());
m = Median(&ctx, values, weights);
ASSERT_EQ(m, .5f);
#endif // defined(XGBOOST_USE_CUDA)
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -19,15 +19,11 @@ namespace gbm {
TEST(GBLinear, JsonIO) { TEST(GBLinear, JsonIO) {
size_t constexpr kRows = 16, kCols = 16; size_t constexpr kRows = 16, kCols = 16;
LearnerModelParam param; Context ctx;
param.num_feature = kCols; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
param.num_output_group = 1;
GenericParameter gparam; std::unique_ptr<GradientBooster> gbm{
gparam.Init(Args{}); CreateTrainedGBM("gblinear", Args{}, kRows, kCols, &mparam, &ctx)};
std::unique_ptr<GradientBooster> gbm {
CreateTrainedGBM("gblinear", Args{}, kRows, kCols, &param, &gparam) };
Json model { Object() }; Json model { Object() };
gbm->SaveModel(&model); gbm->SaveModel(&model);
ASSERT_TRUE(IsA<Object>(model)); ASSERT_TRUE(IsA<Object>(model));

View File

@ -18,15 +18,11 @@ namespace xgboost {
TEST(GBTree, SelectTreeMethod) { TEST(GBTree, SelectTreeMethod) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
GenericParameter generic_param; Context ctx;
generic_param.UpdateAllowUnknown(Args{}); LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
LearnerModelParam mparam;
mparam.base_score = 0.5;
mparam.num_feature = kCols;
mparam.num_output_group = 1;
std::unique_ptr<GradientBooster> p_gbm { std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam)}; GradientBooster::Create("gbtree", &ctx, &mparam)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm); auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
// Test if `tree_method` can be set // Test if `tree_method` can be set
@ -45,7 +41,7 @@ TEST(GBTree, SelectTreeMethod) {
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker"); ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
generic_param.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
gbtree.Configure({{"tree_method", "gpu_hist"}}); gbtree.Configure({{"tree_method", "gpu_hist"}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist"); ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
gbtree.Configure({{"booster", "dart"}, {"tree_method", "gpu_hist"}}); gbtree.Configure({{"booster", "dart"}, {"tree_method", "gpu_hist"}});
@ -55,15 +51,11 @@ TEST(GBTree, SelectTreeMethod) {
TEST(GBTree, PredictionCache) { TEST(GBTree, PredictionCache) {
size_t constexpr kRows = 100, kCols = 10; size_t constexpr kRows = 100, kCols = 10;
GenericParameter generic_param; Context ctx;
generic_param.UpdateAllowUnknown(Args{}); LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
LearnerModelParam mparam;
mparam.base_score = 0.5;
mparam.num_feature = kCols;
mparam.num_output_group = 1;
std::unique_ptr<GradientBooster> p_gbm { std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam)}; GradientBooster::Create("gbtree", &ctx, &mparam)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm); auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
gbtree.Configure({{"tree_method", "hist"}}); gbtree.Configure({{"tree_method", "hist"}});
@ -176,16 +168,11 @@ TEST(GBTree, ChoosePredictor) {
TEST(GBTree, JsonIO) { TEST(GBTree, JsonIO) {
size_t constexpr kRows = 16, kCols = 16; size_t constexpr kRows = 16, kCols = 16;
LearnerModelParam mparam; Context ctx;
mparam.num_feature = kCols; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
mparam.num_output_group = 1;
mparam.base_score = 0.5;
GenericParameter gparam;
gparam.Init(Args{});
std::unique_ptr<GradientBooster> gbm { std::unique_ptr<GradientBooster> gbm {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) }; CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &ctx) };
Json model {Object()}; Json model {Object()};
model["model"] = Object(); model["model"] = Object();
@ -215,16 +202,11 @@ TEST(GBTree, JsonIO) {
TEST(Dart, JsonIO) { TEST(Dart, JsonIO) {
size_t constexpr kRows = 16, kCols = 16; size_t constexpr kRows = 16, kCols = 16;
LearnerModelParam mparam; Context ctx;
mparam.num_feature = kCols; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
mparam.base_score = 0.5;
mparam.num_output_group = 1;
GenericParameter gparam; std::unique_ptr<GradientBooster> gbm{
gparam.Init(Args{}); CreateTrainedGBM("dart", Args{}, kRows, kCols, &mparam, &ctx)};
std::unique_ptr<GradientBooster> gbm {
CreateTrainedGBM("dart", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()}; Json model {Object()};
model["model"] = Object(); model["model"] = Object();

View File

@ -451,5 +451,16 @@ class RMMAllocator;
using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>; using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>;
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv);
/*
* \brief Make learner model param
*/
inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint32_t n_groups,
int32_t device = Context::kCpuId) {
size_t shape[1]{1};
LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
n_groups);
return mparam;
}
} // namespace xgboost } // namespace xgboost
#endif #endif

View File

@ -18,10 +18,7 @@ TEST(Linear, Shotgun) {
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
LearnerModelParam mparam; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
mparam.num_feature = kCols;
mparam.num_output_group = 1;
mparam.base_score = 0.5;
{ {
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
@ -54,10 +51,7 @@ TEST(Linear, coordinate) {
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
LearnerModelParam mparam; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
mparam.num_feature = kCols;
mparam.num_output_group = 1;
mparam.base_score = 0.5;
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("coord_descent", &lparam)); xgboost::LinearUpdater::Create("coord_descent", &lparam));

View File

@ -13,15 +13,11 @@ TEST(Linear, GPUCoordinate) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
auto mat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto mat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto lparam = CreateEmptyGenericParam(GPUIDX); auto ctx = CreateEmptyGenericParam(GPUIDX);
LearnerModelParam mparam;
mparam.num_feature = kCols;
mparam.num_output_group = 1;
mparam.base_score = 0.5;
LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("gpu_coord_descent", &lparam)); xgboost::LinearUpdater::Create("gpu_coord_descent", &ctx));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair( xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));

View File

@ -21,14 +21,11 @@ TEST(CpuPredictor, Basic) {
size_t constexpr kRows = 5; size_t constexpr kRows = 5;
size_t constexpr kCols = 5; size_t constexpr kCols = 5;
LearnerModelParam param; LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
param.num_feature = kCols;
param.base_score = 0.0;
param.num_output_group = 1;
GenericParameter ctx; GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{}); ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx); gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
@ -104,14 +101,11 @@ TEST(CpuPredictor, ExternalMemory) {
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
LearnerModelParam param; LearnerModelParam mparam{MakeMP(dmat->Info().num_col_, .0, 1)};
param.base_score = 0;
param.num_feature = dmat->Info().num_col_;
param.num_output_group = 1;
GenericParameter ctx; GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{}); ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx); gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
// Test predict batch // Test predict batch
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
@ -201,16 +195,11 @@ TEST(CpuPredictor, InplacePredict) {
void TestUpdatePredictionCache(bool use_subsampling) { void TestUpdatePredictionCache(bool use_subsampling) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4; size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam; LearnerModelParam mparam{MakeMP(kCols, .0, kClasses)};
mparam.num_feature = kCols; Context ctx;
mparam.num_output_group = kClasses;
mparam.base_score = 0;
GenericParameter gparam;
gparam.Init(Args{});
std::unique_ptr<gbm::GBTree> gbm; std::unique_ptr<gbm::GBTree> gbm;
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &gparam, &mparam))); gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &ctx, &mparam)));
std::map<std::string, std::string> cfg; std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist"; cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor"; cfg["predictor"] = "cpu_predictor";

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2020 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
@ -34,14 +34,10 @@ TEST(GPUPredictor, Basic) {
int n_row = i, n_col = i; int n_row = i, n_col = i;
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix(); auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix();
LearnerModelParam param; Context ctx;
param.num_feature = n_col; ctx.gpu_id = 0;
param.num_output_group = 1; LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)};
param.base_score = 0.5; gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx);
// Test predict batch // Test predict batch
PredictionCacheEntry gpu_out_predictions; PredictionCacheEntry gpu_out_predictions;
@ -93,15 +89,12 @@ TEST(GPUPredictor, ExternalMemoryTest) {
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
LearnerModelParam param;
param.num_feature = 5;
const int n_classes = 3; const int n_classes = 3;
param.num_output_group = n_classes; Context ctx;
param.base_score = 0.5; ctx.gpu_id = 0;
LearnerModelParam mparam{MakeMP(5, .5, n_classes, ctx.gpu_id)};
GenericParameter ctx; gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, n_classes);
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx, n_classes);
std::vector<std::unique_ptr<DMatrix>> dmats; std::vector<std::unique_ptr<DMatrix>> dmats;
dmats.push_back(CreateSparsePageDMatrix(400)); dmats.push_back(CreateSparsePageDMatrix(400));
@ -171,15 +164,10 @@ TEST(GpuPredictor, LesserFeatures) {
TEST(GPUPredictor, ShapStump) { TEST(GPUPredictor, ShapStump) {
cudaSetDevice(0); cudaSetDevice(0);
LearnerModelParam param; Context ctx;
param.num_feature = 1; ctx.gpu_id = 0;
param.num_output_group = 1; LearnerModelParam mparam{MakeMP(1, .5, 1, ctx.gpu_id)};
param.base_score = 0.5; gbm::GBTreeModel model(&mparam, &ctx);
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&param, &ctx);
std::vector<std::unique_ptr<RegTree>> trees; std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree)); trees.push_back(std::unique_ptr<RegTree>(new RegTree));
@ -193,24 +181,20 @@ TEST(GPUPredictor, ShapStump) {
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &predictions, model); gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
auto& phis = predictions.HostVector(); auto& phis = predictions.HostVector();
auto base_score = mparam.BaseScore(Context::kCpuId)(0);
EXPECT_EQ(phis[0], 0.0); EXPECT_EQ(phis[0], 0.0);
EXPECT_EQ(phis[1], param.base_score); EXPECT_EQ(phis[1], base_score);
EXPECT_EQ(phis[2], 0.0); EXPECT_EQ(phis[2], 0.0);
EXPECT_EQ(phis[3], param.base_score); EXPECT_EQ(phis[3], base_score);
EXPECT_EQ(phis[4], 0.0); EXPECT_EQ(phis[4], 0.0);
EXPECT_EQ(phis[5], param.base_score); EXPECT_EQ(phis[5], base_score);
} }
TEST(GPUPredictor, Shap) { TEST(GPUPredictor, Shap) {
LearnerModelParam param; Context ctx;
param.num_feature = 1; ctx.gpu_id = 0;
param.num_output_group = 1; LearnerModelParam mparam{MakeMP(1, .5, 1, ctx.gpu_id)};
param.base_score = 0.5; gbm::GBTreeModel model(&mparam, &ctx);
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&param, &ctx);
std::vector<std::unique_ptr<RegTree>> trees; std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree)); trees.push_back(std::unique_ptr<RegTree>(new RegTree));
@ -258,14 +242,9 @@ TEST(GPUPredictor, PredictLeafBasic) {
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
LearnerModelParam param; LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
param.num_feature = kCols; Context ctx;
param.base_score = 0.0; gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
param.num_output_group = 1;
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx);
HostDeviceVector<float> leaf_out_predictions; HostDeviceVector<float> leaf_out_predictions;
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);

View File

@ -210,11 +210,7 @@ void TestCategoricalPrediction(std::string name) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
LearnerModelParam param; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
param.num_feature = kCols;
param.num_output_group = 1;
param.base_score = 0.5;
uint32_t split_ind = 3; uint32_t split_ind = 3;
bst_cat_t split_cat = 4; bst_cat_t split_cat = 4;
float left_weight = 1.3f; float left_weight = 1.3f;
@ -222,7 +218,7 @@ void TestCategoricalPrediction(std::string name) {
GenericParameter ctx; GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{}); ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&param, &ctx); gbm::GBTreeModel model(&mparam, &ctx);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
@ -237,27 +233,24 @@ void TestCategoricalPrediction(std::string name) {
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0); predictor->PredictBatch(m.get(), &out_predictions, model, 0);
auto score = mparam.BaseScore(Context::kCpuId)(0);
ASSERT_EQ(out_predictions.predictions.Size(), 1ul); ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], ASSERT_EQ(out_predictions.predictions.HostVector()[0],
right_weight + param.base_score); // go to right for matching cat right_weight + score); // go to right for matching cat
row[split_ind] = split_cat + 1; row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols); m = GetDMatrixFromData(row, 1, kCols);
out_predictions.version = 0; out_predictions.version = 0;
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0); predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
left_weight + param.base_score);
} }
void TestCategoricalPredictLeaf(StringView name) { void TestCategoricalPredictLeaf(StringView name) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
LearnerModelParam param; LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
param.num_feature = kCols;
param.num_output_group = 1;
param.base_score = 0.5;
uint32_t split_ind = 3; uint32_t split_ind = 3;
bst_cat_t split_cat = 4; bst_cat_t split_cat = 4;
@ -267,7 +260,7 @@ void TestCategoricalPredictLeaf(StringView name) {
GenericParameter ctx; GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{}); ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&param, &ctx); gbm::GBTreeModel model(&mparam, &ctx);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
ctx.gpu_id = 0; ctx.gpu_id = 0;

View File

@ -12,11 +12,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
std::shared_ptr<DMatrix> p_hist) { std::shared_ptr<DMatrix> p_hist) {
constexpr size_t kClasses { 3 }; constexpr size_t kClasses { 3 };
LearnerModelParam param; LearnerModelParam mparam{MakeMP(cols, .5, kClasses)};
param.num_feature = cols;
param.num_output_group = kClasses;
param.base_score = 0.5;
auto lparam = CreateEmptyGenericParam(0); auto lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> predictor = std::unique_ptr<Predictor> predictor =
@ -25,7 +21,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
GenericParameter ctx; GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{}); ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx, kClasses); gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, kClasses);
{ {
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(); auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();

View File

@ -3,8 +3,10 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/learner.h> #include <xgboost/learner.h>
#include <xgboost/objective.h> // ObjFunction
#include <xgboost/version_config.h> #include <xgboost/version_config.h>
#include <string> // std::stof, std::string
#include <thread> #include <thread>
#include <vector> #include <vector>
@ -206,8 +208,7 @@ TEST(Learner, MultiThreadedPredict) {
p_dmat->Info().labels.Reshape(kRows); p_dmat->Info().labels.Reshape(kRows);
CHECK_NE(p_dmat->Info().num_col_, 0); CHECK_NE(p_dmat->Info().num_col_, 0);
std::shared_ptr<DMatrix> p_data{ std::shared_ptr<DMatrix> p_data{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
CHECK_NE(p_data->Info().num_col_, 0); CHECK_NE(p_data->Info().num_col_, 0);
std::shared_ptr<Learner> learner{Learner::Create({p_dmat})}; std::shared_ptr<Learner> learner{Learner::Create({p_dmat})};
@ -448,4 +449,77 @@ TEST(Learner, MultiTarget) {
EXPECT_THROW({ learner->Configure(); }, dmlc::Error); EXPECT_THROW({ learner->Configure(); }, dmlc::Error);
} }
} }
/**
* Test the model initialization sequence is correctly performed.
*/
TEST(Learner, InitEstimation) {
size_t constexpr kCols = 10;
auto Xy = RandomDataGenerator{10, kCols, 0}.GenerateDMatrix(true);
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("objective", "reg:absoluteerror");
learner->Configure();
HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_EQ(v, ObjFunction::DefaultBaseScore());
}
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score =
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
// No base score is estimated yet.
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
}
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("objective", "reg:absoluteerror");
learner->UpdateOneIter(0, Xy);
HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_NE(v, ObjFunction::DefaultBaseScore());
}
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score =
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
ASSERT_THROW(
{
learner->SetParam("base_score_estimated", "1");
learner->Configure();
},
dmlc::Error);
}
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("objective", "reg:absoluteerror");
learner->SetParam("base_score", "1.3");
learner->Configure();
HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_FLOAT_EQ(v, 1.3);
}
learner->UpdateOneIter(0, Xy);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score =
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
// no change
ASSERT_FLOAT_EQ(base_score, 1.3);
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -418,6 +418,45 @@ TEST_F(SerializationTest, GPUCoordDescent) {
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
class L1SerializationTest : public SerializationTest {};
TEST_F(L1SerializationTest, Exact) {
TestLearnerSerialization({{"booster", "gbtree"},
{"objective", "reg:absoluteerror"},
{"seed", "0"},
{"max_depth", "2"},
{"tree_method", "exact"}},
fmap_, p_dmat_);
}
TEST_F(L1SerializationTest, Approx) {
TestLearnerSerialization({{"booster", "gbtree"},
{"objective", "reg:absoluteerror"},
{"seed", "0"},
{"max_depth", "2"},
{"tree_method", "approx"}},
fmap_, p_dmat_);
}
TEST_F(L1SerializationTest, Hist) {
TestLearnerSerialization({{"booster", "gbtree"},
{"objective", "reg:absoluteerror"},
{"seed", "0"},
{"max_depth", "2"},
{"tree_method", "hist"}},
fmap_, p_dmat_);
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(L1SerializationTest, GpuHist) {
TestLearnerSerialization({{"booster", "gbtree"},
{"objective", "reg:absoluteerror"},
{"seed", "0"},
{"max_depth", "2"},
{"tree_method", "gpu_hist"}},
fmap_, p_dmat_);
}
#endif // defined(XGBOOST_USE_CUDA)
class LogitSerializationTest : public SerializationTest { class LogitSerializationTest : public SerializationTest {
protected: protected:

View File

@ -208,3 +208,8 @@ class TestGPUUpdaters:
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10) result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result['train'][dataset.metric]) assert tm.non_increasing(result['train'][dataset.metric])
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("weighted", [True, False])
def test_adaptive(self, weighted) -> None:
self.cputest.run_adaptive("gpu_hist", weighted)

View File

@ -102,34 +102,38 @@ def run_scikit_model_check(name, path):
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_model_compatibility(): def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't """Test model compatibility, can only be run on CI as others don't
have the credentials. have the credentials.
''' """
path = os.path.dirname(os.path.abspath(__file__)) path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models') path = os.path.join(path, "models")
zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' + if not os.path.exists(path):
'.amazonaws.com/xgboost_model_compatibility_test.zip') zip_path, _ = urllib.request.urlretrieve(
with zipfile.ZipFile(zip_path, 'r') as z: "https://xgboost-ci-jenkins-artifacts.s3-us-west-2"
z.extractall(path) + ".amazonaws.com/xgboost_model_compatibility_test.zip"
)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(path)
models = [ models = [
os.path.join(root, f) for root, subdir, files in os.walk(path) os.path.join(root, f)
for root, subdir, files in os.walk(path)
for f in files for f in files
if f != 'version' if f != "version"
] ]
assert models assert models
for path in models: for path in models:
name = os.path.basename(path) name = os.path.basename(path)
if name.startswith('xgboost-'): if name.startswith("xgboost-"):
booster = xgboost.Booster(model_file=path) booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name) run_booster_check(booster, name)
# Do full serialization. # Do full serialization.
booster = copy.copy(booster) booster = copy.copy(booster)
run_booster_check(booster, name) run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'): elif name.startswith("xgboost_scikit"):
run_scikit_model_check(name, path) run_scikit_model_check(name, path)
else: else:
assert False assert False

View File

@ -1,4 +1,4 @@
from random import choice import json
from string import ascii_lowercase from string import ascii_lowercase
from typing import Dict, Any from typing import Dict, Any
import testing as tm import testing as tm
@ -397,3 +397,72 @@ class TestTreeMethod:
def test_categorical_missing(self, rows, cols, cats): def test_categorical_missing(self, rows, cols, cats):
self.run_categorical_missing(rows, cols, cats, "approx") self.run_categorical_missing(rows, cols, cats, "approx")
self.run_categorical_missing(rows, cols, cats, "hist") self.run_categorical_missing(rows, cols, cats, "hist")
def run_adaptive(self, tree_method, weighted) -> None:
rng = np.random.RandomState(1994)
from sklearn.datasets import make_regression
from sklearn.utils import stats
n_samples = 256
X, y = make_regression(n_samples, 16, random_state=rng)
if weighted:
w = rng.normal(size=n_samples)
w -= w.min()
Xy = xgb.DMatrix(X, y, weight=w)
base_score = stats._weighted_percentile(y, w, percentile=50)
else:
Xy = xgb.DMatrix(X, y)
base_score = np.median(y)
booster_0 = xgb.train(
{
"tree_method": tree_method,
"base_score": base_score,
"objective": "reg:absoluteerror",
},
Xy,
num_boost_round=1,
)
booster_1 = xgb.train(
{"tree_method": tree_method, "objective": "reg:absoluteerror"},
Xy,
num_boost_round=1,
)
config_0 = json.loads(booster_0.save_config())
config_1 = json.loads(booster_1.save_config())
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
assert get_score(config_0) == get_score(config_1)
raw_booster = booster_1.save_raw(raw_format="deprecated")
booster_2 = xgb.Booster(model_file=raw_booster)
config_2 = json.loads(booster_2.save_config())
assert get_score(config_1) == get_score(config_2)
raw_booster = booster_1.save_raw(raw_format="ubj")
booster_2 = xgb.Booster(model_file=raw_booster)
config_2 = json.loads(booster_2.save_config())
assert get_score(config_1) == get_score(config_2)
booster_0 = xgb.train(
{
"tree_method": tree_method,
"base_score": base_score + 1.0,
"objective": "reg:absoluteerror",
},
Xy,
num_boost_round=1,
)
config_0 = json.loads(booster_0.save_config())
np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize(
"tree_method,weighted", [
("approx", False), ("hist", False), ("approx", True), ("hist", True)
]
)
def test_adaptive(self, tree_method, weighted) -> None:
self.run_adaptive(tree_method, weighted)

View File

@ -1537,13 +1537,56 @@ class TestWithDask:
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_same_on_all_workers(self) -> None: def test_quantile_same_on_all_workers(self) -> None:
self.run_quantile('SameOnAllWorkers') self.run_quantile("SameOnAllWorkers")
def test_adaptive(self) -> None:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
with xgb.dask.RabitContext(rabit_args):
if worker_id == 0:
y = np.array([0.0, 0.0, 0.0])
x = np.array([[0.0]] * 3)
else:
y = np.array([1000.0])
x = np.array(
[
[0.0],
]
)
Xy = xgb.DMatrix(x, y)
booster = xgb.train(
{"tree_method": "hist", "objective": "reg:absoluteerror"},
Xy,
num_boost_round=1,
)
config = json.loads(booster.save_config())
base_score = get_score(config)
assert base_score == 250.0
return True
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for i, _ in enumerate(workers):
f = client.submit(local_test, rabit_args, i)
futures.append(f)
results = client.gather(futures)
assert all(results)
def test_n_workers(self) -> None: def test_n_workers(self) -> None:
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = _get_client_workers(client)
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result() dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
dy = client.submit(da.from_array, y, workers=[workers[0]]).result() dy = client.submit(da.from_array, y, workers=[workers[0]]).result()