Use ellpack for prediction only when sparsepage doesn't exist. (#5504)
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "gblinear_model.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
@@ -68,7 +69,7 @@ class GBLinear : public GradientBooster {
|
||||
updater_->Configure(cfg);
|
||||
monitor_.Init("GBLinear");
|
||||
if (param_.updater == "gpu_coord_descent") {
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ void GBTree::ConfigureUpdaters() {
|
||||
tparam_.updater_seq = "grow_quantile_histmaker";
|
||||
break;
|
||||
case TreeMethod::kGPUHist: {
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
tparam_.updater_seq = "grow_gpu_hist";
|
||||
break;
|
||||
}
|
||||
@@ -391,17 +391,21 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
CHECK(cpu_predictor_);
|
||||
return cpu_predictor_;
|
||||
}
|
||||
|
||||
auto on_device =
|
||||
f_dmat &&
|
||||
(f_dmat->PageExists<EllpackPage>() ||
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead());
|
||||
// Data comes from Device DMatrix.
|
||||
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
|
||||
!f_dmat->PageExists<SparsePage>();
|
||||
// Data comes from device memory, like CuDF or CuPy.
|
||||
auto is_from_device =
|
||||
f_dmat && f_dmat->PageExists<SparsePage>() &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
auto on_device = is_ellpack || is_from_device;
|
||||
|
||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||
if (on_device && generic_param_->gpu_id >= 0) {
|
||||
@@ -434,7 +438,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
return cpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user