Use the new DeviceOrd in the linalg module. (#9527)

This commit is contained in:
Jiaming Yuan
2023-08-29 13:37:29 +08:00
committed by GitHub
parent 942b957eef
commit ddf2e68821
43 changed files with 252 additions and 273 deletions

View File

@@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
}
} // namespace tree
} // namespace xgboost

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2022 by XGBoost Contributors
* Copyright 2022-2023 by XGBoost Contributors
*
* \brief Utilities for estimating initial score.
*/
@@ -41,7 +41,7 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
auto sample = i % gpair.Shape(0);
return GradientPairPrecise{gpair(sample, target)};
});
auto d_sum = sum.View(ctx->gpu_id);
auto d_sum = sum.View(ctx->Device());
CHECK(d_sum.CContiguous());
dh::XGBCachingDeviceAllocator<char> alloc;

View File

@@ -774,7 +774,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::VectorView<float> out_preds) {
auto const &tree = *p_last_tree;
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
CHECK(out_preds.Device().IsCPU());
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
@@ -809,7 +809,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
auto n_nodes = mttree->Size();
auto n_targets = tree.NumTargets();
CHECK_EQ(out_preds.Shape(1), n_targets);
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
CHECK(out_preds.Device().IsCPU());
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);

View File

@@ -516,9 +516,10 @@ struct GPUHistMakerDevice {
}
CHECK(p_tree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
CHECK(out_preds_d.Device().IsCUDA());
CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal());
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
auto d_position = dh::ToSpan(positions);
CHECK_EQ(out_preds_d.Size(), d_position.size());