Overload device memory allocation (#4532)
* Group source files, include headers in source files * Overload device memory allocation
This commit is contained in:
@@ -261,15 +261,15 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments_), h_tree_segments.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group_), model.tree_info.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
@@ -306,9 +306,9 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
private:
|
||||
int device_;
|
||||
thrust::device_vector<DevicePredictionNode> nodes_;
|
||||
thrust::device_vector<size_t> tree_segments_;
|
||||
thrust::device_vector<int> tree_group_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
@@ -373,7 +373,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
GPUPredictor()
|
||||
GPUPredictor() // NOLINT
|
||||
: cpu_predictor_(Predictor::Create("cpu_predictor", learner_param_)) {}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
|
||||
Reference in New Issue
Block a user