More explict sharding methods for device memory (#4396)

* Rename the Reshard method to Shard

* Add a new Reshard method for sharding a vector that's already sharded
This commit is contained in:
Rong Ou
2019-04-30 16:47:23 -07:00
committed by Rory Mitchell
parent 797ba8e72d
commit eaab364a63
12 changed files with 154 additions and 77 deletions

View File

@@ -39,7 +39,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
.describe("gpu to use for objective function evaluation");
}
};
// TODO(trivialfis): Currently the resharding in softmax is less than ideal
// TODO(trivialfis): Currently the sharding in softmax is less than ideal
// due to repeated copying data between CPU and GPUs. Maybe we just use single
// GPU?
class SoftmaxMultiClassObj : public ObjFunction {
@@ -63,11 +63,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass));
info.labels_.Reshard(GPUDistribution::Block(devices_));
info.weights_.Reshard(GPUDistribution::Block(devices_));
preds.Reshard(GPUDistribution::Granular(devices_, nclass));
label_correct_.Reshard(GPUDistribution::Block(devices_));
out_gpair->Shard(GPUDistribution::Granular(devices_, nclass));
info.labels_.Shard(GPUDistribution::Block(devices_));
info.weights_.Shard(GPUDistribution::Block(devices_));
preds.Shard(GPUDistribution::Granular(devices_, nclass));
label_correct_.Shard(GPUDistribution::Block(devices_));
out_gpair->Resize(preds.Size());
label_correct_.Fill(1);
@@ -136,8 +136,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass))
.Eval(io_preds);
} else {
io_preds->Reshard(GPUDistribution::Granular(devices_, nclass));
max_preds_.Reshard(GPUDistribution::Block(devices_));
io_preds->Shard(GPUDistribution::Granular(devices_, nclass));
max_preds_.Shard(GPUDistribution::Block(devices_));
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<const bst_float> _preds,