Use the new DeviceOrd in the linalg module. (#9527)
This commit is contained in:
@@ -366,7 +366,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||
// the slice function.
|
||||
if (this->labels.Size() != this->num_row_) {
|
||||
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
|
||||
auto t_labels = this->labels.View(this->labels.Data()->Device());
|
||||
out.labels.Reshape(ridxs.size(), labels.Shape(1));
|
||||
out.labels.Data()->HostVector() =
|
||||
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
|
||||
@@ -394,7 +394,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
if (this->base_margin_.Size() != this->num_row_) {
|
||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||
<< "Incorrect size of base margin vector.";
|
||||
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
|
||||
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->Device());
|
||||
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
|
||||
out.base_margin_.Data()->HostVector() =
|
||||
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
|
||||
@@ -445,7 +445,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t_out = p_out->View(Context::kCpuId);
|
||||
auto t_out = p_out->View(DeviceOrd::CPU());
|
||||
CHECK(t_out.CContiguous());
|
||||
auto const shape = t_out.Shape();
|
||||
DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
|
||||
@@ -564,7 +564,7 @@ void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, Da
|
||||
CHECK(key);
|
||||
auto proc = [&](auto cast_d_ptr) {
|
||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
|
||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json interface {
|
||||
linalg::ArrayInterface(t)
|
||||
@@ -739,8 +739,7 @@ void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
namespace {
|
||||
template <typename T>
|
||||
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
||||
bool valid =
|
||||
v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device;
|
||||
bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device;
|
||||
if (!valid) {
|
||||
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
|
||||
"the booster. The device ordinal of the data is: "
|
||||
|
||||
@@ -50,7 +50,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(ptr_device);
|
||||
auto t = p_out->View(DeviceOrd::CUDA(ptr_device));
|
||||
linalg::ElementWiseTransformDevice(
|
||||
t,
|
||||
[=] __device__(size_t i, T) {
|
||||
|
||||
Reference in New Issue
Block a user