Use the new DeviceOrd in the linalg module. (#9527)

This commit is contained in:
Jiaming Yuan
2023-08-29 13:37:29 +08:00
committed by GitHub
parent 942b957eef
commit ddf2e68821
43 changed files with 252 additions and 273 deletions

View File

@@ -366,7 +366,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
if (this->labels.Size() != this->num_row_) {
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
auto t_labels = this->labels.View(this->labels.Data()->Device());
out.labels.Reshape(ridxs.size(), labels.Shape(1));
out.labels.Data()->HostVector() =
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
@@ -394,7 +394,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->Device());
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
out.base_margin_.Data()->HostVector() =
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
@@ -445,7 +445,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
return;
}
p_out->Reshape(array.shape);
auto t_out = p_out->View(Context::kCpuId);
auto t_out = p_out->View(DeviceOrd::CPU());
CHECK(t_out.CContiguous());
auto const shape = t_out.Shape();
DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
@@ -564,7 +564,7 @@ void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, Da
CHECK(key);
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface {
linalg::ArrayInterface(t)
@@ -739,8 +739,7 @@ void MetaInfo::SynchronizeNumberOfColumns() {
namespace {
template <typename T>
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
bool valid =
v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device;
bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device;
if (!valid) {
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
"the booster. The device ordinal of the data is: "

View File

@@ -50,7 +50,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(ptr_device);
auto t = p_out->View(DeviceOrd::CUDA(ptr_device));
linalg::ElementWiseTransformDevice(
t,
[=] __device__(size_t i, T) {