Implement fit stump. (#8607)
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace cuda {
|
||||
namespace cuda_impl {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
common::OptionalWeights weights) {
|
||||
HostDeviceVector<size_t> segments{0, t.Size()};
|
||||
@@ -42,6 +42,17 @@ float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
CHECK_EQ(quantile.Size(), 1);
|
||||
return quantile.HostVector().front();
|
||||
}
|
||||
} // namespace cuda
|
||||
|
||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
|
||||
float n = v.Size();
|
||||
auto it = dh::MakeTransformIterator<float>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v(i) / n; });
|
||||
std::size_t bytes;
|
||||
CHECK_EQ(out.Size(), 1);
|
||||
cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size());
|
||||
dh::TemporaryArray<char> temp{bytes};
|
||||
cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size());
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user