Overload device memory allocation (#4532)

* Group source files, include headers in source files

* Overload device memory allocation
This commit is contained in:
Rory Mitchell
2019-06-10 11:35:13 +12:00
committed by GitHub
parent da21ac0cc2
commit 9683fd433e
9 changed files with 140 additions and 49 deletions

View File

@@ -67,3 +67,6 @@ if (USE_OPENMP)
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
endif (USE_OPENMP)
set_output_directory(testxgboost ${PROJECT_BINARY_DIR})
# This grouping organises source files nicely in visual studio
auto_source_group("${TEST_SOURCES}")

View File

@@ -338,8 +338,6 @@ TEST(GpuHist, EvaluateSplits) {
}
TEST(GpuHist, ApplySplit) {
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
GPUHistMakerSpecialised<GradientPairPrecise>();
int constexpr kNId = 0;
int constexpr kNRows = 16;
int constexpr kNCols = 8;
@@ -353,11 +351,9 @@ TEST(GpuHist, ApplySplit) {
param.monotone_constraints.emplace_back(0);
}
hist_maker.shards_.resize(1);
hist_maker.shards_[0].reset(
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols));
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols)};
auto& shard = hist_maker.shards_.at(0);
shard->ridx_segments.resize(3); // 3 nodes.
shard->node_sum_gradients.resize(3);
@@ -368,8 +364,6 @@ TEST(GpuHist, ApplySplit) {
thrust::sequence(
thrust::device_pointer_cast(shard->ridx.Current()),
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
// Initialize GPUHistMaker
hist_maker.param_ = param;
RegTree tree;
DeviceSplitCandidate candidate;
@@ -382,7 +376,6 @@ TEST(GpuHist, ApplySplit) {
// Used to get bin_id in update position.
common::HistCutMatrix cmat = GetHostCutMatrix();
hist_maker.hmat_ = cmat;
MetaInfo info;
info.num_row_ = kNRows;
@@ -421,7 +414,6 @@ TEST(GpuHist, ApplySplit) {
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
shard->gidx_buffer.data(), num_symbols);
hist_maker.info_ = &info;
shard->ApplySplit(candidate_entry, &tree);
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);