Implement unified update prediction cache for (gpu_)hist. (#6860)

* Implement utilites for linalg.
* Unify the update prediction cache functions.
* Implement update prediction cache for multi-class gpu hist.
This commit is contained in:
Jiaming Yuan
2021-04-17 00:29:34 +08:00
committed by GitHub
parent 1b26a2a561
commit 556a83022d
10 changed files with 246 additions and 68 deletions

View File

@@ -390,7 +390,10 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
hist_maker.Configure(args, &generic_param);
hist_maker.Update(gpair, dmat, {tree});
hist_maker.UpdatePredictionCache(dmat, preds);
hist_maker.UpdatePredictionCache(
dmat,
VectorView<float>{
MatrixView<float>(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0});
}
TEST(GpuHist, UniformSampling) {