sync Mar 29
This commit is contained in:
@@ -6,11 +6,12 @@
|
||||
|
||||
#include "../../src/common/linalg_op.h"
|
||||
#include "../../src/tree/fit_stump.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace {
|
||||
void TestFitStump(Context const *ctx) {
|
||||
void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
|
||||
std::size_t constexpr kRows = 16, kTargets = 2;
|
||||
HostDeviceVector<GradientPair> gpair;
|
||||
auto &h_gpair = gpair.HostVector();
|
||||
@@ -22,6 +23,7 @@ void TestFitStump(Context const *ctx) {
|
||||
}
|
||||
linalg::Vector<float> out;
|
||||
MetaInfo info;
|
||||
info.data_split_mode = split;
|
||||
FitStump(ctx, info, gpair, kTargets, &out);
|
||||
auto h_out = out.HostView();
|
||||
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
||||
@@ -45,5 +47,12 @@ TEST(InitEstimation, GPUFitStump) {
|
||||
TestFitStump(&ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
|
||||
TEST(InitEstimation, FitStumpColumnSplit) {
|
||||
Context ctx;
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user