Implement fit stump. (#8607)

This commit is contained in:
Jiaming Yuan
2023-01-04 04:14:51 +08:00
committed by GitHub
parent 20e6087579
commit 8d545ab2a2
23 changed files with 421 additions and 60 deletions

View File

@@ -3,8 +3,10 @@
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
#include <xgboost/linalg.h> // Tensor,Vector
#include "../../../src/common/stats.h"
#include "../../../src/common/transform_iterator.h" // common::MakeIndexTransformIter
namespace xgboost {
namespace common {
@@ -69,5 +71,35 @@ TEST(Stats, Median) {
ASSERT_EQ(m, .5f);
#endif // defined(XGBOOST_USE_CUDA)
}
namespace {
void TestMean(Context const* ctx) {
std::size_t n{128};
linalg::Vector<float> data({n}, ctx->gpu_id);
auto h_v = data.HostView().Values();
std::iota(h_v.begin(), h_v.end(), .0f);
auto nf = static_cast<float>(n);
float mean = nf * (nf - 1) / 2 / n;
linalg::Vector<float> res{{1}, ctx->gpu_id};
Mean(ctx, data, &res);
auto h_res = res.HostView();
ASSERT_EQ(h_res.Size(), 1);
ASSERT_EQ(mean, h_res(0));
}
} // anonymous namespace
TEST(Stats, Mean) {
Context ctx;
TestMean(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
TEST(Stats, GPUMean) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestMean(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace common
} // namespace xgboost

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "../../../src/common/stats.cuh"
#include "../../../src/common/stats.h"
#include "xgboost/base.h"
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"