Optimize array interface input. (#9090)
This commit is contained in:
@@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(Context::kCpuId);
|
||||
CHECK(t.CContiguous());
|
||||
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
|
||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
|
||||
auto t_out = p_out->View(Context::kCpuId);
|
||||
CHECK(t_out.CContiguous());
|
||||
auto const shape = t_out.Shape();
|
||||
DispatchDType(array, Context::kCpuId, [&](auto&& in) {
|
||||
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
|
||||
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
|
||||
});
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user