More explict sharding methods for device memory (#4396)

* Rename the Reshard method to Shard

* Add a new Reshard method for sharding a vector that's already sharded
This commit is contained in:
Rong Ou 2019-04-30 16:47:23 -07:00 committed by Rory Mitchell
parent 797ba8e72d
commit eaab364a63
12 changed files with 154 additions and 77 deletions

View File

@ -154,10 +154,13 @@ bool HostDeviceVector<T>::DeviceCanAccess(int device, GPUAccess access) const {
}
template <typename T>
void HostDeviceVector<T>::Reshard(const GPUDistribution& distribution) const { }
void HostDeviceVector<T>::Shard(const GPUDistribution& distribution) const { }
template <typename T>
void HostDeviceVector<T>::Reshard(GPUSet devices) const { }
void HostDeviceVector<T>::Shard(GPUSet devices) const { }
template <typename T>
void Reshard(const GPUDistribution &distribution) { }
// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;

View File

@ -318,7 +318,7 @@ struct HostDeviceVectorImpl {
// Data is on device;
if (distribution_ != other->distribution_) {
distribution_ = GPUDistribution();
Reshard(other->Distribution());
Shard(other->Distribution());
size_d_ = other->size_d_;
}
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
@ -358,19 +358,24 @@ struct HostDeviceVectorImpl {
return data_h_;
}
void Reshard(const GPUDistribution& distribution) {
void Shard(const GPUDistribution& distribution) {
if (distribution_ == distribution) { return; }
CHECK(distribution_.IsEmpty() || distribution.IsEmpty());
if (distribution.IsEmpty()) {
LazySyncHost(GPUAccess::kWrite);
}
CHECK(distribution_.IsEmpty());
distribution_ = distribution;
InitShards();
}
void Reshard(GPUSet new_devices) {
void Shard(GPUSet new_devices) {
if (distribution_.Devices() == new_devices) { return; }
Reshard(GPUDistribution::Block(new_devices));
Shard(GPUDistribution::Block(new_devices));
}
void Reshard(const GPUDistribution &distribution) {
if (distribution_ == distribution) { return; }
LazySyncHost(GPUAccess::kWrite);
distribution_ = distribution;
shards_.clear();
InitShards();
}
void Resize(size_t new_size, T v) {
@ -586,12 +591,17 @@ bool HostDeviceVector<T>::DeviceCanAccess(int device, GPUAccess access) const {
}
template <typename T>
void HostDeviceVector<T>::Reshard(GPUSet new_devices) const {
impl_->Reshard(new_devices);
void HostDeviceVector<T>::Shard(GPUSet new_devices) const {
impl_->Shard(new_devices);
}
template <typename T>
void HostDeviceVector<T>::Reshard(const GPUDistribution& distribution) const {
void HostDeviceVector<T>::Shard(const GPUDistribution &distribution) const {
impl_->Shard(distribution);
}
template <typename T>
void HostDeviceVector<T>::Reshard(const GPUDistribution &distribution) {
impl_->Reshard(distribution);
}

View File

@ -14,7 +14,7 @@
* Initialization/Allocation:<br/>
* One can choose to initialize the vector on CPU or GPU during constructor.
* (use the 'devices' argument) Or, can choose to use the 'Resize' method to
* allocate/resize memory explicitly, and use the 'Reshard' method
* allocate/resize memory explicitly, and use the 'Shard' method
* to specify the devices.
*
* Accessing underlying data:<br/>
@ -98,6 +98,8 @@ class GPUDistribution {
offsets_(std::move(offsets)) {}
public:
static GPUDistribution Empty() { return GPUDistribution(); }
static GPUDistribution Block(GPUSet devices) { return GPUDistribution(devices); }
static GPUDistribution Overlap(GPUSet devices, int overlap) {
@ -250,11 +252,15 @@ class HostDeviceVector {
/*!
* \brief Specify memory distribution.
*
* If GPUSet::Empty() is used, all data will be drawn back to CPU.
*/
void Reshard(const GPUDistribution& distribution) const;
void Reshard(GPUSet devices) const;
void Shard(const GPUDistribution &distribution) const;
void Shard(GPUSet devices) const;
/*!
* \brief Change memory distribution.
*/
void Reshard(const GPUDistribution &distribution);
void Resize(size_t new_size, T v = T());
private:

View File

@ -57,13 +57,13 @@ class Transform {
template <typename Functor>
struct Evaluator {
public:
Evaluator(Functor func, Range range, GPUSet devices, bool reshard) :
Evaluator(Functor func, Range range, GPUSet devices, bool shard) :
func_(func), range_{std::move(range)},
reshard_{reshard},
shard_{shard},
distribution_{std::move(GPUDistribution::Block(devices))} {}
Evaluator(Functor func, Range range, GPUDistribution dist,
bool reshard) :
func_(func), range_{std::move(range)}, reshard_{reshard},
bool shard) :
func_(func), range_{std::move(range)}, shard_{shard},
distribution_{std::move(dist)} {}
/*!
@ -106,25 +106,25 @@ class Transform {
return Span<T const> {_vec->ConstHostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
// Recursive unpack for Reshard.
// Recursive unpack for Shard.
template <typename T>
void UnpackReshard(GPUDistribution dist, const HostDeviceVector<T>* vector) const {
vector->Reshard(dist);
void UnpackShard(GPUDistribution dist, const HostDeviceVector<T> *vector) const {
vector->Shard(dist);
}
template <typename Head, typename... Rest>
void UnpackReshard(GPUDistribution dist,
void UnpackShard(GPUDistribution dist,
const HostDeviceVector<Head> *_vector,
const HostDeviceVector<Rest> *... _vectors) const {
_vector->Reshard(dist);
UnpackReshard(dist, _vectors...);
_vector->Shard(dist);
UnpackShard(dist, _vectors...);
}
#if defined(__CUDACC__)
template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, HDV*... _vectors) const {
if (reshard_)
UnpackReshard(distribution_, _vectors...);
if (shard_)
UnpackShard(distribution_, _vectors...);
GPUSet devices = distribution_.Devices();
size_t range_size = *range_.end() - *range_.begin();
@ -170,8 +170,8 @@ class Transform {
Functor func_;
/*! \brief Range object specifying parallel threads index range. */
Range range_;
/*! \brief Whether resharding for vectors is required. */
bool reshard_;
/*! \brief Whether sharding for vectors is required. */
bool shard_;
GPUDistribution distribution_;
};
@ -187,19 +187,19 @@ class Transform {
* \param range Range object specifying parallel threads index range.
* \param devices GPUSet specifying GPUs to use, when compiling for CPU,
* this should be GPUSet::Empty().
* \param reshard Whether Reshard for HostDeviceVector is needed.
* \param shard Whether Shard for HostDeviceVector is needed.
*/
template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range,
GPUSet const devices,
bool const reshard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(devices), reshard};
bool const shard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(devices), shard};
}
template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range,
GPUDistribution const dist,
bool const reshard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(dist), reshard};
bool const shard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(dist), shard};
}
};

View File

@ -111,9 +111,9 @@ class ElementWiseMetricsReduction {
allocators_.clear();
allocators_.resize(devices.Size());
}
preds.Reshard(devices);
labels.Reshard(devices);
weights.Reshard(devices);
preds.Shard(devices);
labels.Shard(devices);
weights.Shard(devices);
std::vector<PackedReduceResult> res_per_device(devices.Size());
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)

View File

@ -134,9 +134,9 @@ class MultiClassMetricsReduction {
allocators_.clear();
allocators_.resize(devices.Size());
}
preds.Reshard(GPUDistribution::Granular(devices, n_class));
labels.Reshard(devices);
weights.Reshard(devices);
preds.Shard(GPUDistribution::Granular(devices, n_class));
labels.Shard(devices);
weights.Shard(devices);
std::vector<PackedReduceResult> res_per_device(devices.Size());
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)

View File

@ -39,7 +39,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
.describe("gpu to use for objective function evaluation");
}
};
// TODO(trivialfis): Currently the resharding in softmax is less than ideal
// TODO(trivialfis): Currently the sharding in softmax is less than ideal
// due to repeated copying data between CPU and GPUs. Maybe we just use single
// GPU?
class SoftmaxMultiClassObj : public ObjFunction {
@ -63,11 +63,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass));
info.labels_.Reshard(GPUDistribution::Block(devices_));
info.weights_.Reshard(GPUDistribution::Block(devices_));
preds.Reshard(GPUDistribution::Granular(devices_, nclass));
label_correct_.Reshard(GPUDistribution::Block(devices_));
out_gpair->Shard(GPUDistribution::Granular(devices_, nclass));
info.labels_.Shard(GPUDistribution::Block(devices_));
info.weights_.Shard(GPUDistribution::Block(devices_));
preds.Shard(GPUDistribution::Granular(devices_, nclass));
label_correct_.Shard(GPUDistribution::Block(devices_));
out_gpair->Resize(preds.Size());
label_correct_.Fill(1);
@ -136,8 +136,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass))
.Eval(io_preds);
} else {
io_preds->Reshard(GPUDistribution::Granular(devices_, nclass));
max_preds_.Reshard(GPUDistribution::Block(devices_));
io_preds->Shard(GPUDistribution::Granular(devices_, nclass));
max_preds_.Shard(GPUDistribution::Block(devices_));
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<const bst_float> _preds,

View File

@ -327,11 +327,11 @@ class GPUPredictor : public xgboost::Predictor {
for (const auto &batch : dmat->GetRowBatches()) {
CHECK_EQ(i_batch, 0) << "External memory not supported";
// out_preds have been resharded and resized in InitOutPredictions()
batch.offset.Reshard(GPUDistribution::Overlap(devices_, 1));
// out_preds have been sharded and resized in InitOutPredictions()
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
std::vector<size_t> device_offsets;
DeviceOffsets(batch.offset, &device_offsets);
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
batch.data.Shard(GPUDistribution::Explicit(devices_, device_offsets));
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.PredictInternal(batch, dmat->Info(), out_preds, model,
h_tree_segments, h_nodes, tree_begin, tree_end);
@ -373,7 +373,7 @@ class GPUPredictor : public xgboost::Predictor {
size_t n_classes = model.param.num_output_group;
size_t n = n_classes * info.num_row_;
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
out_preds->Reshard(GPUDistribution::Granular(devices_, n_classes));
out_preds->Shard(GPUDistribution::Granular(devices_, n_classes));
out_preds->Resize(n);
if (base_margin.Size() != 0) {
CHECK_EQ(out_preds->Size(), n);
@ -392,7 +392,7 @@ class GPUPredictor : public xgboost::Predictor {
const HostDeviceVector<bst_float>& y = it->second.predictions;
if (y.Size() != 0) {
monitor_.StartCuda("PredictFromCache");
out_preds->Reshard(y.Distribution());
out_preds->Shard(y.Distribution());
out_preds->Resize(y.Size());
out_preds->Copy(y);
monitor_.StopCuda("PredictFromCache");

View File

@ -566,7 +566,7 @@ class GPUMaker : public TreeUpdater {
int maxNodes_;
int maxLeaves_;
// devices are only used for resharding the HostDeviceVector passed as a parameter;
// devices are only used for sharding the HostDeviceVector passed as a parameter;
// the algorithm works with a single GPU only
GPUSet devices_;
@ -594,7 +594,7 @@ class GPUMaker : public TreeUpdater {
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
gpair->Reshard(devices_);
gpair->Shard(devices_);
try {
// build tree

View File

@ -836,7 +836,7 @@ struct DeviceShard {
for (auto i = 0ull; i < nidxs.size(); i++) {
auto nidx = nidxs[i];
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
p_feature_set->Reshard(GPUSet(device_id, 1));
p_feature_set->Shard(GPUSet(device_id, 1));
auto d_feature_set = p_feature_set->DeviceSpan(device_id);
auto d_split_candidates =
d_split_candidates_all.subspan(i * num_columns, d_feature_set.size());
@ -1527,7 +1527,7 @@ class GPUHistMakerSpecialised{
return false;
}
monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->Reshard(dist_.Devices());
p_out_preds->Shard(dist_.Devices());
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {

View File

@ -23,7 +23,7 @@ void InitHostDeviceVector(size_t n, const GPUDistribution& distribution,
HostDeviceVector<int> *v) {
// create the vector
GPUSet devices = distribution.Devices();
v->Reshard(distribution);
v->Shard(distribution);
v->Resize(n);
ASSERT_EQ(v->Size(), n);
@ -178,6 +178,27 @@ TEST(HostDeviceVector, TestCopy) {
SetCudaSetDeviceHandler(nullptr);
}
TEST(HostDeviceVector, Shard) {
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
auto devices = GPUSet::Range(0, 1);
vec.Shard(devices);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(0); // sync to device
vec.Reshard(GPUDistribution::Empty()); // pull back to cpu, empty devices.
ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_TRUE(vec.Devices().IsEmpty());
auto h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
}
TEST(HostDeviceVector, Reshard) {
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
@ -186,22 +207,24 @@ TEST(HostDeviceVector, Reshard) {
HostDeviceVector<int> vec (h_vec);
auto devices = GPUSet::Range(0, 1);
vec.Reshard(devices);
vec.Shard(devices);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(0); // sync to device
PlusOne(&vec);
vec.Reshard(GPUSet::Empty()); // pull back to cpu, empty devices.
vec.Reshard(GPUDistribution::Empty());
ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_TRUE(vec.Devices().IsEmpty());
auto h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
for (size_t i = 0; i < h_vec_1.size(); ++i) {
ASSERT_EQ(h_vec_1.at(i), i + 1);
}
}
TEST(HostDeviceVector, Span) {
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
vec.Reshard(GPUSet{0, 1});
vec.Shard(GPUSet{0, 1});
auto span = vec.DeviceSpan(0);
ASSERT_EQ(vec.DeviceSize(0), span.size());
ASSERT_EQ(vec.DevicePointer(0), span.data());
@ -212,7 +235,7 @@ TEST(HostDeviceVector, Span) {
// Multi-GPUs' test
#if defined(XGBOOST_USE_NCCL)
TEST(HostDeviceVector, MGPU_Reshard) {
TEST(HostDeviceVector, MGPU_Shard) {
auto devices = GPUSet::AllVisible();
if (devices.Size() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
@ -229,7 +252,7 @@ TEST(HostDeviceVector, MGPU_Reshard) {
std::vector<size_t> devices_size (devices.Size());
// From CPU to GPUs.
vec.Reshard(devices);
vec.Shard(devices);
size_t total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
@ -238,16 +261,16 @@ TEST(HostDeviceVector, MGPU_Reshard) {
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
// Reshard from devices to devices with different distribution.
// Shard from devices to devices with different distribution.
EXPECT_ANY_THROW(
vec.Reshard(GPUDistribution::Granular(devices, 12)));
vec.Shard(GPUDistribution::Granular(devices, 12)));
// All data is drawn back to CPU
vec.Reshard(GPUSet::Empty());
vec.Reshard(GPUDistribution::Empty());
ASSERT_TRUE(vec.Devices().IsEmpty());
ASSERT_EQ(vec.Size(), h_vec.size());
vec.Reshard(GPUDistribution::Granular(devices, 12));
vec.Shard(GPUDistribution::Granular(devices, 12));
total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
@ -256,6 +279,41 @@ TEST(HostDeviceVector, MGPU_Reshard) {
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
}
TEST(HostDeviceVector, MGPU_Reshard) {
auto devices = GPUSet::AllVisible();
if (devices.Size() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
size_t n = 1001;
int n_devices = 2;
auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices));
std::vector<size_t> starts{0, 501};
std::vector<size_t> sizes{501, 500};
HostDeviceVector<int> v;
InitHostDeviceVector(n, distribution, &v);
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
auto distribution1 = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1);
v.Reshard(distribution1);
for (size_t i = 0; i < n_devices; ++i) {
auto span = v.DeviceSpan(i); // sync to device
}
std::vector<size_t> starts1{0, 500};
std::vector<size_t> sizes1{501, 501};
CheckDevice(&v, starts1, sizes1, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
}
#endif
} // namespace common

View File

@ -22,10 +22,10 @@ TEST(Transform, MGPU_Basic) {
GPUDistribution::Block(GPUSet::Empty())};
out_vec.Fill(0);
in_vec.Reshard(GPUDistribution::Granular(devices, 8));
out_vec.Reshard(GPUDistribution::Block(devices));
in_vec.Shard(GPUDistribution::Granular(devices, 8));
out_vec.Shard(GPUDistribution::Block(devices));
// Granularity is different, resharding will throw.
// Granularity is different, sharding will throw.
EXPECT_ANY_THROW(
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, devices)
.Eval(&out_vec, &in_vec));