Use the new DeviceOrd in the linalg module. (#9527)

This commit is contained in:
Jiaming Yuan
2023-08-29 13:37:29 +08:00
committed by GitHub
parent 942b957eef
commit ddf2e68821
43 changed files with 252 additions and 273 deletions

View File

@@ -68,7 +68,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
auto &gpair = *out_gpair;
gpair.SetDevice(grad_dev);
gpair.Reshape(grad.Shape(0), grad.Shape(1));
auto d_gpair = gpair.View(grad_dev);
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
auto cuctx = ctx->CUDACtx();
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {