Overload device memory allocation (#4532)
* Group source files, include headers in source files * Overload device memory allocation
This commit is contained in:
@@ -67,3 +67,6 @@ if (USE_OPENMP)
|
||||
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
|
||||
endif (USE_OPENMP)
|
||||
set_output_directory(testxgboost ${PROJECT_BINARY_DIR})
|
||||
|
||||
# This grouping organises source files nicely in visual studio
|
||||
auto_source_group("${TEST_SOURCES}")
|
||||
|
||||
@@ -338,8 +338,6 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
|
||||
GPUHistMakerSpecialised<GradientPairPrecise>();
|
||||
int constexpr kNId = 0;
|
||||
int constexpr kNRows = 16;
|
||||
int constexpr kNCols = 8;
|
||||
@@ -353,11 +351,9 @@ TEST(GpuHist, ApplySplit) {
|
||||
param.monotone_constraints.emplace_back(0);
|
||||
}
|
||||
|
||||
hist_maker.shards_.resize(1);
|
||||
hist_maker.shards_[0].reset(
|
||||
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols));
|
||||
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
||||
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols)};
|
||||
|
||||
auto& shard = hist_maker.shards_.at(0);
|
||||
shard->ridx_segments.resize(3); // 3 nodes.
|
||||
shard->node_sum_gradients.resize(3);
|
||||
|
||||
@@ -368,8 +364,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(shard->ridx.Current()),
|
||||
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
|
||||
// Initialize GPUHistMaker
|
||||
hist_maker.param_ = param;
|
||||
RegTree tree;
|
||||
|
||||
DeviceSplitCandidate candidate;
|
||||
@@ -382,7 +376,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
|
||||
// Used to get bin_id in update position.
|
||||
common::HistCutMatrix cmat = GetHostCutMatrix();
|
||||
hist_maker.hmat_ = cmat;
|
||||
|
||||
MetaInfo info;
|
||||
info.num_row_ = kNRows;
|
||||
@@ -421,7 +414,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
|
||||
shard->gidx_buffer.data(), num_symbols);
|
||||
|
||||
hist_maker.info_ = &info;
|
||||
shard->ApplySplit(candidate_entry, &tree);
|
||||
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user