remove device shards (#4867)
This commit is contained in:
@@ -195,77 +195,52 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
struct DeviceShard {
|
||||
DeviceShard() : device_{-1} {}
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group_ = model.param.num_output_group;
|
||||
}
|
||||
|
||||
~DeviceShard() {
|
||||
if (device_ >= 0) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
void Init(int device) {
|
||||
this->device_ = device;
|
||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_);
|
||||
}
|
||||
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group_ = model.param.num_output_group;
|
||||
}
|
||||
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
};
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
@@ -285,7 +260,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
shard_.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat,
|
||||
@@ -301,7 +276,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
batch.offset.SetDevice(device_);
|
||||
batch.data.SetDevice(device_);
|
||||
shard_.PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
|
||||
PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.param.num_output_group;
|
||||
}
|
||||
|
||||
@@ -309,14 +284,20 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
GPUPredictor() : device_{-1} {};
|
||||
GPUPredictor() : device_{-1} {}
|
||||
|
||||
~GPUPredictor() override {
|
||||
if (device_ >= 0) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
int device = learner_param_->gpu_id;
|
||||
CHECK_GE(device, 0);
|
||||
ConfigureShard(device);
|
||||
ConfigureDevice(device);
|
||||
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
return;
|
||||
@@ -433,22 +414,29 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
int device = learner_param_->gpu_id;
|
||||
if (device >= 0) {
|
||||
ConfigureShard(device);
|
||||
ConfigureDevice(device);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Reconfigure the shard when GPU is changed. */
|
||||
void ConfigureShard(int device) {
|
||||
/*! \brief Reconfigure the device when GPU is changed. */
|
||||
void ConfigureDevice(int device) {
|
||||
if (device_ == device) return;
|
||||
|
||||
device_ = device;
|
||||
shard_.Init(device_);
|
||||
if (device_ >= 0) {
|
||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(device_);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceShard shard_;
|
||||
int device_;
|
||||
common::Monitor monitor_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
Reference in New Issue
Block a user