Refactor tests with data generator. (#5439)
This commit is contained in:
@@ -313,22 +313,21 @@ TEST(GpuHist, MinSplitLoss) {
|
||||
constexpr size_t kRows = 32;
|
||||
constexpr size_t kCols = 16;
|
||||
constexpr float kSparsity = 0.6;
|
||||
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatix();
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
|
||||
{
|
||||
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair);
|
||||
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 0.01, &gpair);
|
||||
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
|
||||
// when writing this test, and only used for testing larger gamma (below) does prevent
|
||||
// building tree.
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
}
|
||||
{
|
||||
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 100.0, &gpair);
|
||||
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 100.0, &gpair);
|
||||
// No new nodes with gamma == 100.
|
||||
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
|
||||
}
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
|
||||
Reference in New Issue
Block a user