Implement a general array view. (#7365)
* Replace existing matrix and vector view. This is to prepare for handling higher dimension data and prediction when we support multi-target models.
This commit is contained in:
@@ -496,7 +496,7 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(VectorView<float> out_preds_d) {
|
||||
void UpdatePredictionCache(linalg::VectorView<float> out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
@@ -512,13 +512,13 @@ struct GPUHistMakerDevice {
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::LaunchN(d_ridx.size(), [=] __device__(int local_idx) {
|
||||
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(
|
||||
int local_idx) mutable {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(
|
||||
pos, param_d, GradStats{d_node_sum_gradients[pos]});
|
||||
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
|
||||
auto v_predt = out_preds_d; // for some reason out_preds_d is const by both nvcc and clang.
|
||||
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
|
||||
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
|
||||
});
|
||||
row_partitioner.reset();
|
||||
}
|
||||
@@ -834,7 +834,8 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
linalg::VectorView<bst_float> p_out_preds) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
@@ -920,8 +921,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<bst_float> p_out_preds) override {
|
||||
bool
|
||||
UpdatePredictionCache(const DMatrix *data,
|
||||
linalg::VectorView<bst_float> p_out_preds) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
return float_maker_->UpdatePredictionCache(data, p_out_preds);
|
||||
} else {
|
||||
|
||||
@@ -105,7 +105,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data, VectorView<float> out_preds) {
|
||||
const DMatrix* data, linalg::VectorView<float> out_preds) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
@@ -319,7 +319,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
VectorView<float> out_preds) {
|
||||
linalg::VectorView<float> out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
|
||||
@@ -352,7 +352,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -105,7 +105,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const std::vector<RegTree*>& trees) override;
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<float> out_preds) override;
|
||||
linalg::VectorView<float> out_preds) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
@@ -174,7 +174,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
RegTree* p_tree);
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
VectorView<float> out_preds);
|
||||
linalg::VectorView<float> out_preds);
|
||||
|
||||
protected:
|
||||
// initialize temp data structure
|
||||
|
||||
Reference in New Issue
Block a user