More explict sharding methods for device memory (#4396)
* Rename the Reshard method to Shard * Add a new Reshard method for sharding a vector that's already sharded
This commit is contained in:
@@ -111,9 +111,9 @@ class ElementWiseMetricsReduction {
|
||||
allocators_.clear();
|
||||
allocators_.resize(devices.Size());
|
||||
}
|
||||
preds.Reshard(devices);
|
||||
labels.Reshard(devices);
|
||||
weights.Reshard(devices);
|
||||
preds.Shard(devices);
|
||||
labels.Shard(devices);
|
||||
weights.Shard(devices);
|
||||
std::vector<PackedReduceResult> res_per_device(devices.Size());
|
||||
|
||||
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||
|
||||
@@ -134,9 +134,9 @@ class MultiClassMetricsReduction {
|
||||
allocators_.clear();
|
||||
allocators_.resize(devices.Size());
|
||||
}
|
||||
preds.Reshard(GPUDistribution::Granular(devices, n_class));
|
||||
labels.Reshard(devices);
|
||||
weights.Reshard(devices);
|
||||
preds.Shard(GPUDistribution::Granular(devices, n_class));
|
||||
labels.Shard(devices);
|
||||
weights.Shard(devices);
|
||||
std::vector<PackedReduceResult> res_per_device(devices.Size());
|
||||
|
||||
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||
|
||||
Reference in New Issue
Block a user