Fix specifying gpu_id, add tests. (#3851)
* Rewrite gpu_id related code. * Remove normalised/unnormalised operatios. * Address difference between `Index' and `Device ID'. * Modify doc for `gpu_id'. * Better LOG for GPUSet. * Check specified n_gpus. * Remove inappropriate `device_idx' term. * Clarify GpuIdType and size_t.
This commit is contained in:
@@ -251,15 +251,15 @@ struct DeviceHistogram {
|
||||
thrust::device_vector<GradientPairSumT::ValueT> data;
|
||||
const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
|
||||
int n_bins;
|
||||
int device_idx;
|
||||
int device_id_;
|
||||
|
||||
void Init(int device_idx, int n_bins) {
|
||||
void Init(int device_id, int n_bins) {
|
||||
this->n_bins = n_bins;
|
||||
this->device_idx = device_idx;
|
||||
this->device_id_ = device_id;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
data.resize(0);
|
||||
nidx_map.clear();
|
||||
}
|
||||
@@ -281,7 +281,7 @@ struct DeviceHistogram {
|
||||
} else {
|
||||
// Append new node histogram
|
||||
nidx_map[nidx] = data.size();
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
// x 2: Hess and Grad.
|
||||
data.resize(data.size() + (n_bins * 2));
|
||||
}
|
||||
@@ -396,13 +396,12 @@ struct DeviceShard;
|
||||
struct GPUHistBuilderBase {
|
||||
public:
|
||||
virtual void Build(DeviceShard* shard, int idx) = 0;
|
||||
virtual ~GPUHistBuilderBase() = default;
|
||||
};
|
||||
|
||||
// Manage memory for a single GPU
|
||||
struct DeviceShard {
|
||||
int device_idx;
|
||||
/*! \brief Device index counting from param.gpu_id */
|
||||
int normalised_device_idx;
|
||||
int device_id_;
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
|
||||
/*! \brief HistCutMatrix stored in device. */
|
||||
@@ -463,10 +462,9 @@ struct DeviceShard {
|
||||
std::unique_ptr<GPUHistBuilderBase> hist_builder;
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int device_idx, int normalised_device_idx,
|
||||
DeviceShard(int device_id,
|
||||
bst_uint row_begin, bst_uint row_end, TrainParam _param) :
|
||||
device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
device_id_(device_id),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
@@ -479,7 +477,7 @@ struct DeviceShard {
|
||||
|
||||
/* Init row_ptrs and row_stride */
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs.resize(n_rows + 1);
|
||||
thrust::copy(offset_vec.data() + row_begin_idx,
|
||||
@@ -537,7 +535,7 @@ struct DeviceShard {
|
||||
|
||||
// Reset values for each update iteration
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
position.CurrentDVec().Fill(0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
@@ -546,7 +544,8 @@ struct DeviceShard {
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.Size());
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_idx), dh_gpair->tcend(device_idx));
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_id_),
|
||||
dh_gpair->tcend(device_id_));
|
||||
SubsampleGradientPair(&gpair, param.subsample, row_begin_idx);
|
||||
hist.Reset();
|
||||
}
|
||||
@@ -562,7 +561,7 @@ struct DeviceShard {
|
||||
auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram);
|
||||
auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction);
|
||||
|
||||
dh::LaunchN(device_idx, hist.n_bins, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id_, hist.n_bins, [=] __device__(size_t idx) {
|
||||
d_node_hist_subtraction[idx] =
|
||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||
});
|
||||
@@ -589,7 +588,7 @@ struct DeviceShard {
|
||||
int64_t split_gidx, bool default_dir_left, bool is_dense,
|
||||
int fidx_begin, // cut.row_ptr[fidx]
|
||||
int fidx_end) { // cut.row_ptr[fidx + 1]
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
temp_memory.LazyAllocate(sizeof(int64_t));
|
||||
int64_t* d_left_count = temp_memory.Pointer<int64_t>();
|
||||
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t)));
|
||||
@@ -600,7 +599,7 @@ struct DeviceShard {
|
||||
size_t row_stride = this->row_stride;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 512>(
|
||||
device_idx, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
device_id_, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
idx += segment.begin;
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
auto row_begin = row_stride * ridx;
|
||||
@@ -669,7 +668,7 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
prediction_cache.Data(), out_preds_d,
|
||||
@@ -689,7 +688,7 @@ struct DeviceShard {
|
||||
auto d_prediction_cache = prediction_cache.Data();
|
||||
|
||||
dh::LaunchN(
|
||||
device_idx, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
device_id_, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
@@ -723,7 +722,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase {
|
||||
if (grid_size <= 0) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id_));
|
||||
sharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
|
||||
(shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist, d_gpair,
|
||||
segment_begin, n_elements);
|
||||
@@ -742,7 +741,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase {
|
||||
size_t const row_stride = shard->row_stride;
|
||||
int const null_gidx_value = shard->null_gidx_value;
|
||||
|
||||
dh::LaunchN(shard->device_idx, n_elements, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(shard->device_id_, n_elements, [=] __device__(size_t idx) {
|
||||
int ridx = d_ridx[(idx / row_stride) + segment.begin];
|
||||
// lookup the index (bin) of histogram.
|
||||
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
|
||||
@@ -762,7 +761,7 @@ inline void DeviceShard::InitCompressedData(
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
|
||||
ba.Allocate(device_idx, param.silent,
|
||||
ba.Allocate(device_id_, param.silent,
|
||||
&gpair, n_rows,
|
||||
&ridx, n_rows,
|
||||
&position, n_rows,
|
||||
@@ -780,7 +779,7 @@ inline void DeviceShard::InitCompressedData(
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
@@ -792,7 +791,7 @@ inline void DeviceShard::InitCompressedData(
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);
|
||||
ba.Allocate(device_id_, param.silent, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
int nbits = common::detail::SymbolBits(num_symbols);
|
||||
@@ -804,7 +803,7 @@ inline void DeviceShard::InitCompressedData(
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
|
||||
auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value;
|
||||
auto max_smem = dh::MaxSharedMemory(device_idx);
|
||||
auto max_smem = dh::MaxSharedMemory(device_id_);
|
||||
if (histogram_size <= max_smem) {
|
||||
hist_builder.reset(new SharedMemHistBuilder);
|
||||
} else {
|
||||
@@ -812,7 +811,7 @@ inline void DeviceShard::InitCompressedData(
|
||||
}
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_idx, hmat.row_ptr.back());
|
||||
hist.Init(device_id_, hmat.row_ptr.back());
|
||||
|
||||
dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
|
||||
}
|
||||
@@ -820,9 +819,10 @@ inline void DeviceShard::InitCompressedData(
|
||||
inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
|
||||
int num_symbols = n_bins + 1;
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows = std::min
|
||||
(dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(n_rows));
|
||||
size_t gpu_batch_nrows =
|
||||
std::min
|
||||
(dh::TotalMemory(device_id_) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(n_rows));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
|
||||
|
||||
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
@@ -876,8 +876,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
||||
n_devices_ = param_.n_gpus;
|
||||
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus)
|
||||
.Normalised(param_.gpu_id));
|
||||
dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus));
|
||||
|
||||
dh::CheckComputeCapability();
|
||||
|
||||
@@ -914,12 +913,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
info_ = &dmat->Info();
|
||||
|
||||
int n_devices = GPUSet::All(param_.n_gpus, info_->num_row_).Size();
|
||||
int n_devices = dist_.Devices().Size();
|
||||
|
||||
device_list_.resize(n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
int device_idx = GPUSet::GetDeviceIdx(param_.gpu_id + d_idx);
|
||||
device_list_[d_idx] = device_idx;
|
||||
for (int index = 0; index < n_devices; ++index) {
|
||||
int device_id = dist_.Devices().DeviceId(index);
|
||||
device_list_[index] = device_id;
|
||||
}
|
||||
|
||||
reducer_.Init(device_list_);
|
||||
@@ -932,8 +931,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
size_t start = dist_.ShardStart(info_->num_row_, i);
|
||||
size_t size = dist_.ShardSize(info_->num_row_, i);
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(device_list_.at(i), i,
|
||||
start, start + size, param_));
|
||||
(new DeviceShard(dist_.Devices().DeviceId(i),
|
||||
start, start + size, param_));
|
||||
shard->InitRowPtrs(batch);
|
||||
});
|
||||
|
||||
@@ -979,7 +978,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
for (auto& shard : shards_) {
|
||||
auto d_node_hist = shard->hist.GetHistPtr(nidx);
|
||||
reducer_.AllReduceSum(
|
||||
shard->normalised_device_idx,
|
||||
dist_.Devices().Index(shard->device_id_),
|
||||
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
|
||||
n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)));
|
||||
@@ -1050,7 +1049,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
// FIXME: Multi-gpu support?
|
||||
// Use first device
|
||||
auto& shard = shards_.front();
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id_));
|
||||
shard->temp_memory.LazyAllocate(candidates_size_bytes);
|
||||
auto d_split = shard->temp_memory.Pointer<DeviceSplitCandidate>();
|
||||
|
||||
@@ -1063,7 +1062,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
int depth = p_tree->GetDepth(nidx);
|
||||
|
||||
HostDeviceVector<int>& feature_set = column_sampler_.GetFeatureSet(depth);
|
||||
feature_set.Reshard(GPUSet::Range(shard->device_idx, 1));
|
||||
feature_set.Reshard(GPUSet::Range(shard->device_id_, 1));
|
||||
auto& h_feature_set = feature_set.HostVector();
|
||||
// One block for each feature
|
||||
int constexpr BLOCK_THREADS = 256;
|
||||
@@ -1071,7 +1070,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
<<<uint32_t(feature_set.Size()), BLOCK_THREADS, 0, streams[i]>>>(
|
||||
shard->hist.GetHistPtr(nidx),
|
||||
info_->num_col_,
|
||||
feature_set.DevicePointer(shard->device_idx),
|
||||
feature_set.DevicePointer(shard->device_id_),
|
||||
node,
|
||||
shard->cut_.feature_segments.Data(),
|
||||
shard->cut_.min_fvalue.Data(),
|
||||
@@ -1105,7 +1104,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
std::vector<GradientPair> tmp_sums(shards_.size());
|
||||
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_idx));
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id_));
|
||||
tmp_sums[i] =
|
||||
dh::SumReduction(shard->temp_memory, shard->gpair.Data(),
|
||||
shard->gpair.Size());
|
||||
@@ -1265,7 +1264,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
return false;
|
||||
p_out_preds->Reshard(dist_.Devices());
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx));
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer(shard->device_id_));
|
||||
});
|
||||
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
|
||||
return true;
|
||||
@@ -1336,6 +1336,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
common::Monitor monitor_;
|
||||
dh::AllReducer reducer_;
|
||||
std::vector<ValueConstraint> node_value_constraints_;
|
||||
/*! List storing device id. */
|
||||
std::vector<int> device_list_;
|
||||
|
||||
DMatrix* p_last_fmat_;
|
||||
|
||||
Reference in New Issue
Block a user