Cleanup GPU Hist tests. (#10677)
* Cleanup GPU Hist tests. - Remove GPU Hist gradient sampling test. The same properties are tested in the gradient sampler test suite. - Move basic histogram tests into the histogram test suite. - Remove the header inclusion of the `updater_gpu_hist.cu` in tests.
This commit is contained in:
parent
6ccf116601
commit
cc3b56fc37
@ -1,12 +1,12 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 by XGBoost Contributors
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TASK_H_
|
#ifndef XGBOOST_TASK_H_
|
||||||
#define XGBOOST_TASK_H_
|
#define XGBOOST_TASK_H_
|
||||||
|
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
|
||||||
#include <cinttypes>
|
#include <cstdint> // for uint8_t
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*!
|
/*!
|
||||||
@ -23,7 +23,7 @@ namespace xgboost {
|
|||||||
*/
|
*/
|
||||||
struct ObjInfo {
|
struct ObjInfo {
|
||||||
// What kind of problem are we trying to solve
|
// What kind of problem are we trying to solve
|
||||||
enum Task : uint8_t {
|
enum Task : std::uint8_t {
|
||||||
kRegression = 0,
|
kRegression = 0,
|
||||||
kBinary = 1,
|
kBinary = 1,
|
||||||
kClassification = 2,
|
kClassification = 2,
|
||||||
|
|||||||
@ -45,9 +45,7 @@
|
|||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
#if !defined(GTEST_TEST)
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||||
#endif // !defined(GTEST_TEST)
|
|
||||||
|
|
||||||
// Manage memory for a single GPU
|
// Manage memory for a single GPU
|
||||||
struct GPUHistMakerDevice {
|
struct GPUHistMakerDevice {
|
||||||
@ -831,13 +829,11 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if !defined(GTEST_TEST)
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
.describe("Grow tree with GPU.")
|
.describe("Grow tree with GPU.")
|
||||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||||
return new GPUHistMaker(ctx, task);
|
return new GPUHistMaker(ctx, task);
|
||||||
});
|
});
|
||||||
#endif // !defined(GTEST_TEST)
|
|
||||||
|
|
||||||
class GPUGlobalApproxMaker : public TreeUpdater {
|
class GPUGlobalApproxMaker : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
@ -960,11 +956,9 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
|||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if !defined(GTEST_TEST)
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx")
|
||||||
.describe("Grow tree with GPU.")
|
.describe("Grow tree with GPU.")
|
||||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||||
return new GPUGlobalApproxMaker(ctx, task);
|
return new GPUGlobalApproxMaker(ctx, task);
|
||||||
});
|
});
|
||||||
#endif // !defined(GTEST_TEST)
|
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -10,9 +10,7 @@
|
|||||||
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
void VerifySampling(size_t page_size,
|
void VerifySampling(size_t page_size,
|
||||||
float subsample,
|
float subsample,
|
||||||
int sampling_method,
|
int sampling_method,
|
||||||
@ -151,6 +149,4 @@ TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {
|
|||||||
constexpr bool kFixedSizeSampling = false;
|
constexpr bool kFixedSizeSampling = false;
|
||||||
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
|
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
|
||||||
}
|
}
|
||||||
|
}; // namespace xgboost::tree
|
||||||
}; // namespace tree
|
|
||||||
}; // namespace xgboost
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
#include "../../../../src/tree/param.h" // for TrainParam
|
#include "../../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../../categorical_helpers.h" // for OneHotEncodeFeature
|
#include "../../categorical_helpers.h" // for OneHotEncodeFeature
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
#include "../../histogram_helpers.h" // for BuildEllpackPage
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
TEST(Histogram, DeviceHistogramStorage) {
|
TEST(Histogram, DeviceHistogramStorage) {
|
||||||
@ -54,6 +55,83 @@ TEST(Histogram, DeviceHistogramStorage) {
|
|||||||
EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1}););
|
EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1}););
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||||
|
// 24 bins, 3 bins for each feature (column).
|
||||||
|
std::vector<GradientPairPrecise> hist_gpair = {
|
||||||
|
{0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f},
|
||||||
|
{2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f},
|
||||||
|
{1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f},
|
||||||
|
{2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f},
|
||||||
|
{1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f},
|
||||||
|
{1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f},
|
||||||
|
{0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f},
|
||||||
|
{2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f}
|
||||||
|
};
|
||||||
|
return hist_gpair;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestBuildHist(bool use_shared_memory_histograms) {
|
||||||
|
int const kNRows = 16, kNCols = 8;
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
Args args{
|
||||||
|
{"max_depth", "6"},
|
||||||
|
{"max_leaves", "0"},
|
||||||
|
};
|
||||||
|
param.Init(args);
|
||||||
|
|
||||||
|
auto page = BuildEllpackPage(&ctx, kNRows, kNCols);
|
||||||
|
BatchParam batch_param{};
|
||||||
|
|
||||||
|
xgboost::SimpleLCG gen;
|
||||||
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
|
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||||
|
for (auto& gp : gpair.HostVector()) {
|
||||||
|
float grad = dist(&gen);
|
||||||
|
float hess = dist(&gen);
|
||||||
|
gp = GradientPair{grad, hess};
|
||||||
|
}
|
||||||
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
|
auto row_partitioner = std::make_unique<RowPartitioner>();
|
||||||
|
row_partitioner->Reset(&ctx, kNRows, 0);
|
||||||
|
|
||||||
|
auto quantiser = std::make_unique<GradientQuantiser>(&ctx, gpair.ConstDeviceSpan(), MetaInfo());
|
||||||
|
auto shm_size = use_shared_memory_histograms ? dh::MaxSharedMemoryOptin(ctx.Ordinal()) : 0;
|
||||||
|
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
|
||||||
|
|
||||||
|
DeviceHistogramStorage hist;
|
||||||
|
hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
||||||
|
hist.AllocateHistograms(&ctx, {0});
|
||||||
|
|
||||||
|
DeviceHistogramBuilder builder;
|
||||||
|
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), !use_shared_memory_histograms);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
|
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
|
||||||
|
row_partitioner->GetRows(0), hist.GetNodeHistogram(0), *quantiser);
|
||||||
|
|
||||||
|
auto node_histogram = hist.GetNodeHistogram(0);
|
||||||
|
|
||||||
|
std::vector<GradientPairInt64> h_result(node_histogram.size());
|
||||||
|
dh::CopyDeviceSpanToVector(&h_result, node_histogram);
|
||||||
|
|
||||||
|
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
||||||
|
for (size_t i = 0; i < h_result.size(); ++i) {
|
||||||
|
auto result = quantiser->ToFloatingPoint(h_result[i]);
|
||||||
|
ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||||
|
ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Histogram, BuildHistGlobalMem) {
|
||||||
|
TestBuildHist(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Histogram, BuildHistSharedMem) {
|
||||||
|
TestBuildHist(true);
|
||||||
|
}
|
||||||
|
|
||||||
void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
|
void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
|
||||||
Context ctx = MakeCUDACtx(0);
|
Context ctx = MakeCUDACtx(0);
|
||||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||||
|
|||||||
@ -2,173 +2,26 @@
|
|||||||
* Copyright 2017-2024, XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <xgboost/base.h> // for Args
|
||||||
#include <thrust/host_vector.h>
|
#include <xgboost/context.h> // for Context
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||||
|
#include <xgboost/json.h> // for Jons
|
||||||
|
#include <xgboost/task.h> // for ObjInfo
|
||||||
|
#include <xgboost/tree_model.h> // for RegTree
|
||||||
|
#include <xgboost/tree_updater.h> // for TreeUpdater
|
||||||
|
|
||||||
#include <string>
|
#include <memory> // for unique_ptr
|
||||||
#include <vector>
|
#include <string> // for string
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../src/common/common.h"
|
#include "../../../src/common/random.h" // for GlobalRandom
|
||||||
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
|
#include "../../../src/data/ellpack_page.h" // for EllpackPage
|
||||||
#include "../../../src/data/ellpack_page.h" // for EllpackPage
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../../../src/tree/param.h" // for TrainParam
|
#include "../collective/test_worker.h" // for BaseMGPUTest
|
||||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../collective/test_worker.h" // for BaseMGPUTest
|
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../histogram_helpers.h"
|
|
||||||
#include "xgboost/context.h"
|
|
||||||
#include "xgboost/json.h"
|
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
|
||||||
// 24 bins, 3 bins for each feature (column).
|
|
||||||
std::vector<GradientPairPrecise> hist_gpair = {
|
|
||||||
{0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f},
|
|
||||||
{2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f},
|
|
||||||
{1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f},
|
|
||||||
{2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f},
|
|
||||||
{1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f},
|
|
||||||
{1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f},
|
|
||||||
{0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f},
|
|
||||||
{2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f}
|
|
||||||
};
|
|
||||||
return hist_gpair;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
void TestBuildHist(bool use_shared_memory_histograms) {
|
|
||||||
int const kNRows = 16, kNCols = 8;
|
|
||||||
Context ctx{MakeCUDACtx(0)};
|
|
||||||
|
|
||||||
TrainParam param;
|
|
||||||
Args args{
|
|
||||||
{"max_depth", "6"},
|
|
||||||
{"max_leaves", "0"},
|
|
||||||
};
|
|
||||||
param.Init(args);
|
|
||||||
|
|
||||||
auto page = BuildEllpackPage(&ctx, kNRows, kNCols);
|
|
||||||
BatchParam batch_param{};
|
|
||||||
auto cs = std::make_shared<common::ColumnSampler>(0);
|
|
||||||
GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols,
|
|
||||||
batch_param, MetaInfo());
|
|
||||||
xgboost::SimpleLCG gen;
|
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
|
||||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
|
||||||
for (auto& gp : gpair.HostVector()) {
|
|
||||||
float grad = dist(&gen);
|
|
||||||
float hess = dist(&gen);
|
|
||||||
gp = GradientPair{grad, hess};
|
|
||||||
}
|
|
||||||
gpair.SetDevice(ctx.Device());
|
|
||||||
|
|
||||||
maker.row_partitioner = std::make_unique<RowPartitioner>();
|
|
||||||
maker.row_partitioner->Reset(&ctx, kNRows, 0);
|
|
||||||
|
|
||||||
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
|
||||||
maker.hist.AllocateHistograms(&ctx, {0});
|
|
||||||
|
|
||||||
maker.gpair = gpair.DeviceSpan();
|
|
||||||
maker.quantiser = std::make_unique<GradientQuantiser>(&ctx, maker.gpair, MetaInfo());
|
|
||||||
maker.page = page.get();
|
|
||||||
|
|
||||||
maker.InitFeatureGroupsOnce();
|
|
||||||
|
|
||||||
DeviceHistogramBuilder builder;
|
|
||||||
builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
|
|
||||||
!use_shared_memory_histograms);
|
|
||||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
|
||||||
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
|
|
||||||
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
|
|
||||||
*maker.quantiser);
|
|
||||||
|
|
||||||
DeviceHistogramStorage<>& d_hist = maker.hist;
|
|
||||||
|
|
||||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
|
||||||
// d_hist.data stored in float, not gradient pair
|
|
||||||
thrust::host_vector<GradientPairInt64> h_result (node_histogram.size());
|
|
||||||
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), node_histogram.size_bytes(),
|
|
||||||
cudaMemcpyDeviceToHost));
|
|
||||||
|
|
||||||
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
|
||||||
for (size_t i = 0; i < h_result.size(); ++i) {
|
|
||||||
auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
|
|
||||||
ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
|
||||||
ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(GpuHist, BuildHistGlobalMem) {
|
|
||||||
TestBuildHist<GradientPairPrecise>(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(GpuHist, BuildHistSharedMem) {
|
|
||||||
TestBuildHist<GradientPairPrecise>(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<detail::HistogramCutsWrapper> GetHostCutMatrix () {
|
|
||||||
auto cmat = std::make_shared<detail::HistogramCutsWrapper>();
|
|
||||||
cmat->SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
|
||||||
cmat->SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
|
||||||
// 24 cut fields, 3 cut fields for each feature (column).
|
|
||||||
// Each row of the cut represents the cuts for a data column.
|
|
||||||
cmat->SetValues({0.30f, 0.67f, 1.64f,
|
|
||||||
0.32f, 0.77f, 1.95f,
|
|
||||||
0.29f, 0.70f, 1.80f,
|
|
||||||
0.32f, 0.75f, 1.85f,
|
|
||||||
0.18f, 0.59f, 1.69f,
|
|
||||||
0.25f, 0.74f, 2.00f,
|
|
||||||
0.26f, 0.74f, 1.98f,
|
|
||||||
0.26f, 0.71f, 1.83f});
|
|
||||||
return cmat;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestHistogramIndexImpl() {
|
|
||||||
// Test if the compressed histogram index matches when using a sparse
|
|
||||||
// dmatrix with and without using external memory
|
|
||||||
|
|
||||||
int constexpr kNRows = 1000, kNCols = 10;
|
|
||||||
|
|
||||||
// Build 2 matrices and build a histogram maker with that
|
|
||||||
Context ctx(MakeCUDACtx(0));
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
|
||||||
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
|
|
||||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
|
||||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tempdir;
|
|
||||||
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
|
|
||||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir));
|
|
||||||
|
|
||||||
Args training_params = {{"max_depth", "10"}, {"max_leaves", "0"}};
|
|
||||||
TrainParam param;
|
|
||||||
param.UpdateAllowUnknown(training_params);
|
|
||||||
|
|
||||||
hist_maker.Configure(training_params);
|
|
||||||
hist_maker.InitDataOnce(¶m, hist_maker_dmat.get());
|
|
||||||
hist_maker_ext.Configure(training_params);
|
|
||||||
hist_maker_ext.InitDataOnce(¶m, hist_maker_ext_dmat.get());
|
|
||||||
|
|
||||||
// Extract the device maker from the histogram makers and from that its compressed
|
|
||||||
// histogram index
|
|
||||||
const auto &maker = hist_maker.maker;
|
|
||||||
auto grad = GenerateRandomGradients(kNRows);
|
|
||||||
grad.SetDevice(DeviceOrd::CUDA(0));
|
|
||||||
maker->Reset(&grad, hist_maker_dmat.get(), kNCols);
|
|
||||||
|
|
||||||
const auto &maker_ext = hist_maker_ext.maker;
|
|
||||||
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols);
|
|
||||||
|
|
||||||
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());
|
|
||||||
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(GpuHist, TestHistogramIndex) {
|
|
||||||
TestHistogramIndexImpl();
|
|
||||||
}
|
|
||||||
|
|
||||||
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
||||||
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
|
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
|
||||||
float subsample = 1.0f, const std::string& sampling_method = "uniform",
|
float subsample = 1.0f, const std::string& sampling_method = "uniform",
|
||||||
@ -200,14 +53,14 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
|
|||||||
param.UpdateAllowUnknown(args);
|
param.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
tree::GPUHistMaker hist_maker{ctx, &task};
|
std::unique_ptr<TreeUpdater> hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)};
|
||||||
hist_maker.Configure(Args{});
|
hist_maker->Configure(Args{});
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
hist_maker->Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||||
{tree});
|
{tree});
|
||||||
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||||
hist_maker.UpdatePredictionCache(dmat, cache);
|
hist_maker->UpdatePredictionCache(dmat, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, UniformSampling) {
|
TEST(GpuHist, UniformSampling) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user