Implement unified update prediction cache for (gpu_)hist. (#6860)
* Implement utilites for linalg. * Unify the update prediction cache functions. * Implement update prediction cache for multi-class gpu hist.
This commit is contained in:
parent
1b26a2a561
commit
556a83022d
105
include/xgboost/linalg.h
Normal file
105
include/xgboost/linalg.h
Normal file
@ -0,0 +1,105 @@
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
* \file linalg.h
|
||||
* \brief Linear algebra related utilities.
|
||||
*/
|
||||
#ifndef XGBOOST_LINALG_H_
|
||||
#define XGBOOST_LINALG_H_
|
||||
|
||||
#include <xgboost/span.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
* \brief A veiw over a matrix on contigious storage.
|
||||
*
|
||||
* \tparam T data type of matrix
|
||||
*/
|
||||
template <typename T> class MatrixView {
|
||||
int32_t device_;
|
||||
common::Span<T> values_;
|
||||
size_t strides_[2];
|
||||
size_t shape_[2];
|
||||
|
||||
template <typename Vec> static auto InferValues(Vec *vec, int32_t device) {
|
||||
return device == GenericParameter::kCpuId ? vec->HostSpan()
|
||||
: vec->DeviceSpan();
|
||||
}
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \param vec storage.
|
||||
* \param strides Strides for matrix.
|
||||
* \param shape Rows anc columns.
|
||||
* \param device Where the data is stored in.
|
||||
*/
|
||||
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> strides,
|
||||
std::array<size_t, 2> shape, int32_t device)
|
||||
: device_{device}, values_{InferValues(vec, device)} {
|
||||
std::copy(strides.cbegin(), strides.cend(), strides_);
|
||||
std::copy(shape.cbegin(), shape.cend(), shape_);
|
||||
}
|
||||
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
|
||||
std::array<size_t, 2> strides, std::array<size_t, 2> shape,
|
||||
int32_t device)
|
||||
: device_{device}, values_{InferValues(vec, device)} {
|
||||
std::copy(strides.cbegin(), strides.cend(), strides_);
|
||||
std::copy(shape.cbegin(), shape.cend(), shape_);
|
||||
}
|
||||
/*! \brief Row major constructor. */
|
||||
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> shape,
|
||||
int32_t device)
|
||||
: device_{device}, values_{InferValues(vec, device)} {
|
||||
std::copy(shape.cbegin(), shape.cend(), shape_);
|
||||
strides_[0] = shape[1];
|
||||
strides_[1] = 1;
|
||||
}
|
||||
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
|
||||
std::array<size_t, 2> shape, int32_t device)
|
||||
: device_{device}, values_{InferValues(vec, device)} {
|
||||
std::copy(shape.cbegin(), shape.cend(), shape_);
|
||||
strides_[0] = shape[1];
|
||||
strides_[1] = 1;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE T const &operator()(size_t r, size_t c) const {
|
||||
return values_[strides_[0] * r + strides_[1] * c];
|
||||
}
|
||||
XGBOOST_DEVICE T &operator()(size_t r, size_t c) {
|
||||
return values_[strides_[0] * r + strides_[1] * c];
|
||||
}
|
||||
|
||||
auto Strides() const { return strides_; }
|
||||
auto Shape() const { return shape_; }
|
||||
auto Values() const { return values_; }
|
||||
auto Size() const { return shape_[0] * shape_[1]; }
|
||||
auto DeviceIdx() const { return device_; }
|
||||
};
|
||||
|
||||
/*! \brief A slice for 1 column of MatrixView. Can be extended to row if needed. */
|
||||
template <typename T> class VectorView {
|
||||
MatrixView<T> matrix_;
|
||||
size_t column_;
|
||||
|
||||
public:
|
||||
explicit VectorView(MatrixView<T> matrix, size_t column)
|
||||
: matrix_{std::move(matrix)}, column_{column} {}
|
||||
|
||||
XGBOOST_DEVICE T &operator[](size_t i) {
|
||||
return matrix_(i, column_);
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE T const &operator[](size_t i) const {
|
||||
return matrix_(i, column_);
|
||||
}
|
||||
|
||||
size_t Size() { return matrix_.Shape()[0]; }
|
||||
int32_t DeviceIdx() const { return matrix_.DeviceIdx(); }
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LINALG_H_
|
||||
@ -15,6 +15,7 @@
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/linalg.h>
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
@ -70,14 +71,8 @@ class TreeUpdater : public Configurable {
|
||||
* the prediction cache. If true, the prediction cache will have been
|
||||
* updated by the time this function returns.
|
||||
*/
|
||||
virtual bool UpdatePredictionCache(const DMatrix* /*data*/,
|
||||
HostDeviceVector<bst_float>* /*out_preds*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool UpdatePredictionCacheMulticlass(const DMatrix* /*data*/,
|
||||
HostDeviceVector<bst_float>* /*out_preds*/,
|
||||
const int /*gid*/, const int /*ngroup*/) {
|
||||
virtual bool UpdatePredictionCache(const DMatrix * /*data*/,
|
||||
VectorView<float> /*out_preds*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@ -190,6 +190,32 @@ void GBTree::ConfigureUpdaters() {
|
||||
}
|
||||
}
|
||||
|
||||
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
||||
bst_group_t n_groups, bst_group_t group_id,
|
||||
HostDeviceVector<GradientPair> *out_gpair)
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
; // NOLINT
|
||||
#else
|
||||
{
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif
|
||||
|
||||
void CopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
||||
bst_group_t n_groups, bst_group_t group_id,
|
||||
HostDeviceVector<GradientPair> *out_gpair) {
|
||||
if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) {
|
||||
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
|
||||
} else {
|
||||
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
|
||||
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
|
||||
const auto &gpair_h = in_gpair->ConstHostVector();
|
||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
||||
tmp_h[i] = gpair_h[i * n_groups + group_id];
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<GradientPair>* in_gpair,
|
||||
PredictionCacheEntry* predt) {
|
||||
@ -197,39 +223,44 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
const int ngroup = model_.learner_model_param->num_output_group;
|
||||
ConfigureWithKnownData(this->cfg_, p_fmat);
|
||||
monitor_.Start("BoostNewTrees");
|
||||
auto* out = &predt->predictions;
|
||||
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
|
||||
// `gpu_id` be the single source of determining what algorithms to run, but that will
|
||||
// break a lots of existing code.
|
||||
auto device = tparam_.tree_method != TreeMethod::kGPUHist
|
||||
? GenericParameter::kCpuId
|
||||
: in_gpair->DeviceIdx();
|
||||
auto out = MatrixView<float>(
|
||||
&predt->predictions,
|
||||
{p_fmat->Info().num_row_, static_cast<size_t>(ngroup)}, device);
|
||||
CHECK_NE(ngroup, 0);
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
std::vector<std::unique_ptr<RegTree>> ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(p_fmat, out)) {
|
||||
auto v_predt = VectorView<float>{out, 0};
|
||||
if (updaters_.size() > 0 && num_new_trees == 1 &&
|
||||
predt->predictions.Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) {
|
||||
predt->Update(1);
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->Size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup * nrow gpairs";
|
||||
// TODO(canonizer): perform this on GPU if HostDeviceVector has device set.
|
||||
HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup,
|
||||
GradientPair(),
|
||||
in_gpair->DeviceIdx());
|
||||
const auto& gpair_h = in_gpair->ConstHostVector();
|
||||
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
|
||||
bool update_predict = true;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
std::vector<GradientPair>& tmp_h = tmp.HostVector();
|
||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
||||
tmp_h[i] = gpair_h[i * ngroup + gid];
|
||||
});
|
||||
CopyGradient(in_gpair, ngroup, gid, &tmp);
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
auto* out = &predt->predictions;
|
||||
if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 &&
|
||||
updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) {
|
||||
auto v_predt = VectorView<float>{out, static_cast<size_t>(gid)};
|
||||
if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) {
|
||||
update_predict = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,10 +2,28 @@
|
||||
* Copyright 2021 by Contributors
|
||||
*/
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
||||
bst_group_t n_groups, bst_group_t group_id,
|
||||
HostDeviceVector<GradientPair> *out_gpair) {
|
||||
MatrixView<GradientPair const> in{
|
||||
in_gpair,
|
||||
{n_groups, 1ul},
|
||||
{in_gpair->Size() / n_groups, static_cast<size_t>(n_groups)},
|
||||
in_gpair->DeviceIdx()};
|
||||
auto v_in = VectorView<GradientPair const>{in, group_id};
|
||||
out_gpair->Resize(v_in.Size());
|
||||
auto d_out = out_gpair->DeviceSpan();
|
||||
dh::LaunchN(dh::CurrentDevice(), v_in.Size(),
|
||||
[=] __device__(size_t i) { d_out[i] = v_in[i]; });
|
||||
}
|
||||
|
||||
void GPUDartPredictInc(common::Span<float> out_predts,
|
||||
common::Span<float> predts, float tree_w, size_t n_rows,
|
||||
bst_group_t n_groups, bst_group_t group) {
|
||||
|
||||
@ -273,9 +273,9 @@ struct GPUHistMakerDevice {
|
||||
if (d_gpair.size() != dh_gpair->Size()) {
|
||||
d_gpair.resize(dh_gpair->Size());
|
||||
}
|
||||
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(),
|
||||
dh_gpair->ConstDevicePointer() + dh_gpair->Size(),
|
||||
d_gpair.begin());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
||||
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
|
||||
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
@ -528,8 +528,9 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
|
||||
void UpdatePredictionCache(VectorView<float> out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
|
||||
GPUTrainingParam param_d(param);
|
||||
@ -543,13 +544,13 @@ struct GPUHistMakerDevice {
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
|
||||
dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(pos, param_d,
|
||||
GradStats{d_node_sum_gradients[pos]});
|
||||
out_preds_d[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
bst_float weight = evaluator.CalcWeight(
|
||||
pos, param_d, GradStats{d_node_sum_gradients[pos]});
|
||||
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
|
||||
auto v_predt = out_preds_d; // for some reaon out_preds_d is const by both nvcc and clang.
|
||||
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
|
||||
});
|
||||
row_partitioner.reset();
|
||||
}
|
||||
@ -862,13 +863,12 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
|
||||
maker->UpdatePredictionCache(p_out_preds);
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
@ -947,8 +947,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<bst_float> p_out_preds) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
return float_maker_->UpdatePredictionCache(data, p_out_preds);
|
||||
} else {
|
||||
|
||||
@ -110,7 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
|
||||
const DMatrix* data, VectorView<float> out_preds) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
@ -120,19 +120,6 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
|
||||
int,
|
||||
@ -629,7 +616,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
|
||||
VectorView<float> out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
|
||||
@ -638,16 +625,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
|
||||
std::vector<bst_float>& out_preds = p_out_preds->HostVector();
|
||||
|
||||
CHECK_GT(out_preds.size(), 0U);
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
|
||||
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||
return row_set_collection_[node].Size();
|
||||
}, 1024);
|
||||
|
||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const RowSetCollection::Elem rowset = row_set_collection_[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||
@ -664,7 +649,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it * ngroup + gid] += leaf_value;
|
||||
out_preds[*it] += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -687,7 +672,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
|
||||
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
|
||||
p_last_tree_->GetLeafIndex<false>(feats);
|
||||
out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue();
|
||||
out_preds[row_num] += (*p_last_tree_)[lid].LeafValue();
|
||||
|
||||
feats.Drop(inst);
|
||||
});
|
||||
|
||||
@ -118,11 +118,8 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override;
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) override;
|
||||
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const int gid, const int ngroup) override;
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<float> out_preds) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
@ -245,8 +242,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
const int gid = 0, const int ngroup = 1);
|
||||
VectorView<float> out_preds);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
38
tests/cpp/common/test_linalg.cc
Normal file
38
tests/cpp/common/test_linalg.cc
Normal file
@ -0,0 +1,38 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <numeric>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) {
|
||||
storage->Resize(n_rows * n_cols);
|
||||
auto& h_storage = storage->HostVector();
|
||||
|
||||
std::iota(h_storage.begin(), h_storage.end(), 0);
|
||||
|
||||
auto m = MatrixView<float>{storage, {n_cols, 1}, {n_rows, n_cols}, -1};
|
||||
return m;
|
||||
|
||||
}
|
||||
|
||||
TEST(Linalg, Matrix) {
|
||||
size_t kRows = 31, kCols = 77;
|
||||
HostDeviceVector<float> storage;
|
||||
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
||||
ASSERT_EQ(m.DeviceIdx(), GenericParameter::kCpuId);
|
||||
ASSERT_EQ(m(0, 0), 0);
|
||||
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
|
||||
}
|
||||
|
||||
TEST(Linalg, Vector) {
|
||||
size_t kRows = 31, kCols = 77;
|
||||
HostDeviceVector<float> storage;
|
||||
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
||||
auto v = VectorView<float>(m, 3);
|
||||
for (size_t i = 0; i < v.Size(); ++i) {
|
||||
ASSERT_EQ(v[i], m(i, 3));
|
||||
}
|
||||
|
||||
ASSERT_EQ(v[0], 3);
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -294,6 +294,13 @@ TEST(Learner, GPUConfiguration) {
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
}
|
||||
{
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"tree_method", "gpu_hist"},
|
||||
Arg{"gpu_id", "-1"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
}
|
||||
{
|
||||
// with CPU algorithm
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
|
||||
@ -390,7 +390,10 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
hist_maker.Configure(args, &generic_param);
|
||||
|
||||
hist_maker.Update(gpair, dmat, {tree});
|
||||
hist_maker.UpdatePredictionCache(dmat, preds);
|
||||
hist_maker.UpdatePredictionCache(
|
||||
dmat,
|
||||
VectorView<float>{
|
||||
MatrixView<float>(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0});
|
||||
}
|
||||
|
||||
TEST(GpuHist, UniformSampling) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user