Use the new DeviceOrd in the linalg module. (#9527)
This commit is contained in:
@@ -68,7 +68,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
|
||||
auto &gpair = *out_gpair;
|
||||
gpair.SetDevice(grad_dev);
|
||||
gpair.Reshape(grad.Shape(0), grad.Shape(1));
|
||||
auto d_gpair = gpair.View(grad_dev);
|
||||
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
|
||||
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {
|
||||
|
||||
Reference in New Issue
Block a user