From 556a83022db51688c585ebb619cc52bfa85d8123 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 17 Apr 2021 00:29:34 +0800 Subject: [PATCH] Implement unified update prediction cache for (gpu_)hist. (#6860) * Implement utilites for linalg. * Unify the update prediction cache functions. * Implement update prediction cache for multi-class gpu hist. --- include/xgboost/linalg.h | 105 ++++++++++++++++++++++++++++++ include/xgboost/tree_updater.h | 11 +--- src/gbm/gbtree.cc | 59 +++++++++++++---- src/gbm/gbtree.cu | 18 +++++ src/tree/updater_gpu_hist.cu | 34 +++++----- src/tree/updater_quantile_hist.cc | 27 ++------ src/tree/updater_quantile_hist.h | 10 +-- tests/cpp/common/test_linalg.cc | 38 +++++++++++ tests/cpp/test_learner.cc | 7 ++ tests/cpp/tree/test_gpu_hist.cu | 5 +- 10 files changed, 246 insertions(+), 68 deletions(-) create mode 100644 include/xgboost/linalg.h create mode 100644 tests/cpp/common/test_linalg.cc diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h new file mode 100644 index 000000000..bdc4dbe60 --- /dev/null +++ b/include/xgboost/linalg.h @@ -0,0 +1,105 @@ +/*! + * Copyright 2021 by Contributors + * \file linalg.h + * \brief Linear algebra related utilities. + */ +#ifndef XGBOOST_LINALG_H_ +#define XGBOOST_LINALG_H_ + +#include +#include +#include + +#include +#include +#include + +namespace xgboost { +/*! + * \brief A veiw over a matrix on contigious storage. + * + * \tparam T data type of matrix + */ +template class MatrixView { + int32_t device_; + common::Span values_; + size_t strides_[2]; + size_t shape_[2]; + + template static auto InferValues(Vec *vec, int32_t device) { + return device == GenericParameter::kCpuId ? vec->HostSpan() + : vec->DeviceSpan(); + } + + public: + /*! + * \param vec storage. + * \param strides Strides for matrix. + * \param shape Rows anc columns. + * \param device Where the data is stored in. + */ + MatrixView(HostDeviceVector *vec, std::array strides, + std::array shape, int32_t device) + : device_{device}, values_{InferValues(vec, device)} { + std::copy(strides.cbegin(), strides.cend(), strides_); + std::copy(shape.cbegin(), shape.cend(), shape_); + } + MatrixView(HostDeviceVector> const *vec, + std::array strides, std::array shape, + int32_t device) + : device_{device}, values_{InferValues(vec, device)} { + std::copy(strides.cbegin(), strides.cend(), strides_); + std::copy(shape.cbegin(), shape.cend(), shape_); + } + /*! \brief Row major constructor. */ + MatrixView(HostDeviceVector *vec, std::array shape, + int32_t device) + : device_{device}, values_{InferValues(vec, device)} { + std::copy(shape.cbegin(), shape.cend(), shape_); + strides_[0] = shape[1]; + strides_[1] = 1; + } + MatrixView(HostDeviceVector> const *vec, + std::array shape, int32_t device) + : device_{device}, values_{InferValues(vec, device)} { + std::copy(shape.cbegin(), shape.cend(), shape_); + strides_[0] = shape[1]; + strides_[1] = 1; + } + + XGBOOST_DEVICE T const &operator()(size_t r, size_t c) const { + return values_[strides_[0] * r + strides_[1] * c]; + } + XGBOOST_DEVICE T &operator()(size_t r, size_t c) { + return values_[strides_[0] * r + strides_[1] * c]; + } + + auto Strides() const { return strides_; } + auto Shape() const { return shape_; } + auto Values() const { return values_; } + auto Size() const { return shape_[0] * shape_[1]; } + auto DeviceIdx() const { return device_; } +}; + +/*! \brief A slice for 1 column of MatrixView. Can be extended to row if needed. */ +template class VectorView { + MatrixView matrix_; + size_t column_; + + public: + explicit VectorView(MatrixView matrix, size_t column) + : matrix_{std::move(matrix)}, column_{column} {} + + XGBOOST_DEVICE T &operator[](size_t i) { + return matrix_(i, column_); + } + + XGBOOST_DEVICE T const &operator[](size_t i) const { + return matrix_(i, column_); + } + + size_t Size() { return matrix_.Shape()[0]; } + int32_t DeviceIdx() const { return matrix_.DeviceIdx(); } +}; +} // namespace xgboost +#endif // XGBOOST_LINALG_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 5f57cd353..f36005a9a 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -70,14 +71,8 @@ class TreeUpdater : public Configurable { * the prediction cache. If true, the prediction cache will have been * updated by the time this function returns. */ - virtual bool UpdatePredictionCache(const DMatrix* /*data*/, - HostDeviceVector* /*out_preds*/) { - return false; - } - - virtual bool UpdatePredictionCacheMulticlass(const DMatrix* /*data*/, - HostDeviceVector* /*out_preds*/, - const int /*gid*/, const int /*ngroup*/) { + virtual bool UpdatePredictionCache(const DMatrix * /*data*/, + VectorView /*out_preds*/) { return false; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 9fc522e03..cc7e1d9d6 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -190,6 +190,32 @@ void GBTree::ConfigureUpdaters() { } } +void GPUCopyGradient(HostDeviceVector const *in_gpair, + bst_group_t n_groups, bst_group_t group_id, + HostDeviceVector *out_gpair) +#if defined(XGBOOST_USE_CUDA) +; // NOLINT +#else +{ + common::AssertGPUSupport(); +} +#endif + +void CopyGradient(HostDeviceVector const *in_gpair, + bst_group_t n_groups, bst_group_t group_id, + HostDeviceVector *out_gpair) { + if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) { + GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); + } else { + std::vector &tmp_h = out_gpair->HostVector(); + auto nsize = static_cast(out_gpair->Size()); + const auto &gpair_h = in_gpair->ConstHostVector(); + common::ParallelFor(nsize, [&](bst_omp_uint i) { + tmp_h[i] = gpair_h[i * n_groups + group_id]; + }); + } +} + void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry* predt) { @@ -197,39 +223,44 @@ void GBTree::DoBoost(DMatrix* p_fmat, const int ngroup = model_.learner_model_param->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); - auto* out = &predt->predictions; + // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let + // `gpu_id` be the single source of determining what algorithms to run, but that will + // break a lots of existing code. + auto device = tparam_.tree_method != TreeMethod::kGPUHist + ? GenericParameter::kCpuId + : in_gpair->DeviceIdx(); + auto out = MatrixView( + &predt->predictions, + {p_fmat->Info().num_row_, static_cast(ngroup)}, device); CHECK_NE(ngroup, 0); if (ngroup == 1) { - std::vector > ret; + std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 && - updaters_.back()->UpdatePredictionCache(p_fmat, out)) { + auto v_predt = VectorView{out, 0}; + if (updaters_.size() > 0 && num_new_trees == 1 && + predt->predictions.Size() > 0 && + updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) { predt->Update(1); } } else { CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup * nrow gpairs"; - // TODO(canonizer): perform this on GPU if HostDeviceVector has device set. HostDeviceVector tmp(in_gpair->Size() / ngroup, GradientPair(), in_gpair->DeviceIdx()); - const auto& gpair_h = in_gpair->ConstHostVector(); - auto nsize = static_cast(tmp.Size()); bool update_predict = true; for (int gid = 0; gid < ngroup; ++gid) { - std::vector& tmp_h = tmp.HostVector(); - common::ParallelFor(nsize, [&](bst_omp_uint i) { - tmp_h[i] = gpair_h[i * ngroup + gid]; - }); + CopyGradient(in_gpair, ngroup, gid, &tmp); std::vector > ret; BoostNewTrees(&tmp, p_fmat, gid, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - auto* out = &predt->predictions; - if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 && - updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) { + auto v_predt = VectorView{out, static_cast(gid)}; + if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && + num_new_trees == 1 && + updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) { update_predict = false; } } diff --git a/src/gbm/gbtree.cu b/src/gbm/gbtree.cu index 52ac90501..33bca68c3 100644 --- a/src/gbm/gbtree.cu +++ b/src/gbm/gbtree.cu @@ -2,10 +2,28 @@ * Copyright 2021 by Contributors */ #include "xgboost/span.h" +#include "xgboost/generic_parameters.h" +#include "xgboost/linalg.h" #include "../common/device_helpers.cuh" namespace xgboost { namespace gbm { + +void GPUCopyGradient(HostDeviceVector const *in_gpair, + bst_group_t n_groups, bst_group_t group_id, + HostDeviceVector *out_gpair) { + MatrixView in{ + in_gpair, + {n_groups, 1ul}, + {in_gpair->Size() / n_groups, static_cast(n_groups)}, + in_gpair->DeviceIdx()}; + auto v_in = VectorView{in, group_id}; + out_gpair->Resize(v_in.Size()); + auto d_out = out_gpair->DeviceSpan(); + dh::LaunchN(dh::CurrentDevice(), v_in.Size(), + [=] __device__(size_t i) { d_out[i] = v_in[i]; }); +} + void GPUDartPredictInc(common::Span out_predts, common::Span predts, float tree_w, size_t n_rows, bst_group_t n_groups, bst_group_t group) { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 8b46f7468..6ac8a92f3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -273,9 +273,9 @@ struct GPUHistMakerDevice { if (d_gpair.size() != dh_gpair->Size()) { d_gpair.resize(dh_gpair->Size()); } - thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(), - dh_gpair->ConstDevicePointer() + dh_gpair->Size(), - d_gpair.begin()); + dh::safe_cuda(cudaMemcpyAsync( + d_gpair.data().get(), dh_gpair->ConstDevicePointer(), + dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat); page = sample.page; gpair = sample.gpair; @@ -528,8 +528,9 @@ struct GPUHistMakerDevice { }); } - void UpdatePredictionCache(common::Span out_preds_d) { + void UpdatePredictionCache(VectorView out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); + CHECK_EQ(out_preds_d.DeviceIdx(), device_id); auto d_ridx = row_partitioner->GetRows(); GPUTrainingParam param_d(param); @@ -543,14 +544,14 @@ struct GPUHistMakerDevice { auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto evaluator = tree_evaluator.GetEvaluator(); - dh::LaunchN( - device_id, out_preds_d.size(), [=] __device__(int local_idx) { - int pos = d_position[local_idx]; - bst_float weight = evaluator.CalcWeight(pos, param_d, - GradStats{d_node_sum_gradients[pos]}); - out_preds_d[d_ridx[local_idx]] += - weight * param_d.learning_rate; - }); + dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) { + int pos = d_position[local_idx]; + bst_float weight = evaluator.CalcWeight( + pos, param_d, GradStats{d_node_sum_gradients[pos]}); + static_assert(!std::is_const::value, ""); + auto v_predt = out_preds_d; // for some reaon out_preds_d is const by both nvcc and clang. + v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate; + }); row_partitioner.reset(); } @@ -862,13 +863,12 @@ class GPUHistMakerSpecialised { maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); } - bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* p_out_preds) { + bool UpdatePredictionCache(const DMatrix* data, VectorView p_out_preds) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } monitor_.Start("UpdatePredictionCache"); - p_out_preds->SetDevice(device_); - maker->UpdatePredictionCache(p_out_preds->DeviceSpan()); + maker->UpdatePredictionCache(p_out_preds); monitor_.Stop("UpdatePredictionCache"); return true; } @@ -947,8 +947,8 @@ class GPUHistMaker : public TreeUpdater { } } - bool UpdatePredictionCache( - const DMatrix* data, HostDeviceVector* p_out_preds) override { + bool UpdatePredictionCache(const DMatrix *data, + VectorView p_out_preds) override { if (hist_maker_param_.single_precision_histogram) { return float_maker_->UpdatePredictionCache(data, p_out_preds); } else { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index d79084f92..2bd09875b 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -110,7 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, } bool QuantileHistMaker::UpdatePredictionCache( - const DMatrix* data, HostDeviceVector* out_preds) { + const DMatrix* data, VectorView out_preds) { if (hist_maker_param_.single_precision_histogram && float_builder_) { return float_builder_->UpdatePredictionCache(data, out_preds); } else if (double_builder_) { @@ -120,19 +120,6 @@ bool QuantileHistMaker::UpdatePredictionCache( } } -bool QuantileHistMaker::UpdatePredictionCacheMulticlass( - const DMatrix* data, - HostDeviceVector* out_preds, const int gid, const int ngroup) { - if (hist_maker_param_.single_precision_histogram && float_builder_) { - return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup); - } else if (double_builder_) { - return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup); - } else { - return false; - } -} - - template void BatchHistSynchronizer::SyncHistograms(BuilderT *builder, int, @@ -629,7 +616,7 @@ void QuantileHistMaker::Builder::Update( template bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, - HostDeviceVector* p_out_preds, const int gid, const int ngroup) { + VectorView out_preds) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || @@ -638,16 +625,14 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } builder_monitor_.Start("UpdatePredictionCache"); - std::vector& out_preds = p_out_preds->HostVector(); - - CHECK_GT(out_preds.size(), 0U); + CHECK_GT(out_preds.Size(), 0U); size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); common::BlockedSpace2d space(n_nodes, [&](size_t node) { return row_set_collection_[node].Size(); }, 1024); - + CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { const RowSetCollection::Elem rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr) { @@ -664,7 +649,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( leaf_value = (*p_last_tree_)[nid].LeafValue(); for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { - out_preds[*it * ngroup + gid] += leaf_value; + out_preds[*it] += leaf_value; } } }); @@ -687,7 +672,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( const size_t row_num = unused_rows_[block_id] + batch.base_rowid; const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex(feats) : p_last_tree_->GetLeafIndex(feats); - out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue(); + out_preds[row_num] += (*p_last_tree_)[lid].LeafValue(); feats.Drop(inst); }); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 36019beae..7d6c5db97 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -118,11 +118,8 @@ class QuantileHistMaker: public TreeUpdater { DMatrix* dmat, const std::vector& trees) override; - bool UpdatePredictionCache(const DMatrix* data, - HostDeviceVector* out_preds) override; - bool UpdatePredictionCacheMulticlass(const DMatrix* data, - HostDeviceVector* out_preds, - const int gid, const int ngroup) override; + bool UpdatePredictionCache(const DMatrix *data, + VectorView out_preds) override; void LoadConfig(Json const& in) override { auto const& config = get(in); @@ -245,8 +242,7 @@ class QuantileHistMaker: public TreeUpdater { } bool UpdatePredictionCache(const DMatrix* data, - HostDeviceVector* p_out_preds, - const int gid = 0, const int ngroup = 1); + VectorView out_preds); void SetHistSynchronizer(HistSynchronizer* sync); void SetHistRowsAdder(HistRowsAdder* adder); diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc new file mode 100644 index 000000000..935e82dde --- /dev/null +++ b/tests/cpp/common/test_linalg.cc @@ -0,0 +1,38 @@ +#include +#include +#include + +namespace xgboost { + +auto MakeMatrixFromTest(HostDeviceVector *storage, size_t n_rows, size_t n_cols) { + storage->Resize(n_rows * n_cols); + auto& h_storage = storage->HostVector(); + + std::iota(h_storage.begin(), h_storage.end(), 0); + + auto m = MatrixView{storage, {n_cols, 1}, {n_rows, n_cols}, -1}; + return m; + +} + +TEST(Linalg, Matrix) { + size_t kRows = 31, kCols = 77; + HostDeviceVector storage; + auto m = MakeMatrixFromTest(&storage, kRows, kCols); + ASSERT_EQ(m.DeviceIdx(), GenericParameter::kCpuId); + ASSERT_EQ(m(0, 0), 0); + ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1); +} + +TEST(Linalg, Vector) { + size_t kRows = 31, kCols = 77; + HostDeviceVector storage; + auto m = MakeMatrixFromTest(&storage, kRows, kCols); + auto v = VectorView(m, 3); + for (size_t i = 0; i < v.Size(); ++i) { + ASSERT_EQ(v[i], m(i, 3)); + } + + ASSERT_EQ(v[0], 3); +} +} // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 703af54f2..ce910efed 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -294,6 +294,13 @@ TEST(Learner, GPUConfiguration) { learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); } + { + std::unique_ptr learner {Learner::Create(mat)}; + learner->SetParams({Arg{"tree_method", "gpu_hist"}, + Arg{"gpu_id", "-1"}}); + learner->UpdateOneIter(0, p_dmat); + ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); + } { // with CPU algorithm std::unique_ptr learner {Learner::Create(mat)}; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 5d4f6e864..5f71a67a7 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -390,7 +390,10 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, hist_maker.Configure(args, &generic_param); hist_maker.Update(gpair, dmat, {tree}); - hist_maker.UpdatePredictionCache(dmat, preds); + hist_maker.UpdatePredictionCache( + dmat, + VectorView{ + MatrixView(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0}); } TEST(GpuHist, UniformSampling) {