Implement unified update prediction cache for (gpu_)hist. (#6860)

* Implement utilites for linalg.
* Unify the update prediction cache functions.
* Implement update prediction cache for multi-class gpu hist.
This commit is contained in:
Jiaming Yuan 2021-04-17 00:29:34 +08:00 committed by GitHub
parent 1b26a2a561
commit 556a83022d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 246 additions and 68 deletions

105
include/xgboost/linalg.h Normal file
View File

@ -0,0 +1,105 @@
/*!
* Copyright 2021 by Contributors
* \file linalg.h
* \brief Linear algebra related utilities.
*/
#ifndef XGBOOST_LINALG_H_
#define XGBOOST_LINALG_H_
#include <xgboost/span.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/generic_parameters.h>
#include <array>
#include <algorithm>
#include <utility>
namespace xgboost {
/*!
* \brief A veiw over a matrix on contigious storage.
*
* \tparam T data type of matrix
*/
template <typename T> class MatrixView {
int32_t device_;
common::Span<T> values_;
size_t strides_[2];
size_t shape_[2];
template <typename Vec> static auto InferValues(Vec *vec, int32_t device) {
return device == GenericParameter::kCpuId ? vec->HostSpan()
: vec->DeviceSpan();
}
public:
/*!
* \param vec storage.
* \param strides Strides for matrix.
* \param shape Rows anc columns.
* \param device Where the data is stored in.
*/
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> strides,
std::array<size_t, 2> shape, int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(strides.cbegin(), strides.cend(), strides_);
std::copy(shape.cbegin(), shape.cend(), shape_);
}
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
std::array<size_t, 2> strides, std::array<size_t, 2> shape,
int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(strides.cbegin(), strides.cend(), strides_);
std::copy(shape.cbegin(), shape.cend(), shape_);
}
/*! \brief Row major constructor. */
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> shape,
int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(shape.cbegin(), shape.cend(), shape_);
strides_[0] = shape[1];
strides_[1] = 1;
}
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
std::array<size_t, 2> shape, int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(shape.cbegin(), shape.cend(), shape_);
strides_[0] = shape[1];
strides_[1] = 1;
}
XGBOOST_DEVICE T const &operator()(size_t r, size_t c) const {
return values_[strides_[0] * r + strides_[1] * c];
}
XGBOOST_DEVICE T &operator()(size_t r, size_t c) {
return values_[strides_[0] * r + strides_[1] * c];
}
auto Strides() const { return strides_; }
auto Shape() const { return shape_; }
auto Values() const { return values_; }
auto Size() const { return shape_[0] * shape_[1]; }
auto DeviceIdx() const { return device_; }
};
/*! \brief A slice for 1 column of MatrixView. Can be extended to row if needed. */
template <typename T> class VectorView {
MatrixView<T> matrix_;
size_t column_;
public:
explicit VectorView(MatrixView<T> matrix, size_t column)
: matrix_{std::move(matrix)}, column_{column} {}
XGBOOST_DEVICE T &operator[](size_t i) {
return matrix_(i, column_);
}
XGBOOST_DEVICE T const &operator[](size_t i) const {
return matrix_(i, column_);
}
size_t Size() { return matrix_.Shape()[0]; }
int32_t DeviceIdx() const { return matrix_.DeviceIdx(); }
};
} // namespace xgboost
#endif // XGBOOST_LINALG_H_

View File

@ -15,6 +15,7 @@
#include <xgboost/generic_parameters.h> #include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/model.h> #include <xgboost/model.h>
#include <xgboost/linalg.h>
#include <functional> #include <functional>
#include <vector> #include <vector>
@ -70,14 +71,8 @@ class TreeUpdater : public Configurable {
* the prediction cache. If true, the prediction cache will have been * the prediction cache. If true, the prediction cache will have been
* updated by the time this function returns. * updated by the time this function returns.
*/ */
virtual bool UpdatePredictionCache(const DMatrix* /*data*/, virtual bool UpdatePredictionCache(const DMatrix * /*data*/,
HostDeviceVector<bst_float>* /*out_preds*/) { VectorView<float> /*out_preds*/) {
return false;
}
virtual bool UpdatePredictionCacheMulticlass(const DMatrix* /*data*/,
HostDeviceVector<bst_float>* /*out_preds*/,
const int /*gid*/, const int /*ngroup*/) {
return false; return false;
} }

View File

@ -190,6 +190,32 @@ void GBTree::ConfigureUpdaters() {
} }
} }
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id,
HostDeviceVector<GradientPair> *out_gpair)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
}
#endif
void CopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id,
HostDeviceVector<GradientPair> *out_gpair) {
if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) {
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
} else {
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
const auto &gpair_h = in_gpair->ConstHostVector();
common::ParallelFor(nsize, [&](bst_omp_uint i) {
tmp_h[i] = gpair_h[i * n_groups + group_id];
});
}
}
void GBTree::DoBoost(DMatrix* p_fmat, void GBTree::DoBoost(DMatrix* p_fmat,
HostDeviceVector<GradientPair>* in_gpair, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry* predt) { PredictionCacheEntry* predt) {
@ -197,39 +223,44 @@ void GBTree::DoBoost(DMatrix* p_fmat,
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
ConfigureWithKnownData(this->cfg_, p_fmat); ConfigureWithKnownData(this->cfg_, p_fmat);
monitor_.Start("BoostNewTrees"); monitor_.Start("BoostNewTrees");
auto* out = &predt->predictions; // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
// `gpu_id` be the single source of determining what algorithms to run, but that will
// break a lots of existing code.
auto device = tparam_.tree_method != TreeMethod::kGPUHist
? GenericParameter::kCpuId
: in_gpair->DeviceIdx();
auto out = MatrixView<float>(
&predt->predictions,
{p_fmat->Info().num_row_, static_cast<size_t>(ngroup)}, device);
CHECK_NE(ngroup, 0); CHECK_NE(ngroup, 0);
if (ngroup == 1) { if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret; std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret); BoostNewTrees(in_gpair, p_fmat, 0, &ret);
const size_t num_new_trees = ret.size(); const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret)); new_trees.push_back(std::move(ret));
if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 && auto v_predt = VectorView<float>{out, 0};
updaters_.back()->UpdatePredictionCache(p_fmat, out)) { if (updaters_.size() > 0 && num_new_trees == 1 &&
predt->predictions.Size() > 0 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) {
predt->Update(1); predt->Update(1);
} }
} else { } else {
CHECK_EQ(in_gpair->Size() % ngroup, 0U) CHECK_EQ(in_gpair->Size() % ngroup, 0U)
<< "must have exactly ngroup * nrow gpairs"; << "must have exactly ngroup * nrow gpairs";
// TODO(canonizer): perform this on GPU if HostDeviceVector has device set.
HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup, HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup,
GradientPair(), GradientPair(),
in_gpair->DeviceIdx()); in_gpair->DeviceIdx());
const auto& gpair_h = in_gpair->ConstHostVector();
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
bool update_predict = true; bool update_predict = true;
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
std::vector<GradientPair>& tmp_h = tmp.HostVector(); CopyGradient(in_gpair, ngroup, gid, &tmp);
common::ParallelFor(nsize, [&](bst_omp_uint i) {
tmp_h[i] = gpair_h[i * ngroup + gid];
});
std::vector<std::unique_ptr<RegTree> > ret; std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(&tmp, p_fmat, gid, &ret); BoostNewTrees(&tmp, p_fmat, gid, &ret);
const size_t num_new_trees = ret.size(); const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret)); new_trees.push_back(std::move(ret));
auto* out = &predt->predictions; auto v_predt = VectorView<float>{out, static_cast<size_t>(gid)};
if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 && if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 &&
updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) { num_new_trees == 1 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) {
update_predict = false; update_predict = false;
} }
} }

View File

@ -2,10 +2,28 @@
* Copyright 2021 by Contributors * Copyright 2021 by Contributors
*/ */
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
namespace gbm { namespace gbm {
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id,
HostDeviceVector<GradientPair> *out_gpair) {
MatrixView<GradientPair const> in{
in_gpair,
{n_groups, 1ul},
{in_gpair->Size() / n_groups, static_cast<size_t>(n_groups)},
in_gpair->DeviceIdx()};
auto v_in = VectorView<GradientPair const>{in, group_id};
out_gpair->Resize(v_in.Size());
auto d_out = out_gpair->DeviceSpan();
dh::LaunchN(dh::CurrentDevice(), v_in.Size(),
[=] __device__(size_t i) { d_out[i] = v_in[i]; });
}
void GPUDartPredictInc(common::Span<float> out_predts, void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows, common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group) { bst_group_t n_groups, bst_group_t group) {

View File

@ -273,9 +273,9 @@ struct GPUHistMakerDevice {
if (d_gpair.size() != dh_gpair->Size()) { if (d_gpair.size() != dh_gpair->Size()) {
d_gpair.resize(dh_gpair->Size()); d_gpair.resize(dh_gpair->Size());
} }
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(), dh::safe_cuda(cudaMemcpyAsync(
dh_gpair->ConstDevicePointer() + dh_gpair->Size(), d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
d_gpair.begin()); dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat); auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
page = sample.page; page = sample.page;
gpair = sample.gpair; gpair = sample.gpair;
@ -528,8 +528,9 @@ struct GPUHistMakerDevice {
}); });
} }
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) { void UpdatePredictionCache(VectorView<float> out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
auto d_ridx = row_partitioner->GetRows(); auto d_ridx = row_partitioner->GetRows();
GPUTrainingParam param_d(param); GPUTrainingParam param_d(param);
@ -543,14 +544,14 @@ struct GPUHistMakerDevice {
auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
dh::LaunchN( dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) {
device_id, out_preds_d.size(), [=] __device__(int local_idx) { int pos = d_position[local_idx];
int pos = d_position[local_idx]; bst_float weight = evaluator.CalcWeight(
bst_float weight = evaluator.CalcWeight(pos, param_d, pos, param_d, GradStats{d_node_sum_gradients[pos]});
GradStats{d_node_sum_gradients[pos]}); static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
out_preds_d[d_ridx[local_idx]] += auto v_predt = out_preds_d; // for some reaon out_preds_d is const by both nvcc and clang.
weight * param_d.learning_rate; v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
}); });
row_partitioner.reset(); row_partitioner.reset();
} }
@ -862,13 +863,12 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
} }
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) { bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false; return false;
} }
monitor_.Start("UpdatePredictionCache"); monitor_.Start("UpdatePredictionCache");
p_out_preds->SetDevice(device_); maker->UpdatePredictionCache(p_out_preds);
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
monitor_.Stop("UpdatePredictionCache"); monitor_.Stop("UpdatePredictionCache");
return true; return true;
} }
@ -947,8 +947,8 @@ class GPUHistMaker : public TreeUpdater {
} }
} }
bool UpdatePredictionCache( bool UpdatePredictionCache(const DMatrix *data,
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override { VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds); return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else { } else {

View File

@ -110,7 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
} }
bool QuantileHistMaker::UpdatePredictionCache( bool QuantileHistMaker::UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) { const DMatrix* data, VectorView<float> out_preds) {
if (hist_maker_param_.single_precision_histogram && float_builder_) { if (hist_maker_param_.single_precision_histogram && float_builder_) {
return float_builder_->UpdatePredictionCache(data, out_preds); return float_builder_->UpdatePredictionCache(data, out_preds);
} else if (double_builder_) { } else if (double_builder_) {
@ -120,19 +120,6 @@ bool QuantileHistMaker::UpdatePredictionCache(
} }
} }
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
const DMatrix* data,
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
if (hist_maker_param_.single_precision_histogram && float_builder_) {
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else {
return false;
}
}
template <typename GradientSumT> template <typename GradientSumT>
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder, void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
int, int,
@ -629,7 +616,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
template<typename GradientSumT> template<typename GradientSumT>
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache( bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
const DMatrix* data, const DMatrix* data,
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) { VectorView<float> out_preds) {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update(). // conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
@ -638,16 +625,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
} }
builder_monitor_.Start("UpdatePredictionCache"); builder_monitor_.Start("UpdatePredictionCache");
std::vector<bst_float>& out_preds = p_out_preds->HostVector(); CHECK_GT(out_preds.Size(), 0U);
CHECK_GT(out_preds.size(), 0U);
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
common::BlockedSpace2d space(n_nodes, [&](size_t node) { common::BlockedSpace2d space(n_nodes, [&](size_t node) {
return row_set_collection_[node].Size(); return row_set_collection_[node].Size();
}, 1024); }, 1024);
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const RowSetCollection::Elem rowset = row_set_collection_[node]; const RowSetCollection::Elem rowset = row_set_collection_[node];
if (rowset.begin != nullptr && rowset.end != nullptr) { if (rowset.begin != nullptr && rowset.end != nullptr) {
@ -664,7 +649,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
leaf_value = (*p_last_tree_)[nid].LeafValue(); leaf_value = (*p_last_tree_)[nid].LeafValue();
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds[*it * ngroup + gid] += leaf_value; out_preds[*it] += leaf_value;
} }
} }
}); });
@ -687,7 +672,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
const size_t row_num = unused_rows_[block_id] + batch.base_rowid; const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) : const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
p_last_tree_->GetLeafIndex<false>(feats); p_last_tree_->GetLeafIndex<false>(feats);
out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue(); out_preds[row_num] += (*p_last_tree_)[lid].LeafValue();
feats.Drop(inst); feats.Drop(inst);
}); });

View File

@ -118,11 +118,8 @@ class QuantileHistMaker: public TreeUpdater {
DMatrix* dmat, DMatrix* dmat,
const std::vector<RegTree*>& trees) override; const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix *data,
HostDeviceVector<bst_float>* out_preds) override; VectorView<float> out_preds) override;
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds,
const int gid, const int ngroup) override;
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in); auto const& config = get<Object const>(in);
@ -245,8 +242,7 @@ class QuantileHistMaker: public TreeUpdater {
} }
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* p_out_preds, VectorView<float> out_preds);
const int gid = 0, const int ngroup = 1);
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync); void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder); void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);

View File

@ -0,0 +1,38 @@
#include <gtest/gtest.h>
#include <xgboost/linalg.h>
#include <numeric>
namespace xgboost {
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) {
storage->Resize(n_rows * n_cols);
auto& h_storage = storage->HostVector();
std::iota(h_storage.begin(), h_storage.end(), 0);
auto m = MatrixView<float>{storage, {n_cols, 1}, {n_rows, n_cols}, -1};
return m;
}
TEST(Linalg, Matrix) {
size_t kRows = 31, kCols = 77;
HostDeviceVector<float> storage;
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
ASSERT_EQ(m.DeviceIdx(), GenericParameter::kCpuId);
ASSERT_EQ(m(0, 0), 0);
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
}
TEST(Linalg, Vector) {
size_t kRows = 31, kCols = 77;
HostDeviceVector<float> storage;
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
auto v = VectorView<float>(m, 3);
for (size_t i = 0; i < v.Size(); ++i) {
ASSERT_EQ(v[i], m(i, 3));
}
ASSERT_EQ(v[0], 3);
}
} // namespace xgboost

View File

@ -294,6 +294,13 @@ TEST(Learner, GPUConfiguration) {
learner->UpdateOneIter(0, p_dmat); learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
} }
{
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "gpu_hist"},
Arg{"gpu_id", "-1"}});
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}
{ {
// with CPU algorithm // with CPU algorithm
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};

View File

@ -390,7 +390,10 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
hist_maker.Configure(args, &generic_param); hist_maker.Configure(args, &generic_param);
hist_maker.Update(gpair, dmat, {tree}); hist_maker.Update(gpair, dmat, {tree});
hist_maker.UpdatePredictionCache(dmat, preds); hist_maker.UpdatePredictionCache(
dmat,
VectorView<float>{
MatrixView<float>(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0});
} }
TEST(GpuHist, UniformSampling) { TEST(GpuHist, UniformSampling) {