Remove internal use of gpu_id. (#9568)

This commit is contained in:
Jiaming Yuan
2023-09-20 23:29:51 +08:00
committed by GitHub
parent 38ac52dd87
commit 8c676c889d
121 changed files with 1012 additions and 1044 deletions

View File

@@ -212,7 +212,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
bst_target_t const n_groups = model_.learner_model_param->OutputLength();
monitor_.Start("BoostNewTrees");
predt->predictions.SetDevice(ctx_->Ordinal());
predt->predictions.SetDevice(ctx_->Device());
auto out = linalg::MakeTensorView(ctx_, &predt->predictions, p_fmat->Info().num_row_,
model_.learner_model_param->OutputLength());
CHECK_NE(n_groups, 0);
@@ -248,7 +248,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
} else {
CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs";
linalg::Matrix<GradientPair> tmp{{in_gpair->Shape(0), static_cast<std::size_t>(1ul)},
ctx_->Ordinal()};
ctx_->Device()};
bool update_predict = true;
for (bst_target_t gid = 0; gid < n_groups; ++gid) {
node_position.clear();
@@ -736,7 +736,7 @@ class Dart : public GBTree {
PredictionCacheEntry predts; // temporary storage for prediction
if (ctx_->IsCUDA()) {
predts.predictions.SetDevice(ctx_->gpu_id);
predts.predictions.SetDevice(ctx_->Device());
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
// multi-target is not yet supported.
@@ -761,8 +761,8 @@ class Dart : public GBTree {
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
if (predts.predictions.Device().IsCUDA()) {
p_out_preds->predictions.SetDevice(predts.predictions.Device());
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, n_groups,
group);
@@ -801,8 +801,8 @@ class Dart : public GBTree {
StringView msg{"Unsupported data type for inplace predict."};
PredictionCacheEntry predts;
if (ctx_->gpu_id != Context::kCpuId) {
predts.predictions.SetDevice(ctx_->gpu_id);
if (ctx_->IsCUDA()) {
predts.predictions.SetDevice(ctx_->Device());
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
@@ -838,8 +838,8 @@ class Dart : public GBTree {
CHECK_EQ(predts.predictions.Size(), p_out_preds->predictions.Size());
size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
if (predts.predictions.Device().IsCUDA()) {
p_out_preds->predictions.SetDevice(predts.predictions.Device());
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.Device());
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,