Multi-target support for L1 error. (#8652)
- Add matrix support to the median function. - Iterate through each target for quantile computation.
This commit is contained in:
parent
badeff1d74
commit
cfa994d57f
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file objective.h
|
* \file objective.h
|
||||||
* \brief interface of objective function used by xgboost.
|
* \brief interface of objective function used by xgboost.
|
||||||
* \author Tianqi Chen, Kailong Chen
|
* \author Tianqi Chen, Kailong Chen
|
||||||
@ -14,6 +14,7 @@
|
|||||||
#include <xgboost/model.h>
|
#include <xgboost/model.h>
|
||||||
#include <xgboost/task.h>
|
#include <xgboost/task.h>
|
||||||
|
|
||||||
|
#include <cstdint> // std::int32_t
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -111,12 +112,13 @@ class ObjFunction : public Configurable {
|
|||||||
* \param position The leaf index for each rows.
|
* \param position The leaf index for each rows.
|
||||||
* \param info MetaInfo providing labels and weights.
|
* \param info MetaInfo providing labels and weights.
|
||||||
* \param prediction Model prediction after transformation.
|
* \param prediction Model prediction after transformation.
|
||||||
|
* \param group_idx The group index for this tree, 0 when it's not multi-target or multi-class.
|
||||||
* \param p_tree Tree that needs to be updated.
|
* \param p_tree Tree that needs to be updated.
|
||||||
*/
|
*/
|
||||||
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
|
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
|
||||||
MetaInfo const& /*info*/,
|
MetaInfo const& /*info*/,
|
||||||
HostDeviceVector<float> const& /*prediction*/,
|
HostDeviceVector<float> const& /*prediction*/,
|
||||||
RegTree* /*p_tree*/) const {}
|
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Create an objective function according to name.
|
* \brief Create an objective function according to name.
|
||||||
|
|||||||
@ -317,13 +317,13 @@ class TestDataset:
|
|||||||
enable_categorical=True,
|
enable_categorical=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix:
|
def get_device_dmat(self) -> xgb.QuantileDMatrix:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
w = None if self.w is None else cp.array(self.w)
|
w = None if self.w is None else cp.array(self.w)
|
||||||
X = cp.array(self.X, dtype=np.float32)
|
X = cp.array(self.X, dtype=np.float32)
|
||||||
y = cp.array(self.y, dtype=np.float32)
|
y = cp.array(self.y, dtype=np.float32)
|
||||||
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
|
return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
|
||||||
|
|
||||||
def get_external_dmat(self) -> xgb.DMatrix:
|
def get_external_dmat(self) -> xgb.DMatrix:
|
||||||
n_samples = self.X.shape[0]
|
n_samples = self.X.shape[0]
|
||||||
@ -726,10 +726,16 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
|||||||
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
|
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
|
||||||
TestDataset(
|
TestDataset(
|
||||||
"mtreg",
|
"mtreg",
|
||||||
lambda: datasets.make_regression(n_samples=128, n_targets=3),
|
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
||||||
"reg:squarederror",
|
"reg:squarederror",
|
||||||
"rmse",
|
"rmse",
|
||||||
),
|
),
|
||||||
|
TestDataset(
|
||||||
|
"mtreg-l1",
|
||||||
|
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
||||||
|
"reg:absoluteerror",
|
||||||
|
"mae",
|
||||||
|
),
|
||||||
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
|
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
|
||||||
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
|
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
|
||||||
TestDataset(
|
TestDataset(
|
||||||
@ -753,7 +759,7 @@ def _dataset_weight_margin(draw: Callable) -> TestDataset:
|
|||||||
num_class = 1
|
num_class = 1
|
||||||
if data.objective == "multi:softmax":
|
if data.objective == "multi:softmax":
|
||||||
num_class = int(np.max(data.y) + 1)
|
num_class = int(np.max(data.y) + 1)
|
||||||
elif data.name == "mtreg":
|
elif data.name.startswith("mtreg"):
|
||||||
num_class = data.y.shape[1]
|
num_class = data.y.shape[1]
|
||||||
|
|
||||||
data.margin = draw(
|
data.margin = draw(
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "stats.h"
|
#include "stats.h"
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
#include <numeric> // std::accumulate
|
#include <numeric> // std::accumulate
|
||||||
|
|
||||||
#include "common.h" // OptionalWeights
|
#include "common.h" // OptionalWeights
|
||||||
|
#include "linalg_op.h"
|
||||||
#include "threading_utils.h" // ParallelFor, MemStackAllocator
|
#include "threading_utils.h" // ParallelFor, MemStackAllocator
|
||||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
@ -15,32 +17,32 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||||
HostDeviceVector<float> const& weights) {
|
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
|
||||||
CHECK_LE(t.Shape(1), 1) << "Matrix is not yet supported.";
|
|
||||||
if (!ctx->IsCPU()) {
|
if (!ctx->IsCPU()) {
|
||||||
weights.SetDevice(ctx->gpu_id);
|
weights.SetDevice(ctx->gpu_id);
|
||||||
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
||||||
auto t_v = t.View(ctx->gpu_id);
|
auto t_v = t.View(ctx->gpu_id);
|
||||||
return cuda_impl::Median(ctx, t_v, opt_weights);
|
cuda_impl::Median(ctx, t_v, opt_weights, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
|
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
|
||||||
auto t_v = t.HostView();
|
auto t_v = t.HostView();
|
||||||
auto iter = common::MakeIndexTransformIter(
|
out->Reshape(t.Shape(1));
|
||||||
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
|
auto h_out = out->HostView();
|
||||||
float q{0};
|
for (std::size_t i{0}; i < t.Shape(1); ++i) {
|
||||||
if (opt_weights.Empty()) {
|
auto ti_v = t_v.Slice(linalg::All(), i);
|
||||||
q = common::Quantile(0.5, iter, iter + t_v.Size());
|
auto iter = linalg::cbegin(ti_v);
|
||||||
} else {
|
float q{0};
|
||||||
CHECK_NE(t_v.Shape(1), 0);
|
if (opt_weights.Empty()) {
|
||||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
|
q = common::Quantile(0.5, iter, iter + ti_v.Size());
|
||||||
auto sample_idx = i / t_v.Shape(1);
|
} else {
|
||||||
return opt_weights[sample_idx];
|
CHECK_NE(t_v.Shape(1), 0);
|
||||||
});
|
auto w_it = common::MakeIndexTransformIter([&](std::size_t i) { return opt_weights[i]; });
|
||||||
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
|
q = common::WeightedQuantile(0.5, iter, iter + ti_v.Size(), w_it);
|
||||||
|
}
|
||||||
|
h_out(i) = q;
|
||||||
}
|
}
|
||||||
return q;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out) {
|
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out) {
|
||||||
|
|||||||
@ -1,46 +1,52 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
|
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
|
||||||
|
|
||||||
#include "common.h" // common::OptionalWeights
|
#include <cstddef> // size_t
|
||||||
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
|
||||||
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
#include "common.h" // common::OptionalWeights
|
||||||
#include "xgboost/context.h" // Context
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
|
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
||||||
|
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
||||||
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply
|
#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||||
common::OptionalWeights weights) {
|
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
|
||||||
HostDeviceVector<size_t> segments{0, t.Size()};
|
CHECK_GE(t.Shape(1), 1);
|
||||||
|
HostDeviceVector<std::size_t> segments(t.Shape(1) + 1, 0);
|
||||||
segments.SetDevice(ctx->gpu_id);
|
segments.SetDevice(ctx->gpu_id);
|
||||||
auto d_segments = segments.ConstDeviceSpan();
|
auto d_segments = segments.DeviceSpan();
|
||||||
|
dh::LaunchN(d_segments.size(), ctx->CUDACtx()->Stream(),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { d_segments[i] = t.Shape(0) * i; });
|
||||||
auto val_it = dh::MakeTransformIterator<float>(
|
auto val_it = dh::MakeTransformIterator<float>(
|
||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
|
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
|
||||||
return linalg::detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
return linalg::detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||||
});
|
});
|
||||||
|
|
||||||
HostDeviceVector<float> quantile{0};
|
out->SetDevice(ctx->gpu_id);
|
||||||
quantile.SetDevice(ctx->gpu_id);
|
out->Reshape(t.Shape(1));
|
||||||
if (weights.Empty()) {
|
if (weights.Empty()) {
|
||||||
common::SegmentedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments), val_it,
|
common::SegmentedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments), val_it,
|
||||||
val_it + t.Size(), &quantile);
|
val_it + t.Size(), out->Data());
|
||||||
} else {
|
} else {
|
||||||
CHECK_NE(t.Shape(1), 0);
|
CHECK_NE(t.Shape(1), 0);
|
||||||
auto w_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
auto w_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
[=] XGBOOST_DEVICE(size_t i) {
|
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||||
auto sample_idx = i / t.Shape(1);
|
auto sample_idx = i / t.Shape(1);
|
||||||
return weights[sample_idx];
|
return weights[sample_idx];
|
||||||
});
|
});
|
||||||
common::SegmentedWeightedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments),
|
common::SegmentedWeightedQuantile(ctx, 0.5, dh::tcbegin(d_segments), dh::tcend(d_segments),
|
||||||
val_it, val_it + t.Size(), w_it, w_it + t.Size(), &quantile);
|
val_it, val_it + t.Size(), w_it, w_it + t.Size(),
|
||||||
|
out->Data());
|
||||||
}
|
}
|
||||||
CHECK_EQ(quantile.Size(), 1);
|
|
||||||
return quantile.HostVector().front();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
|
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
|
||||||
@ -49,9 +55,10 @@ void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorV
|
|||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v(i) / n; });
|
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v(i) / n; });
|
||||||
std::size_t bytes;
|
std::size_t bytes;
|
||||||
CHECK_EQ(out.Size(), 1);
|
CHECK_EQ(out.Size(), 1);
|
||||||
cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size());
|
auto s = ctx->CUDACtx()->Stream();
|
||||||
|
cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size(), s);
|
||||||
dh::TemporaryArray<char> temp{bytes};
|
dh::TemporaryArray<char> temp{bytes};
|
||||||
cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size());
|
cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size(), s);
|
||||||
}
|
}
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_COMMON_STATS_H_
|
#ifndef XGBOOST_COMMON_STATS_H_
|
||||||
#define XGBOOST_COMMON_STATS_H_
|
#define XGBOOST_COMMON_STATS_H_
|
||||||
@ -95,13 +95,15 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights);
|
void Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights,
|
||||||
|
linalg::Tensor<float, 1>* out);
|
||||||
|
|
||||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
|
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
inline float Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights) {
|
inline void Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights,
|
||||||
|
linalg::Tensor<float, 1>*) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
|
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
@ -109,8 +111,11 @@ inline void Mean(Context const*, linalg::VectorView<float const>, linalg::Vector
|
|||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
|
|
||||||
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
/**
|
||||||
HostDeviceVector<float> const& weights);
|
* \brief Calculate medians for each column of the input matrix.
|
||||||
|
*/
|
||||||
|
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||||
|
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out);
|
||||||
|
|
||||||
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
|
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by Contributors
|
* Copyright 2014-2023 by Contributors
|
||||||
* \file gbtree.cc
|
* \file gbtree.cc
|
||||||
* \brief gradient boosted tree implementation.
|
* \brief gradient boosted tree implementation.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -21,6 +21,7 @@
|
|||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/gbm.h"
|
#include "xgboost/gbm.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
@ -219,6 +220,8 @@ void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_thre
|
|||||||
|
|
||||||
void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
|
void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
|
||||||
ObjFunction const* obj,
|
ObjFunction const* obj,
|
||||||
|
std::int32_t group_idx,
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> const& node_position,
|
||||||
std::vector<std::unique_ptr<RegTree>>* p_trees) {
|
std::vector<std::unique_ptr<RegTree>>* p_trees) {
|
||||||
CHECK(!updaters_.empty());
|
CHECK(!updaters_.empty());
|
||||||
if (!updaters_.back()->HasNodePosition()) {
|
if (!updaters_.back()->HasNodePosition()) {
|
||||||
@ -227,10 +230,14 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
|
|||||||
if (!obj || !obj->Task().UpdateTreeLeaf()) {
|
if (!obj || !obj->Task().UpdateTreeLeaf()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& trees = *p_trees;
|
auto& trees = *p_trees;
|
||||||
for (size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
||||||
auto const& position = this->node_position_.at(tree_idx);
|
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
||||||
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, trees[tree_idx].get());
|
<< "Boosting random forest is not supported for current objective.";
|
||||||
|
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
||||||
|
auto const& position = node_position.at(tree_idx);
|
||||||
|
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, group_idx, trees[tree_idx].get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,10 +261,14 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
|||||||
LOG(FATAL) << "Current objective doesn't support external memory.";
|
LOG(FATAL) << "Current objective doesn't support external memory.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The node position for each row, 1 HDV for each tree in the forest. Note that the
|
||||||
|
// position is negated if the row is sampled out.
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> node_position;
|
||||||
|
|
||||||
if (ngroup == 1) {
|
if (ngroup == 1) {
|
||||||
std::vector<std::unique_ptr<RegTree>> ret;
|
std::vector<std::unique_ptr<RegTree>> ret;
|
||||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
|
||||||
UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret);
|
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
|
||||||
const size_t num_new_trees = ret.size();
|
const size_t num_new_trees = ret.size();
|
||||||
new_trees.push_back(std::move(ret));
|
new_trees.push_back(std::move(ret));
|
||||||
auto v_predt = out.Slice(linalg::All(), 0);
|
auto v_predt = out.Slice(linalg::All(), 0);
|
||||||
@ -271,10 +282,11 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
|||||||
in_gpair->DeviceIdx());
|
in_gpair->DeviceIdx());
|
||||||
bool update_predict = true;
|
bool update_predict = true;
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
|
node_position.clear();
|
||||||
CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp);
|
CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp);
|
||||||
std::vector<std::unique_ptr<RegTree>> ret;
|
std::vector<std::unique_ptr<RegTree>> ret;
|
||||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret);
|
||||||
UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret);
|
UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret);
|
||||||
const size_t num_new_trees = ret.size();
|
const size_t num_new_trees = ret.size();
|
||||||
new_trees.push_back(std::move(ret));
|
new_trees.push_back(std::move(ret));
|
||||||
auto v_predt = out.Slice(linalg::All(), gid);
|
auto v_predt = out.Slice(linalg::All(), gid);
|
||||||
@ -334,6 +346,7 @@ void GBTree::InitUpdater(Args const& cfg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
|
void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>>* out_position,
|
||||||
std::vector<std::unique_ptr<RegTree>>* ret) {
|
std::vector<std::unique_ptr<RegTree>>* ret) {
|
||||||
std::vector<RegTree*> new_trees;
|
std::vector<RegTree*> new_trees;
|
||||||
ret->clear();
|
ret->clear();
|
||||||
@ -367,14 +380,16 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
|||||||
ret->push_back(std::move(t));
|
ret->push_back(std::move(t));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the trees
|
// update the trees
|
||||||
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
|
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
|
||||||
<< "Mismatching size between number of rows from input data and size of "
|
<< "Mismatching size between number of rows from input data and size of "
|
||||||
"gradient vector.";
|
"gradient vector.";
|
||||||
node_position_.resize(new_trees.size());
|
|
||||||
|
CHECK(out_position);
|
||||||
|
out_position->resize(new_trees.size());
|
||||||
for (auto& up : updaters_) {
|
for (auto& up : updaters_) {
|
||||||
up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{node_position_},
|
up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
|
||||||
new_trees);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by Contributors
|
* Copyright 2014-2023 by Contributors
|
||||||
* \file gbtree.cc
|
* \file gbtree.cc
|
||||||
* \brief gradient boosted tree implementation.
|
* \brief gradient boosted tree implementation.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -10,26 +10,26 @@
|
|||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <cstdint> // std::int32_t
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
|
||||||
#include "xgboost/data.h"
|
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/gbm.h"
|
|
||||||
#include "xgboost/predictor.h"
|
|
||||||
#include "xgboost/tree_updater.h"
|
|
||||||
#include "xgboost/parameter.h"
|
|
||||||
#include "xgboost/json.h"
|
|
||||||
#include "xgboost/host_device_vector.h"
|
|
||||||
|
|
||||||
#include "gbtree_model.h"
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
|
#include "gbtree_model.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/gbm.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/parameter.h"
|
||||||
|
#include "xgboost/predictor.h"
|
||||||
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
enum class TreeMethod : int {
|
enum class TreeMethod : int {
|
||||||
@ -205,7 +205,10 @@ class GBTree : public GradientBooster {
|
|||||||
* \brief Optionally update the leaf value.
|
* \brief Optionally update the leaf value.
|
||||||
*/
|
*/
|
||||||
void UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
|
void UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
|
||||||
ObjFunction const* obj, std::vector<std::unique_ptr<RegTree>>* p_trees);
|
ObjFunction const* obj,
|
||||||
|
std::int32_t group_idx,
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> const& node_position,
|
||||||
|
std::vector<std::unique_ptr<RegTree>>* p_trees);
|
||||||
|
|
||||||
/*! \brief Carry out one iteration of boosting */
|
/*! \brief Carry out one iteration of boosting */
|
||||||
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
||||||
@ -411,11 +414,9 @@ class GBTree : public GradientBooster {
|
|||||||
// initialize updater before using them
|
// initialize updater before using them
|
||||||
void InitUpdater(Args const& cfg);
|
void InitUpdater(Args const& cfg);
|
||||||
|
|
||||||
// do group specific group
|
void BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
|
||||||
void BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
|
std::vector<HostDeviceVector<bst_node_t>>* out_position,
|
||||||
DMatrix *p_fmat,
|
std::vector<std::unique_ptr<RegTree>>* ret);
|
||||||
int bst_group,
|
|
||||||
std::vector<std::unique_ptr<RegTree> >* ret);
|
|
||||||
|
|
||||||
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
|
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
|
||||||
DMatrix* f_dmat = nullptr) const;
|
DMatrix* f_dmat = nullptr) const;
|
||||||
@ -435,9 +436,6 @@ class GBTree : public GradientBooster {
|
|||||||
Args cfg_;
|
Args cfg_;
|
||||||
// the updaters that can be applied to each of tree
|
// the updaters that can be applied to each of tree
|
||||||
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
||||||
// The node position for each row, 1 HDV for each tree in the forest. Note that the
|
|
||||||
// position is negated if the row is sampled out.
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> node_position_;
|
|
||||||
// Predictors
|
// Predictors
|
||||||
std::unique_ptr<Predictor> cpu_predictor_;
|
std::unique_ptr<Predictor> cpu_predictor_;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file learner.cc
|
* \file learner.cc
|
||||||
* \brief Implementation of learning algorithm.
|
* \brief Implementation of learning algorithm.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -412,6 +412,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
// We estimate it from input data.
|
// We estimate it from input data.
|
||||||
linalg::Tensor<float, 1> base_score;
|
linalg::Tensor<float, 1> base_score;
|
||||||
UsePtr(obj_)->InitEstimation(info, &base_score);
|
UsePtr(obj_)->InitEstimation(info, &base_score);
|
||||||
|
CHECK_EQ(base_score.Size(), 1);
|
||||||
mparam_.base_score = base_score(0);
|
mparam_.base_score = base_score(0);
|
||||||
CHECK(!std::isnan(mparam_.base_score));
|
CHECK(!std::isnan(mparam_.base_score));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "adaptive.h"
|
#include "adaptive.h"
|
||||||
|
|
||||||
@ -11,6 +11,7 @@
|
|||||||
#include "../common/stats.h"
|
#include "../common/stats.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -66,8 +67,8 @@ void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& posi
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||||
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
|
std::int32_t group_idx, MetaInfo const& info,
|
||||||
RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
|
|
||||||
std::vector<bst_node_t> nidx;
|
std::vector<bst_node_t> nidx;
|
||||||
@ -88,6 +89,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
auto const& h_node_idx = nidx;
|
auto const& h_node_idx = nidx;
|
||||||
auto const& h_node_ptr = nptr;
|
auto const& h_node_ptr = nptr;
|
||||||
CHECK_LE(h_node_ptr.back(), info.num_row_);
|
CHECK_LE(h_node_ptr.back(), info.num_row_);
|
||||||
|
auto h_predt = linalg::MakeTensorView(predt.ConstHostSpan(),
|
||||||
|
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
|
||||||
|
|
||||||
// loop over each leaf
|
// loop over each leaf
|
||||||
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
|
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
|
||||||
auto nidx = h_node_idx[k];
|
auto nidx = h_node_idx[k];
|
||||||
@ -95,14 +99,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
CHECK_LT(k + 1, h_node_ptr.size());
|
CHECK_LT(k + 1, h_node_ptr.size());
|
||||||
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
|
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
|
||||||
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
|
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
|
||||||
// multi-target not yet supported.
|
CHECK_LE(group_idx, info.labels.Shape(1));
|
||||||
auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);
|
auto h_labels = info.labels.HostView().Slice(linalg::All(), group_idx);
|
||||||
auto const& h_predt = predt.ConstHostVector();
|
|
||||||
auto h_weights = linalg::MakeVec(&info.weights_);
|
auto h_weights = linalg::MakeVec(&info.weights_);
|
||||||
|
|
||||||
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
|
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
|
||||||
auto row_idx = h_row_set[i];
|
auto row_idx = h_row_set[i];
|
||||||
return h_labels(row_idx) - h_predt[row_idx];
|
return h_labels(row_idx) - h_predt(row_idx, group_idx);
|
||||||
});
|
});
|
||||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
|
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
|
||||||
auto row_idx = h_row_set[i];
|
auto row_idx = h_row_set[i];
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/sort.h>
|
#include <thrust/sort.h>
|
||||||
|
|
||||||
|
#include <cstdint> // std::int32_t
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
|
||||||
|
#include "../common/cuda_context.cuh" // CUDAContext
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/stats.cuh"
|
#include "../common/stats.cuh"
|
||||||
#include "adaptive.h"
|
#include "adaptive.h"
|
||||||
|
#include "xgboost/context.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace obj {
|
namespace obj {
|
||||||
@ -55,13 +58,13 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
|
|
||||||
size_t nbytes{0};
|
size_t nbytes{0};
|
||||||
auto begin_it = sorted_position.begin() + beg_pos;
|
auto begin_it = sorted_position.begin() + beg_pos;
|
||||||
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, nbytes, begin_it,
|
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(
|
||||||
unique_out.data().get(), counts_out.data().get(),
|
nullptr, nbytes, begin_it, unique_out.data().get(), counts_out.data().get(),
|
||||||
d_num_runs_out.data(), n_samples - beg_pos));
|
d_num_runs_out.data(), n_samples - beg_pos, ctx->CUDACtx()->Stream()));
|
||||||
dh::TemporaryArray<char> temp(nbytes);
|
dh::TemporaryArray<char> temp(nbytes);
|
||||||
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(temp.data().get(), nbytes, begin_it,
|
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(
|
||||||
unique_out.data().get(), counts_out.data().get(),
|
temp.data().get(), nbytes, begin_it, unique_out.data().get(), counts_out.data().get(),
|
||||||
d_num_runs_out.data(), n_samples - beg_pos));
|
d_num_runs_out.data(), n_samples - beg_pos, ctx->CUDACtx()->Stream()));
|
||||||
|
|
||||||
dh::PinnedMemory pinned_pool;
|
dh::PinnedMemory pinned_pool;
|
||||||
auto pinned = pinned_pool.GetSpan<char>(sizeof(size_t) + sizeof(bst_node_t));
|
auto pinned = pinned_pool.GetSpan<char>(sizeof(size_t) + sizeof(bst_node_t));
|
||||||
@ -138,8 +141,8 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||||
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
|
std::int32_t group_idx, MetaInfo const& info,
|
||||||
RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||||
dh::device_vector<size_t> ridx;
|
dh::device_vector<size_t> ridx;
|
||||||
HostDeviceVector<size_t> nptr;
|
HostDeviceVector<size_t> nptr;
|
||||||
@ -154,19 +157,24 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
|
|
||||||
HostDeviceVector<float> quantiles;
|
HostDeviceVector<float> quantiles;
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->gpu_id);
|
||||||
auto d_predt = predt.ConstDeviceSpan();
|
|
||||||
auto d_labels = info.labels.View(ctx->gpu_id);
|
auto d_predt = linalg::MakeTensorView(predt.ConstDeviceSpan(),
|
||||||
|
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
|
||||||
|
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||||
|
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||||
|
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), group_idx);
|
||||||
|
|
||||||
auto d_row_index = dh::ToSpan(ridx);
|
auto d_row_index = dh::ToSpan(ridx);
|
||||||
auto seg_beg = nptr.DevicePointer();
|
auto seg_beg = nptr.DevicePointer();
|
||||||
auto seg_end = seg_beg + nptr.Size();
|
auto seg_end = seg_beg + nptr.Size();
|
||||||
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
[=] XGBOOST_DEVICE(size_t i) {
|
[=] XGBOOST_DEVICE(size_t i) {
|
||||||
auto predt = d_predt[d_row_index[i]];
|
float p = t_predt(d_row_index[i]);
|
||||||
auto y = d_labels(d_row_index[i]);
|
auto y = d_labels(d_row_index[i]);
|
||||||
return y - predt;
|
return y - p;
|
||||||
});
|
});
|
||||||
auto val_end = val_beg + d_labels.Size();
|
CHECK_EQ(d_labels.Shape(0), position.size());
|
||||||
|
auto val_end = val_beg + d_labels.Shape(0);
|
||||||
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
||||||
if (info.weights_.Empty()) {
|
if (info.weights_.Empty()) {
|
||||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint> // std::int32_t
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const nidx,
|
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
auto& quantiles = *p_quantiles;
|
auto& quantiles = *p_quantiles;
|
||||||
@ -73,12 +74,12 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||||
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
|
std::int32_t group_idx, MetaInfo const& info,
|
||||||
RegTree* p_tree);
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||||
|
|
||||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||||
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
|
std::int32_t group_idx, MetaInfo const& info,
|
||||||
RegTree* p_tree);
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,15 +1,14 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file regression_obj.cu
|
* \file regression_obj.cu
|
||||||
* \brief Definition of single-value regression and classification objectives.
|
* \brief Definition of single-value regression and classification objectives.
|
||||||
* \author Tianqi Chen, Kailong Chen
|
* \author Tianqi Chen, Kailong Chen
|
||||||
*/
|
*/
|
||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
#include <xgboost/logging.h>
|
|
||||||
#include <xgboost/objective.h>
|
|
||||||
#include <xgboost/tree_model.h>
|
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdint> // std::int32_t
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -25,12 +24,15 @@
|
|||||||
#include "adaptive.h"
|
#include "adaptive.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h" // MetaInfo
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/objective.h" // ObjFunction
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
#include "xgboost/tree_model.h" // RegTree
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
@ -703,6 +705,9 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
public:
|
public:
|
||||||
void Configure(Args const&) override {}
|
void Configure(Args const&) override {}
|
||||||
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
|
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
|
||||||
|
bst_target_t Targets(MetaInfo const& info) const override {
|
||||||
|
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int /*iter*/,
|
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int /*iter*/,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||||
@ -724,7 +729,7 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
|
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
|
||||||
};
|
};
|
||||||
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
|
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
|
||||||
auto grad = sign(predt(i) - y) * weight[i];
|
auto grad = sign(predt(i) - y) * weight[sample_id];
|
||||||
auto hess = weight[sample_id];
|
auto hess = weight[sample_id];
|
||||||
gpair(i) = GradientPair{grad, hess};
|
gpair(i) = GradientPair{grad, hess};
|
||||||
});
|
});
|
||||||
@ -732,8 +737,7 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
|
||||||
CheckInitInputs(info);
|
CheckInitInputs(info);
|
||||||
base_margin->Reshape(1);
|
base_margin->Reshape(this->Targets(info));
|
||||||
auto out = base_margin->HostView();
|
|
||||||
|
|
||||||
double w{0.0};
|
double w{0.0};
|
||||||
if (info.weights_.Empty()) {
|
if (info.weights_.Empty()) {
|
||||||
@ -743,11 +747,18 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (info.num_row_ == 0) {
|
if (info.num_row_ == 0) {
|
||||||
|
auto out = base_margin->HostView();
|
||||||
out(0) = 0;
|
out(0) = 0;
|
||||||
} else {
|
} else {
|
||||||
// weighted avg
|
linalg::Vector<float> temp;
|
||||||
out(0) = common::Median(ctx_, info.labels, info.weights_) * w;
|
common::Median(ctx_, info.labels, info.weights_, &temp);
|
||||||
|
common::Mean(ctx_, temp, base_margin);
|
||||||
}
|
}
|
||||||
|
CHECK_EQ(base_margin->Size(), 1);
|
||||||
|
auto out = base_margin->HostView();
|
||||||
|
// weighted avg
|
||||||
|
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||||
|
[w](float v) { return v * w; });
|
||||||
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
||||||
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
||||||
@ -763,15 +774,16 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||||
HostDeviceVector<float> const& prediction, RegTree* p_tree) const override {
|
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
|
||||||
|
RegTree* p_tree) const override {
|
||||||
if (ctx_->IsCPU()) {
|
if (ctx_->IsCPU()) {
|
||||||
auto const& h_position = position.ConstHostVector();
|
auto const& h_position = position.ConstHostVector();
|
||||||
detail::UpdateTreeLeafHost(ctx_, h_position, info, prediction, 0.5, p_tree);
|
detail::UpdateTreeLeafHost(ctx_, h_position, group_idx, info, prediction, 0.5, p_tree);
|
||||||
} else {
|
} else {
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
position.SetDevice(ctx_->gpu_id);
|
position.SetDevice(ctx_->gpu_id);
|
||||||
auto d_position = position.ConstDeviceSpan();
|
auto d_position = position.ConstDeviceSpan();
|
||||||
detail::UpdateTreeLeafDevice(ctx_, d_position, info, prediction, 0.5, p_tree);
|
detail::UpdateTreeLeafDevice(ctx_, d_position, group_idx, info, prediction, 0.5, p_tree);
|
||||||
#else
|
#else
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class GloablApproxBuilder {
|
|||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
|
|
||||||
n_batches_ = 0;
|
n_batches_ = 0;
|
||||||
int32_t n_total_bins = 0;
|
bst_bin_t n_total_bins = 0;
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, task_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, task_))) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h>
|
#include <xgboost/context.h>
|
||||||
@ -58,19 +58,44 @@ TEST(Stats, WeightedQuantile) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Stats, Median) {
|
TEST(Stats, Median) {
|
||||||
linalg::Tensor<float, 2> values{{.0f, .0f, 1.f, 2.f}, {4}, Context::kCpuId};
|
|
||||||
Context ctx;
|
Context ctx;
|
||||||
HostDeviceVector<float> weights;
|
|
||||||
auto m = Median(&ctx, values, weights);
|
{
|
||||||
ASSERT_EQ(m, .5f);
|
linalg::Tensor<float, 2> values{{.0f, .0f, 1.f, 2.f}, {4}, Context::kCpuId};
|
||||||
|
HostDeviceVector<float> weights;
|
||||||
|
linalg::Tensor<float, 1> out;
|
||||||
|
Median(&ctx, values, weights, &out);
|
||||||
|
auto m = out(0);
|
||||||
|
ASSERT_EQ(m, .5f);
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
ctx.gpu_id = 0;
|
ctx.gpu_id = 0;
|
||||||
ASSERT_FALSE(ctx.IsCPU());
|
ASSERT_FALSE(ctx.IsCPU());
|
||||||
m = Median(&ctx, values, weights);
|
Median(&ctx, values, weights, &out);
|
||||||
ASSERT_EQ(m, .5f);
|
m = out(0);
|
||||||
|
ASSERT_EQ(m, .5f);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
ctx.gpu_id = Context::kCpuId;
|
||||||
|
// 4x2 matrix
|
||||||
|
linalg::Tensor<float, 2> values{{0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 2.f, 2.f}, {4, 2}, ctx.gpu_id};
|
||||||
|
HostDeviceVector<float> weights;
|
||||||
|
linalg::Tensor<float, 1> out;
|
||||||
|
Median(&ctx, values, weights, &out);
|
||||||
|
ASSERT_EQ(out(0), .5f);
|
||||||
|
ASSERT_EQ(out(1), .5f);
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
ctx.gpu_id = 0;
|
||||||
|
Median(&ctx, values, weights, &out);
|
||||||
|
ASSERT_EQ(out(0), .5f);
|
||||||
|
ASSERT_EQ(out(1), .5f);
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestMean(Context const* ctx) {
|
void TestMean(Context const* ctx) {
|
||||||
std::size_t n{128};
|
std::size_t n{128};
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2016-2022 by XGBoost contributors
|
* Copyright 2016-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "helpers.h"
|
#include "helpers.h"
|
||||||
|
|
||||||
@ -335,30 +335,30 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
CHECK_EQ(columns->Size(), value->Size());
|
CHECK_EQ(columns->Size(), value->Size());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DMatrix>
|
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||||
RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
size_t classes) const {
|
||||||
size_t classes) const {
|
|
||||||
HostDeviceVector<float> data;
|
HostDeviceVector<float> data;
|
||||||
HostDeviceVector<bst_row_t> rptrs;
|
HostDeviceVector<bst_row_t> rptrs;
|
||||||
HostDeviceVector<bst_feature_t> columns;
|
HostDeviceVector<bst_feature_t> columns;
|
||||||
this->GenerateCSR(&data, &rptrs, &columns);
|
this->GenerateCSR(&data, &rptrs, &columns);
|
||||||
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(),
|
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
|
||||||
data.HostPointer(), rows_, data.Size(), cols_);
|
data.Size(), cols_);
|
||||||
std::shared_ptr<DMatrix> out{
|
std::shared_ptr<DMatrix> out{
|
||||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
|
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
|
||||||
|
|
||||||
if (with_label) {
|
if (with_label) {
|
||||||
RandomDataGenerator gen(rows_, 1, 0);
|
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
|
||||||
if (!float_label) {
|
if (!float_label) {
|
||||||
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
|
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
|
||||||
out->Info().labels.Reshape(this->rows_);
|
out->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||||
auto& h_labels = out->Info().labels.Data()->HostVector();
|
auto& h_labels = out->Info().labels.Data()->HostVector();
|
||||||
for (auto& v : h_labels) {
|
for (auto& v : h_labels) {
|
||||||
v = static_cast<float>(static_cast<uint32_t>(v));
|
v = static_cast<float>(static_cast<uint32_t>(v));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
gen.GenerateDense(out->Info().labels.Data());
|
gen.GenerateDense(out->Info().labels.Data());
|
||||||
out->Info().labels.Reshape(this->rows_);
|
CHECK_EQ(out->Info().labels.Size(), this->rows_ * this->n_targets_);
|
||||||
|
out->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (device_ >= 0) {
|
if (device_ >= 0) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2016-2019 XGBoost contributors
|
* Copyright 2016-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TESTS_CPP_HELPERS_H_
|
#ifndef XGBOOST_TESTS_CPP_HELPERS_H_
|
||||||
#define XGBOOST_TESTS_CPP_HELPERS_H_
|
#define XGBOOST_TESTS_CPP_HELPERS_H_
|
||||||
@ -214,26 +214,26 @@ class RandomDataGenerator {
|
|||||||
size_t cols_;
|
size_t cols_;
|
||||||
float sparsity_;
|
float sparsity_;
|
||||||
|
|
||||||
float lower_;
|
float lower_{0.0f};
|
||||||
float upper_;
|
float upper_{1.0f};
|
||||||
|
|
||||||
int32_t device_;
|
bst_target_t n_targets_{1};
|
||||||
uint64_t seed_;
|
|
||||||
|
std::int32_t device_{Context::kCpuId};
|
||||||
|
std::uint64_t seed_{0};
|
||||||
SimpleLCG lcg_;
|
SimpleLCG lcg_;
|
||||||
|
|
||||||
size_t bins_;
|
std::size_t bins_{0};
|
||||||
std::vector<FeatureType> ft_;
|
std::vector<FeatureType> ft_;
|
||||||
bst_cat_t max_cat_;
|
bst_cat_t max_cat_;
|
||||||
|
|
||||||
Json ArrayInterfaceImpl(HostDeviceVector<float> *storage, size_t rows,
|
Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;
|
||||||
size_t cols) const;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
|
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
|
||||||
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lower_{0.0f}, upper_{1.0f},
|
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {}
|
||||||
device_{-1}, seed_{0}, lcg_{seed_}, bins_{0} {}
|
|
||||||
|
|
||||||
RandomDataGenerator &Lower(float v) {
|
RandomDataGenerator& Lower(float v) {
|
||||||
lower_ = v;
|
lower_ = v;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -264,6 +264,10 @@ class RandomDataGenerator {
|
|||||||
max_cat_ = cat;
|
max_cat_ = cat;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
RandomDataGenerator& Targets(bst_target_t n_targets) {
|
||||||
|
n_targets_ = n_targets;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
void GenerateDense(HostDeviceVector<float>* out) const;
|
void GenerateDense(HostDeviceVector<float>* out) const;
|
||||||
|
|
||||||
@ -279,18 +283,15 @@ class RandomDataGenerator {
|
|||||||
* a single JSON string representing the consecutive memory as a whole
|
* a single JSON string representing the consecutive memory as a whole
|
||||||
* (combining all the batches).
|
* (combining all the batches).
|
||||||
*/
|
*/
|
||||||
std::pair<std::vector<std::string>, std::string>
|
std::pair<std::vector<std::string>, std::string> GenerateArrayInterfaceBatch(
|
||||||
GenerateArrayInterfaceBatch(HostDeviceVector<float> *storage,
|
HostDeviceVector<float>* storage, size_t batches) const;
|
||||||
size_t batches) const;
|
|
||||||
|
|
||||||
std::string GenerateColumnarArrayInterface(
|
std::string GenerateColumnarArrayInterface(std::vector<HostDeviceVector<float>>* data) const;
|
||||||
std::vector<HostDeviceVector<float>> *data) const;
|
|
||||||
|
|
||||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||||
HostDeviceVector<bst_feature_t>* columns) const;
|
HostDeviceVector<bst_feature_t>* columns) const;
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
|
std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false, bool float_label = true,
|
||||||
bool float_label = true,
|
|
||||||
size_t classes = 1) const;
|
size_t classes = 1) const;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
std::shared_ptr<DMatrix> GenerateDeviceDMatrix();
|
std::shared_ptr<DMatrix> GenerateDeviceDMatrix();
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 XGBoost contributors
|
* Copyright 2017-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h>
|
#include <xgboost/context.h>
|
||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/objective.h>
|
#include <xgboost/objective.h>
|
||||||
|
|
||||||
|
#include "../../../src/common/linalg_op.h" // begin,end
|
||||||
#include "../../../src/objective/adaptive.h"
|
#include "../../../src/objective/adaptive.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -404,56 +408,61 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteError)) {
|
|||||||
h_predt[i] = labels[i] + i;
|
h_predt[i] = labels[i] + i;
|
||||||
}
|
}
|
||||||
|
|
||||||
obj->UpdateTreeLeaf(position, info, predt, &tree);
|
obj->UpdateTreeLeaf(position, info, predt, 0, &tree);
|
||||||
ASSERT_EQ(tree[1].LeafValue(), -1);
|
ASSERT_EQ(tree[1].LeafValue(), -1);
|
||||||
ASSERT_EQ(tree[2].LeafValue(), -4);
|
ASSERT_EQ(tree[2].LeafValue(), -4);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
||||||
Context ctx = CreateEmptyGenericParam(GPUIDX);
|
Context ctx = CreateEmptyGenericParam(GPUIDX);
|
||||||
|
bst_target_t constexpr kTargets = 3, kRows = 16;
|
||||||
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:absoluteerror", &ctx)};
|
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:absoluteerror", &ctx)};
|
||||||
obj->Configure({});
|
obj->Configure({});
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
info.labels.Reshape(16, 1);
|
info.num_row_ = kRows;
|
||||||
info.num_row_ = info.labels.Size();
|
info.labels.Reshape(16, kTargets);
|
||||||
CHECK_EQ(info.num_row_, 16);
|
HostDeviceVector<float> predt(info.labels.Size());
|
||||||
auto h_labels = info.labels.HostView().Values();
|
|
||||||
std::iota(h_labels.begin(), h_labels.end(), 0);
|
|
||||||
HostDeviceVector<float> predt(h_labels.size());
|
|
||||||
auto& h_predt = predt.HostVector();
|
|
||||||
for (size_t i = 0; i < h_predt.size(); ++i) {
|
|
||||||
h_predt[i] = h_labels[i] + i;
|
|
||||||
}
|
|
||||||
|
|
||||||
HostDeviceVector<bst_node_t> position(info.labels.Size(), 0);
|
for (bst_target_t t{0}; t < kTargets; ++t) {
|
||||||
auto& h_position = position.HostVector();
|
auto h_labels = info.labels.HostView().Slice(linalg::All(), t);
|
||||||
for (int32_t i = 0; i < 3; ++i) {
|
std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0);
|
||||||
h_position[i] = ~i; // negation for sampled nodes.
|
|
||||||
}
|
|
||||||
for (size_t i = 3; i < 8; ++i) {
|
|
||||||
h_position[i] = 3;
|
|
||||||
}
|
|
||||||
// empty leaf for node 4
|
|
||||||
for (size_t i = 8; i < 13; ++i) {
|
|
||||||
h_position[i] = 5;
|
|
||||||
}
|
|
||||||
for (size_t i = 13; i < h_labels.size(); ++i) {
|
|
||||||
h_position[i] = 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
RegTree tree;
|
auto h_predt = linalg::MakeTensorView(predt.HostSpan(), {kRows, kTargets}, Context::kCpuId)
|
||||||
tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
.Slice(linalg::All(), t);
|
||||||
tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
for (size_t i = 0; i < h_predt.Size(); ++i) {
|
||||||
tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
h_predt(i) = h_labels(i) + i;
|
||||||
ASSERT_EQ(tree.GetNumLeaves(), 4);
|
}
|
||||||
|
|
||||||
auto empty_leaf = tree[4].LeafValue();
|
HostDeviceVector<bst_node_t> position(h_labels.Size(), 0);
|
||||||
obj->UpdateTreeLeaf(position, info, predt, &tree);
|
auto& h_position = position.HostVector();
|
||||||
ASSERT_EQ(tree[3].LeafValue(), -5);
|
for (int32_t i = 0; i < 3; ++i) {
|
||||||
ASSERT_EQ(tree[4].LeafValue(), empty_leaf);
|
h_position[i] = ~i; // negation for sampled nodes.
|
||||||
ASSERT_EQ(tree[5].LeafValue(), -10);
|
}
|
||||||
ASSERT_EQ(tree[6].LeafValue(), -14);
|
for (size_t i = 3; i < 8; ++i) {
|
||||||
|
h_position[i] = 3;
|
||||||
|
}
|
||||||
|
// empty leaf for node 4
|
||||||
|
for (size_t i = 8; i < 13; ++i) {
|
||||||
|
h_position[i] = 5;
|
||||||
|
}
|
||||||
|
for (size_t i = 13; i < h_labels.Size(); ++i) {
|
||||||
|
h_position[i] = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
||||||
|
tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
||||||
|
tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f);
|
||||||
|
ASSERT_EQ(tree.GetNumLeaves(), 4);
|
||||||
|
|
||||||
|
auto empty_leaf = tree[4].LeafValue();
|
||||||
|
obj->UpdateTreeLeaf(position, info, predt, t, &tree);
|
||||||
|
ASSERT_EQ(tree[3].LeafValue(), -5);
|
||||||
|
ASSERT_EQ(tree[4].LeafValue(), empty_leaf);
|
||||||
|
ASSERT_EQ(tree[5].LeafValue(), -10);
|
||||||
|
ASSERT_EQ(tree[6].LeafValue(), -14);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) {
|
TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) {
|
||||||
|
|||||||
120
tests/cpp/test_multi_target.cc
Normal file
120
tests/cpp/test_multi_target.cc
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/base.h> // bst_target_t
|
||||||
|
#include <xgboost/data.h> // DMatrix
|
||||||
|
#include <xgboost/json.h> // Json,Object,Number,get
|
||||||
|
#include <xgboost/learner.h> // Learner
|
||||||
|
|
||||||
|
#include <cstddef> // size_t
|
||||||
|
#include <memory> // shared_ptr,unique_ptr
|
||||||
|
#include <numeric>
|
||||||
|
#include <string> // stod
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../src/common/linalg_op.h" // cbegin,cend
|
||||||
|
#include "../../src/common/stats.h" // Median
|
||||||
|
#include "helpers.h" // RandomDataGenerator
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class TestL1MultiTarget : public ::testing::Test {
|
||||||
|
std::shared_ptr<DMatrix> Xy_;
|
||||||
|
std::shared_ptr<DMatrix> Xyw_;
|
||||||
|
std::vector<std::shared_ptr<DMatrix>> single_;
|
||||||
|
std::vector<std::shared_ptr<DMatrix>> single_w_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
std::size_t constexpr kRows{256}, kCols{5}, kTargets{3};
|
||||||
|
auto make_fmat = [&](bool weighted) {
|
||||||
|
if (weighted) {
|
||||||
|
auto p_fmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.5f}.Targets(kTargets).GenerateDMatrix(true);
|
||||||
|
p_fmat->Info().weights_.Resize(kRows);
|
||||||
|
RandomDataGenerator{kRows, 1, 0.0f}.GenerateDense(&p_fmat->Info().weights_);
|
||||||
|
return p_fmat;
|
||||||
|
} else {
|
||||||
|
return RandomDataGenerator{kRows, kCols, 0.5f}.Targets(kTargets).GenerateDMatrix(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Xy_ = make_fmat(false);
|
||||||
|
Xyw_ = make_fmat(true);
|
||||||
|
ASSERT_EQ(Xy_->Info().labels.Shape(1), kTargets);
|
||||||
|
ASSERT_EQ(Xyw_->Info().labels.Shape(1), kTargets);
|
||||||
|
|
||||||
|
single_.clear();
|
||||||
|
single_w_.clear();
|
||||||
|
for (bst_target_t t{0}; t < kTargets; ++t) {
|
||||||
|
{
|
||||||
|
single_.emplace_back(make_fmat(false));
|
||||||
|
single_[t]->Info().labels.Reshape(kRows, 1);
|
||||||
|
auto h_labels = single_[t]->Info().labels.HostView();
|
||||||
|
auto in_labels = Xy_->Info().labels.HostView().Slice(linalg::All(), t);
|
||||||
|
std::copy(linalg::cbegin(in_labels), linalg::cend(in_labels), linalg::begin(h_labels));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
single_w_.emplace_back(make_fmat(true));
|
||||||
|
single_w_[t]->Info().labels.Reshape(kRows, 1);
|
||||||
|
auto h_labels = single_w_[t]->Info().labels.HostView();
|
||||||
|
auto in_labels = Xyw_->Info().labels.HostView().Slice(linalg::All(), t);
|
||||||
|
std::copy(linalg::cbegin(in_labels), linalg::cend(in_labels), linalg::begin(h_labels));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest(std::string const& tree_method, bool weight) {
|
||||||
|
auto p_fmat = weight ? Xyw_ : Xy_;
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||||
|
learner->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}});
|
||||||
|
learner->Configure();
|
||||||
|
for (auto i = 0; i < 4; ++i) {
|
||||||
|
learner->UpdateOneIter(i, p_fmat);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(learner->Groups(), 3);
|
||||||
|
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score =
|
||||||
|
std::stod(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||||
|
|
||||||
|
std::vector<float> base_scores;
|
||||||
|
for (bst_target_t t{0}; t < p_fmat->Info().labels.Shape(1); ++t) {
|
||||||
|
auto t_Xy = weight ? single_w_[t] : single_[t];
|
||||||
|
std::unique_ptr<Learner> sl{Learner::Create({t_Xy})};
|
||||||
|
sl->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}});
|
||||||
|
sl->Configure();
|
||||||
|
sl->UpdateOneIter(0, t_Xy);
|
||||||
|
Json s_config{Object{}};
|
||||||
|
sl->SaveConfig(&s_config);
|
||||||
|
auto s_base_score =
|
||||||
|
std::stod(get<String const>(s_config["learner"]["learner_model_param"]["base_score"]));
|
||||||
|
linalg::Vector<float> out;
|
||||||
|
common::Median(sl->Ctx(), t_Xy->Info().labels, t_Xy->Info().weights_, &out);
|
||||||
|
ASSERT_FLOAT_EQ(s_base_score, out(0));
|
||||||
|
base_scores.push_back(s_base_score);
|
||||||
|
}
|
||||||
|
auto mean = std::accumulate(base_scores.cbegin(), base_scores.cend(), .0f) /
|
||||||
|
static_cast<float>(base_scores.size());
|
||||||
|
ASSERT_FLOAT_EQ(mean, base_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest(std::string const& tree_method) {
|
||||||
|
this->RunTest(tree_method, false);
|
||||||
|
this->RunTest(tree_method, true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestL1MultiTarget, Hist) { this->RunTest("hist"); }
|
||||||
|
|
||||||
|
TEST_F(TestL1MultiTarget, Exact) { this->RunTest("exact"); }
|
||||||
|
|
||||||
|
TEST_F(TestL1MultiTarget, Approx) { this->RunTest("approx"); }
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
TEST_F(TestL1MultiTarget, GpuHist) { this->RunTest("gpu_hist"); }
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user