Single precision histograms on GPU (#3965)

* Allow single precision histogram summation in gpu_hist

* Add python test, reduce run-time of gpu_hist tests

* Update documentation
This commit is contained in:
Rory Mitchell
2018-12-10 10:55:30 +13:00
committed by GitHub
parent 9af6b689d6
commit 93f9ce9ef9
10 changed files with 351 additions and 212 deletions

View File

@@ -116,19 +116,19 @@ struct GPUSketcher {
n_rows_(row_end - row_begin), param_(std::move(param)) {
}
void Init(const SparsePage& row_batch, const MetaInfo& info) {
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
num_cols_ = info.num_col_;
has_weights_ = info.weights_.Size() > 0;
// find the batch size
if (param_.gpu_batch_nrows == 0) {
if (gpu_batch_nrows == 0) {
// By default, use no more than 1/16th of GPU memory
gpu_batch_nrows_ = dh::TotalMemory(device_) /
(16 * num_cols_ * sizeof(Entry));
} else if (param_.gpu_batch_nrows == -1) {
} else if (gpu_batch_nrows == -1) {
gpu_batch_nrows_ = n_rows_;
} else {
gpu_batch_nrows_ = param_.gpu_batch_nrows;
gpu_batch_nrows_ = gpu_batch_nrows;
}
if (gpu_batch_nrows_ > n_rows_) {
gpu_batch_nrows_ = n_rows_;
@@ -346,7 +346,8 @@ struct GPUSketcher {
}
};
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
void Sketch(const SparsePage& batch, const MetaInfo& info,
HistCutMatrix* hmat, int gpu_batch_nrows) {
// create device shards
shards_.resize(dist_.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
@@ -358,10 +359,11 @@ struct GPUSketcher {
});
// compute sketches for each shard
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info);
shard->Sketch(batch, info);
});
dh::ExecuteIndexShards(&shards_,
[&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info, gpu_batch_nrows);
shard->Sketch(batch, info);
});
// merge the sketches from all shards
// TODO(canonizer): do it in a tree-like reduction
@@ -390,9 +392,9 @@ struct GPUSketcher {
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat) {
const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows) {
GPUSketcher sketcher(param, info.num_row_);
sketcher.Sketch(batch, info, hmat);
sketcher.Sketch(batch, info, hmat, gpu_batch_nrows);
}
} // namespace common