Implement devices to devices reshard. (#3721)

* Force clearing device memory before Reshard.
* Remove calculating row_segments for gpu_hist and gpu_sketch.
* Guard against changing device.
This commit is contained in:
trivialfis
2018-09-28 17:40:23 +12:00
committed by Rory Mitchell
parent 0b7fd74138
commit 5a7f7e7d49
11 changed files with 179 additions and 96 deletions

View File

@@ -347,15 +347,13 @@ struct GPUSketcher {
};
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
// partition input matrix into row segments
std::vector<size_t> row_segments;
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
// create device shards
shards_.resize(devices_.Size());
shards_.resize(dist_.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist_.ShardStart(info.num_row_, i);
size_t size = dist_.ShardSize(info.num_row_, i);
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
(new DeviceShard(dist_.Devices()[i], start, start + size, param_));
});
// compute sketches for each shard
@@ -381,12 +379,13 @@ struct GPUSketcher {
}
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id);
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows).
Normalised(param_.gpu_id));
}
std::vector<std::unique_ptr<DeviceShard>> shards_;
tree::TrainParam param_;
GPUSet devices_;
GPUDistribution dist_;
};
void DeviceSketch