Fix GPU RF (#6755)

* Fix sampling.
This commit is contained in:
Jiaming Yuan
2021-03-17 06:23:35 +08:00
committed by GitHub
parent 1a73a28511
commit 4f75f514ce
4 changed files with 26 additions and 9 deletions

View File

@@ -503,12 +503,15 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
auto gpair = GenerateRandomGradients(kRows);
// Build a tree using the in-memory DMatrix.
auto rng = common::GlobalRandom();
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod,
kRows);
// Build another tree using multiple ELLPACK pages.
common::GlobalRandom() = rng;
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext,
@@ -518,7 +521,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-3);
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 1e-3);
}
}