Use the new DeviceOrd in the linalg module. (#9527)

This commit is contained in:
Jiaming Yuan
2023-08-29 13:37:29 +08:00
committed by GitHub
parent 942b957eef
commit ddf2e68821
43 changed files with 252 additions and 273 deletions

View File

@@ -14,10 +14,10 @@
#include "../../../src/data/array_interface.h"
namespace xgboost {
inline void TestMetaInfoStridedData(int32_t device) {
inline void TestMetaInfoStridedData(DeviceOrd device) {
MetaInfo info;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
ctx.UpdateAllowUnknown(Args{{"device", device.Name()}});
{
// labels
linalg::Tensor<float, 3> labels;
@@ -28,9 +28,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
ASSERT_EQ(t_labels.Shape().size(), 2);
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
auto const& h_result = info.labels.View(-1);
auto const& h_result = info.labels.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(-1);
auto in_labels = labels.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
@@ -62,9 +62,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
ASSERT_EQ(t_margin.Shape().size(), 2);
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
auto const& h_result = info.base_margin_.View(-1);
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(-1);
auto in_margin = base_margin.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);