Implement fit stump. (#8607)

This commit is contained in:
Jiaming Yuan
2023-01-04 04:14:51 +08:00
committed by GitHub
parent 20e6087579
commit 8d545ab2a2
23 changed files with 421 additions and 60 deletions

View File

@@ -134,6 +134,8 @@ using bst_row_t = std::size_t; // NOLINT
using bst_node_t = int32_t; // NOLINT
/*! \brief Type for ranking group index. */
using bst_group_t = uint32_t; // NOLINT
/*! \brief Type for indexing target variables. */
using bst_target_t = std::size_t; // NOLINT
namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation

View File

@@ -15,6 +15,7 @@
#include <algorithm>
#include <cassert>
#include <cinttypes> // std::int32_t
#include <limits>
#include <string>
#include <tuple>
@@ -388,9 +389,9 @@ class TensorView {
* \brief Create a tensor with data, shape and strides. Don't use this constructor if
* stride can be calculated from shape.
*/
template <typename I, int32_t D>
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
int32_t device)
std::int32_t device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D == kDim, "Invalid shape & stride.");
detail::UnrollLoop<D>([&](auto i) {
@@ -833,6 +834,27 @@ class Tensor {
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
};
template <typename T>
using Vector = Tensor<T, 1>;
template <typename T, typename... Index>
auto Constant(Context const *ctx, T v, Index &&...index) {
Tensor<T, sizeof...(Index)> t;
t.SetDevice(ctx->gpu_id);
t.Reshape(index...);
t.Data()->Fill(std::move(v));
return t;
}
/**
* \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros.
*/
template <typename T, typename... Index>
auto Zeros(Context const *ctx, Index &&...index) {
return Constant(ctx, static_cast<T>(0), index...);
}
// Only first axis is supported for now.
template <typename T, int32_t D>
void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {

View File

@@ -93,7 +93,7 @@ class ObjFunction : public Configurable {
* \brief Return number of targets for input matrix. Right now XGBoost supports only
* multi-target regression.
*/
virtual uint32_t Targets(MetaInfo const& info) const {
virtual bst_target_t Targets(MetaInfo const& info) const {
if (info.labels.Shape(1) > 1) {
LOG(FATAL) << "multioutput is not supported by current objective function";
}