Fix gpu_hist apply_split test. (#4158)
This commit is contained in:
parent
2e618af743
commit
e1240413c9
@ -328,7 +328,7 @@ class RegTree {
|
|||||||
nodes_[node.LeftChild()].SetParent(nid, true);
|
nodes_[node.LeftChild()].SetParent(nid, true);
|
||||||
nodes_[node.RightChild()].SetParent(nid, false);
|
nodes_[node.RightChild()].SetParent(nid, false);
|
||||||
node.SetSplit(split_index, split_value,
|
node.SetSplit(split_index, split_value,
|
||||||
default_left);
|
default_left);
|
||||||
// mark right child as 0, to indicate fresh leaf
|
// mark right child as 0, to indicate fresh leaf
|
||||||
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
|
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
|
||||||
nodes_[pright].SetLeaf(right_leaf_weight, 0);
|
nodes_[pright].SetLeaf(right_leaf_weight, 0);
|
||||||
|
|||||||
@ -90,6 +90,7 @@ TEST(gpu_predictor, Test) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
// Test whether pickling preserves predictor parameters
|
// Test whether pickling preserves predictor parameters
|
||||||
TEST(gpu_predictor, MGPU_PicklingTest) {
|
TEST(gpu_predictor, MGPU_PicklingTest) {
|
||||||
int ngpu;
|
int ngpu;
|
||||||
@ -163,7 +164,9 @@ TEST(gpu_predictor, MGPU_PicklingTest) {
|
|||||||
|
|
||||||
CheckCAPICall(XGBoosterFree(bst2));
|
CheckCAPICall(XGBoosterFree(bst2));
|
||||||
}
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
// multi-GPU predictor test
|
// multi-GPU predictor test
|
||||||
TEST(gpu_predictor, MGPU_Test) {
|
TEST(gpu_predictor, MGPU_Test) {
|
||||||
std::unique_ptr<Predictor> gpu_predictor =
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
@ -202,6 +205,6 @@ TEST(gpu_predictor, MGPU_Test) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
} // namespace predictor
|
} // namespace predictor
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
@ -5,6 +5,9 @@
|
|||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
@ -315,6 +318,8 @@ TEST(GpuHist, ApplySplit) {
|
|||||||
int constexpr n_cols = 8;
|
int constexpr n_cols = 8;
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
std::vector<std::pair<std::string, std::string>> args = {};
|
||||||
|
param.InitAllowUnknown(args);
|
||||||
|
|
||||||
// Initialize shard
|
// Initialize shard
|
||||||
for (size_t i = 0; i < n_cols; ++i) {
|
for (size_t i = 0; i < n_cols; ++i) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user