sync Mar 29

This commit is contained in:
amdsc21
2023-03-30 00:46:50 +02:00
20 changed files with 335 additions and 115 deletions

View File

@@ -6,11 +6,12 @@
#include "../../src/common/linalg_op.h"
#include "../../src/tree/fit_stump.h"
#include "../helpers.h"
namespace xgboost {
namespace tree {
namespace {
void TestFitStump(Context const *ctx) {
void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
std::size_t constexpr kRows = 16, kTargets = 2;
HostDeviceVector<GradientPair> gpair;
auto &h_gpair = gpair.HostVector();
@@ -22,6 +23,7 @@ void TestFitStump(Context const *ctx) {
}
linalg::Vector<float> out;
MetaInfo info;
info.data_split_mode = split;
FitStump(ctx, info, gpair, kTargets, &out);
auto h_out = out.HostView();
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
@@ -45,5 +47,12 @@ TEST(InitEstimation, GPUFitStump) {
TestFitStump(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
TEST(InitEstimation, FitStumpColumnSplit) {
Context ctx;
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
}
} // namespace tree
} // namespace xgboost