Add range-based slicing to tensor view. (#7453)

This commit is contained in:
Jiaming Yuan
2021-11-27 13:42:36 +08:00
committed by GitHub
parent 6f38f5affa
commit 85cbd32c5a
10 changed files with 361 additions and 132 deletions

View File

@@ -85,9 +85,8 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
auto const &labels = info.labels_.ConstHostVector();
std::vector<double> results_storage(n_classes * 3, 0);
linalg::TensorView<double> results(results_storage,
{n_classes, static_cast<size_t>(3)},
GenericParameter::kCpuId);
linalg::TensorView<double, 2> results(results_storage, {n_classes, static_cast<size_t>(3)},
GenericParameter::kCpuId);
auto local_area = results.Slice(linalg::All(), 0);
auto tp = results.Slice(linalg::All(), 1);
auto auc = results.Slice(linalg::All(), 2);