Optimize array interface input. (#9090)

This commit is contained in:
Jiaming Yuan
2023-04-28 18:01:58 +08:00
committed by GitHub
parent fb941262b4
commit 17ff471616
3 changed files with 89 additions and 10 deletions

View File

@@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(Context::kCpuId);
CHECK(t.CContiguous());
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
auto t_out = p_out->View(Context::kCpuId);
CHECK(t_out.CContiguous());
auto const shape = t_out.Shape();
DispatchDType(array, Context::kCpuId, [&](auto&& in) {
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
});
});
}
} // namespace