Implement a general array view. (#7365)

* Replace existing matrix and vector view.

This is to prepare for handling higher dimension data and prediction when we support multi-target models.
This commit is contained in:
Jiaming Yuan
2021-11-05 04:16:11 +08:00
committed by GitHub
parent 232144ca09
commit b06040b6d0
11 changed files with 418 additions and 146 deletions

View File

@@ -496,7 +496,7 @@ struct GPUHistMakerDevice {
});
}
void UpdatePredictionCache(VectorView<float> out_preds_d) {
void UpdatePredictionCache(linalg::VectorView<float> out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
auto d_ridx = row_partitioner->GetRows();
@@ -512,13 +512,13 @@ struct GPUHistMakerDevice {
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
dh::LaunchN(d_ridx.size(), [=] __device__(int local_idx) {
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(
int local_idx) mutable {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(
pos, param_d, GradStats{d_node_sum_gradients[pos]});
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
auto v_predt = out_preds_d; // for some reason out_preds_d is const by both nvcc and clang.
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
});
row_partitioner.reset();
}
@@ -834,7 +834,8 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
}
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
bool UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<bst_float> p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
@@ -920,8 +921,9 @@ class GPUHistMaker : public TreeUpdater {
}
}
bool UpdatePredictionCache(const DMatrix *data,
VectorView<bst_float> p_out_preds) override {
bool
UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {