Implement unified update prediction cache for (gpu_)hist. (#6860)
* Implement utilites for linalg. * Unify the update prediction cache functions. * Implement update prediction cache for multi-class gpu hist.
This commit is contained in:
@@ -273,9 +273,9 @@ struct GPUHistMakerDevice {
|
||||
if (d_gpair.size() != dh_gpair->Size()) {
|
||||
d_gpair.resize(dh_gpair->Size());
|
||||
}
|
||||
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(),
|
||||
dh_gpair->ConstDevicePointer() + dh_gpair->Size(),
|
||||
d_gpair.begin());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
||||
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
|
||||
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
@@ -528,8 +528,9 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
|
||||
void UpdatePredictionCache(VectorView<float> out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
|
||||
GPUTrainingParam param_d(param);
|
||||
@@ -543,14 +544,14 @@ struct GPUHistMakerDevice {
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(pos, param_d,
|
||||
GradStats{d_node_sum_gradients[pos]});
|
||||
out_preds_d[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
});
|
||||
dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(
|
||||
pos, param_d, GradStats{d_node_sum_gradients[pos]});
|
||||
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
|
||||
auto v_predt = out_preds_d; // for some reaon out_preds_d is const by both nvcc and clang.
|
||||
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
|
||||
});
|
||||
row_partitioner.reset();
|
||||
}
|
||||
|
||||
@@ -862,13 +863,12 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
|
||||
maker->UpdatePredictionCache(p_out_preds);
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
@@ -947,8 +947,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<bst_float> p_out_preds) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
return float_maker_->UpdatePredictionCache(data, p_out_preds);
|
||||
} else {
|
||||
|
||||
@@ -110,7 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
|
||||
const DMatrix* data, VectorView<float> out_preds) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
@@ -120,19 +120,6 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
|
||||
int,
|
||||
@@ -629,7 +616,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
|
||||
VectorView<float> out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
|
||||
@@ -638,16 +625,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
|
||||
std::vector<bst_float>& out_preds = p_out_preds->HostVector();
|
||||
|
||||
CHECK_GT(out_preds.size(), 0U);
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
|
||||
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||
return row_set_collection_[node].Size();
|
||||
}, 1024);
|
||||
|
||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const RowSetCollection::Elem rowset = row_set_collection_[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||
@@ -664,7 +649,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it * ngroup + gid] += leaf_value;
|
||||
out_preds[*it] += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -687,7 +672,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
|
||||
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
|
||||
p_last_tree_->GetLeafIndex<false>(feats);
|
||||
out_preds[row_num * ngroup + gid] += (*p_last_tree_)[lid].LeafValue();
|
||||
out_preds[row_num] += (*p_last_tree_)[lid].LeafValue();
|
||||
|
||||
feats.Drop(inst);
|
||||
});
|
||||
|
||||
@@ -118,11 +118,8 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override;
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) override;
|
||||
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const int gid, const int ngroup) override;
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
VectorView<float> out_preds) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
@@ -245,8 +242,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
const int gid = 0, const int ngroup = 1);
|
||||
VectorView<float> out_preds);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
Reference in New Issue
Block a user