[SYC]. Implementation of HostDeviceVector (#10842)

This commit is contained in:
Dmitry Razdoburdin
2024-09-24 22:45:17 +02:00
committed by GitHub
parent bc69a3e877
commit 2179baa50c
25 changed files with 937 additions and 282 deletions

View File

@@ -277,7 +277,7 @@ class Predictor : public xgboost::Predictor {
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
const gbm::GBTreeModel &model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
::sycl::queue qu = device_manager.GetQueue(ctx_->Device());
::sycl::queue* qu = device_manager.GetQueue(ctx_->Device());
// TODO(razdoburdin): remove temporary workaround after cache fix
sycl::DeviceMatrix device_matrix;
device_matrix.Init(qu, dmat);
@@ -290,9 +290,9 @@ class Predictor : public xgboost::Predictor {
if (tree_begin < tree_end) {
const bool any_missing = !(dmat->IsDense());
if (any_missing) {
DevicePredictInternal<true>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
DevicePredictInternal<true>(qu, device_matrix, out_preds, model, tree_begin, tree_end);
} else {
DevicePredictInternal<false>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
DevicePredictInternal<false>(qu, device_matrix, out_preds, model, tree_begin, tree_end);
}
}
}