Remove internal use of gpu_id. (#9568)
This commit is contained in:
@@ -212,7 +212,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
|
||||
bst_target_t const n_groups = model_.learner_model_param->OutputLength();
|
||||
monitor_.Start("BoostNewTrees");
|
||||
|
||||
predt->predictions.SetDevice(ctx_->Ordinal());
|
||||
predt->predictions.SetDevice(ctx_->Device());
|
||||
auto out = linalg::MakeTensorView(ctx_, &predt->predictions, p_fmat->Info().num_row_,
|
||||
model_.learner_model_param->OutputLength());
|
||||
CHECK_NE(n_groups, 0);
|
||||
@@ -248,7 +248,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs";
|
||||
linalg::Matrix<GradientPair> tmp{{in_gpair->Shape(0), static_cast<std::size_t>(1ul)},
|
||||
ctx_->Ordinal()};
|
||||
ctx_->Device()};
|
||||
bool update_predict = true;
|
||||
for (bst_target_t gid = 0; gid < n_groups; ++gid) {
|
||||
node_position.clear();
|
||||
@@ -736,7 +736,7 @@ class Dart : public GBTree {
|
||||
|
||||
PredictionCacheEntry predts; // temporary storage for prediction
|
||||
if (ctx_->IsCUDA()) {
|
||||
predts.predictions.SetDevice(ctx_->gpu_id);
|
||||
predts.predictions.SetDevice(ctx_->Device());
|
||||
}
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
// multi-target is not yet supported.
|
||||
@@ -761,8 +761,8 @@ class Dart : public GBTree {
|
||||
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
|
||||
|
||||
size_t n_rows = p_fmat->Info().num_row_;
|
||||
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
|
||||
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
|
||||
if (predts.predictions.Device().IsCUDA()) {
|
||||
p_out_preds->predictions.SetDevice(predts.predictions.Device());
|
||||
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(),
|
||||
predts.predictions.DeviceSpan(), w, n_rows, n_groups,
|
||||
group);
|
||||
@@ -801,8 +801,8 @@ class Dart : public GBTree {
|
||||
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
PredictionCacheEntry predts;
|
||||
if (ctx_->gpu_id != Context::kCpuId) {
|
||||
predts.predictions.SetDevice(ctx_->gpu_id);
|
||||
if (ctx_->IsCUDA()) {
|
||||
predts.predictions.SetDevice(ctx_->Device());
|
||||
}
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
|
||||
@@ -838,8 +838,8 @@ class Dart : public GBTree {
|
||||
CHECK_EQ(predts.predictions.Size(), p_out_preds->predictions.Size());
|
||||
|
||||
size_t n_rows = p_fmat->Info().num_row_;
|
||||
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
|
||||
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
|
||||
if (predts.predictions.Device().IsCUDA()) {
|
||||
p_out_preds->predictions.SetDevice(predts.predictions.Device());
|
||||
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.Device());
|
||||
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
|
||||
predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,
|
||||
|
||||
Reference in New Issue
Block a user