/*! * Copyright by Contributors 2017 */ #include #include #include #include #include #include #include #include #include #include #include "../common/device_helpers.cuh" #include "../common/host_device_vector.h" namespace xgboost { namespace predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); /*! \brief prediction parameters */ struct GPUPredictionParam : public dmlc::Parameter { int gpu_id; int n_gpus; bool silent; // declare parameters DMLC_DECLARE_PARAMETER(GPUPredictionParam) { DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe( "Device ordinal for GPU prediction."); DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe( "Number of devices to use for prediction (NOT IMPLEMENTED)."); DMLC_DECLARE_FIELD(silent).set_default(false).describe( "Do not print information during trainig."); } }; DMLC_REGISTER_PARAMETER(GPUPredictionParam); template void increment_offset(iter_t begin_itr, iter_t end_itr, size_t amount) { thrust::transform(begin_itr, end_itr, begin_itr, [=] __device__(size_t elem) { return elem + amount; }); } /** * \struct DeviceMatrix * * \brief A csr representation of the input matrix allocated on the device. */ struct DeviceMatrix { DMatrix* p_mat; // Pointer to the original matrix on the host dh::bulk_allocator ba; dh::dvec row_ptr; dh::dvec data; thrust::device_vector predictions; DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) { dh::safe_cuda(cudaSetDevice(device_idx)); auto info = dmat->info(); ba.allocate(device_idx, silent, &row_ptr, info.num_row + 1, &data, info.num_nonzero); auto iter = dmat->RowIterator(); iter->BeforeFirst(); size_t data_offset = 0; while (iter->Next()) { auto batch = iter->Value(); // Copy row ptr thrust::copy(batch.ind_ptr, batch.ind_ptr + batch.size + 1, row_ptr.tbegin() + batch.base_rowid); if (batch.base_rowid > 0) { auto begin_itr = row_ptr.tbegin() + batch.base_rowid; auto end_itr = begin_itr + batch.size + 1; increment_offset(begin_itr, end_itr, batch.base_rowid); } // Copy data thrust::copy(batch.data_ptr, batch.data_ptr + batch.ind_ptr[batch.size], data.tbegin() + data_offset); data_offset += batch.ind_ptr[batch.size]; } } }; /** * \struct DevicePredictionNode * * \brief Packed 16 byte representation of a tree node for use in device * prediction */ struct DevicePredictionNode { XGBOOST_DEVICE DevicePredictionNode() : fidx(-1), left_child_idx(-1), right_child_idx(-1) {} union NodeValue { float leaf_weight; float fvalue; }; int fidx; int left_child_idx; int right_child_idx; NodeValue val; DevicePredictionNode(const RegTree::Node& n) { // NOLINT this->left_child_idx = n.cleft(); this->right_child_idx = n.cright(); this->fidx = n.split_index(); if (n.default_left()) { fidx |= (1U << 31); } if (n.is_leaf()) { this->val.leaf_weight = n.leaf_value(); } else { this->val.fvalue = n.split_cond(); } } XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; } XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); } XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; } XGBOOST_DEVICE int MissingIdx() const { if (MissingLeft()) { return this->left_child_idx; } else { return this->right_child_idx; } } XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; } XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; } }; struct ElementLoader { bool use_shared; size_t* d_row_ptr; SparseBatch::Entry* d_data; int num_features; float* smem; __device__ ElementLoader(bool use_shared, size_t* row_ptr, SparseBatch::Entry* entry, int num_features, float* smem, int num_rows) : use_shared(use_shared), d_row_ptr(row_ptr), d_data(entry), num_features(num_features), smem(smem) { // Copy instances if (use_shared) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; int shared_elements = blockDim.x * num_features; dh::block_fill(smem, shared_elements, nanf("")); __syncthreads(); if (global_idx < num_rows) { bst_uint elem_begin = d_row_ptr[global_idx]; bst_uint elem_end = d_row_ptr[global_idx + 1]; for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { SparseBatch::Entry elem = d_data[elem_idx]; smem[threadIdx.x * num_features + elem.index] = elem.fvalue; } } __syncthreads(); } } __device__ float GetFvalue(int ridx, int fidx) { if (use_shared) { return smem[threadIdx.x * num_features + fidx]; } else { // Binary search auto begin_ptr = d_data + d_row_ptr[ridx]; auto end_ptr = d_data + d_row_ptr[ridx + 1]; SparseBatch::Entry* previous_middle = nullptr; while (end_ptr != begin_ptr) { auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; if (middle == previous_middle) { break; } else { previous_middle = middle; } if (middle->index == fidx) { return middle->fvalue; } else if (middle->index < fidx) { begin_ptr = middle; } else { end_ptr = middle; } } // Value is missing return nanf(""); } } }; __device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree, ElementLoader* loader) { DevicePredictionNode n = tree[0]; while (!n.IsLeaf()) { float fvalue = loader->GetFvalue(ridx, n.GetFidx()); // Missing value if (isnan(fvalue)) { n = tree[n.MissingIdx()]; } else { if (fvalue < n.GetFvalue()) { n = tree[n.left_child_idx]; } else { n = tree[n.right_child_idx]; } } } return n.GetWeight(); } template __global__ void PredictKernel(const DevicePredictionNode* d_nodes, float* d_out_predictions, size_t* d_tree_segments, int* d_tree_group, size_t* d_row_ptr, SparseBatch::Entry* d_data, size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, bool use_shared, int num_group) { extern __shared__ float smem[]; bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem, num_rows); if (global_idx >= num_rows) return; if (num_group == 1) { float sum = 0; for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { const DevicePredictionNode* d_tree = d_nodes + d_tree_segments[tree_idx - tree_begin]; sum += GetLeafWeight(global_idx, d_tree, &loader); } d_out_predictions[global_idx] += sum; } else { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { int tree_group = d_tree_group[tree_idx]; const DevicePredictionNode* d_tree = d_nodes + d_tree_segments[tree_idx - tree_begin]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; d_out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, d_tree, &loader); } } } class GPUPredictor : public xgboost::Predictor { protected: struct DevicePredictionCacheEntry { std::shared_ptr data; HostDeviceVector predictions; }; private: void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { if (tree_end - tree_begin == 0) { return; } // Add dmatrix to device if not seen before if (this->device_matrix_cache_.find(dmat) == this->device_matrix_cache_.end()) { this->device_matrix_cache_.emplace( dmat, std::unique_ptr( new DeviceMatrix(dmat, param.gpu_id, param.silent))); } DeviceMatrix* device_matrix = device_matrix_cache_.find(dmat)->second.get(); dh::safe_cuda(cudaSetDevice(param.gpu_id)); CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device thrust::host_vector h_tree_segments; h_tree_segments.reserve((tree_end - tree_end) + 1); size_t sum = 0; h_tree_segments.push_back(sum); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { sum += model.trees[tree_idx]->GetNodes().size(); h_tree_segments.push_back(sum); } thrust::host_vector h_nodes(h_tree_segments.back()); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees[tree_idx]->GetNodes(); std::copy(src_nodes.begin(), src_nodes.end(), h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } nodes.resize(h_nodes.size()); thrust::copy(h_nodes.begin(), h_nodes.end(), nodes.begin()); tree_segments.resize(h_tree_segments.size()); thrust::copy(h_tree_segments.begin(), h_tree_segments.end(), tree_segments.begin()); tree_group.resize(model.tree_info.size()); thrust::copy(model.tree_info.begin(), model.tree_info.end(), tree_group.begin()); device_matrix->predictions.resize(out_preds->size()); thrust::copy(out_preds->tbegin(param.gpu_id), out_preds->tend(param.gpu_id), device_matrix->predictions.begin()); const int BLOCK_THREADS = 128; const int GRID_SIZE = static_cast( dh::div_round_up(device_matrix->row_ptr.size() - 1, BLOCK_THREADS)); int shared_memory_bytes = static_cast( sizeof(float) * device_matrix->p_mat->info().num_col * BLOCK_THREADS); bool use_shared = true; if (shared_memory_bytes > max_shared_memory_bytes) { shared_memory_bytes = 0; use_shared = false; } PredictKernel <<>>( dh::raw(nodes), dh::raw(device_matrix->predictions), dh::raw(tree_segments), dh::raw(tree_group), device_matrix->row_ptr.data(), device_matrix->data.data(), tree_begin, tree_end, device_matrix->p_mat->info().num_col, device_matrix->p_mat->info().num_row, use_shared, model.param.num_output_group); dh::safe_cuda(cudaDeviceSynchronize()); thrust::copy(device_matrix->predictions.begin(), device_matrix->predictions.end(), out_preds->tbegin(param.gpu_id)); } public: GPUPredictor() : cpu_predictor(Predictor::Create("cpu_predictor")) {} void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { return; } this->InitOutPredictions(dmat->info(), out_preds, model); int tree_end = ntree_limit * model.param.num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { tree_end = static_cast(model.trees.size()); } DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); } protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { size_t n = model.param.num_output_group * info.num_row; const std::vector& base_margin = info.base_margin; out_preds->resize(n, 0.0f, param.gpu_id); if (base_margin.size() != 0) { CHECK_EQ(out_preds->size(), n); thrust::copy(base_margin.begin(), base_margin.end(), out_preds->tbegin(param.gpu_id)); } else { thrust::fill(out_preds->tbegin(param.gpu_id), out_preds->tend(param.gpu_id), model.base_margin); } } bool PredictFromCache(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) { if (ntree_limit == 0 || ntree_limit * model.param.num_output_group >= model.trees.size()) { auto it = cache_.find(dmat); if (it != cache_.end()) { HostDeviceVector& y = it->second.predictions; if (y.size() != 0) { dh::safe_cuda(cudaSetDevice(param.gpu_id)); out_preds->resize(y.size(), 0.0f, param.gpu_id); dh::safe_cuda (cudaMemcpy(out_preds->ptr_d(param.gpu_id), y.ptr_d(param.gpu_id), out_preds->size() * sizeof(bst_float), cudaMemcpyDefault)); return true; } } } return false; } void UpdatePredictionCache(const gbm::GBTreeModel& model, std::vector>* updaters, int num_new_trees) override { auto old_ntree = model.trees.size() - num_new_trees; // update cache entry for (auto& kv : cache_) { PredictionCacheEntry& e = kv.second; DMatrix* dmat = kv.first; HostDeviceVector& predictions = e.predictions; if (predictions.size() == 0) { // ensure that the device in predictions is correct predictions.resize(0, 0.0f, param.gpu_id); cpu_predictor->PredictBatch(dmat, &predictions, model, 0, static_cast(model.trees.size())); } else if (model.param.num_output_group == 1 && updaters->size() > 0 && num_new_trees == 1 && updaters->back()->UpdatePredictionCache(e.data.get(), &predictions)) { // do nothing } else { DevicePredictInternal(dmat, &predictions, model, old_ntree, model.trees.size()); } } } void PredictInstance(const SparseBatch::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit, unsigned root_index) override { cpu_predictor->PredictInstance(inst, out_preds, model, root_index); } void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); } void PredictContribution(DMatrix* p_fmat, std::vector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit, bool approximate, int condition, unsigned condition_feature) override { cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, approximate, condition, condition_feature); } void PredictInteractionContributions(DMatrix* p_fmat, std::vector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit, bool approximate) override { cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, approximate); } void Init(const std::vector>& cfg, const std::vector>& cache) override { Predictor::Init(cfg, cache); cpu_predictor->Init(cfg, cache); param.InitAllowUnknown(cfg); max_shared_memory_bytes = dh::max_shared_memory(param.gpu_id); } private: GPUPredictionParam param; std::unique_ptr cpu_predictor; std::unordered_map> device_matrix_cache_; thrust::device_vector nodes; thrust::device_vector tree_segments; thrust::device_vector tree_group; size_t max_shared_memory_bytes; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") .describe("Make predictions using GPU.") .set_body([]() { return new GPUPredictor(); }); } // namespace predictor } // namespace xgboost