Support column major array. (#6765)

This commit is contained in:
Jiaming Yuan
2021-03-20 05:19:46 +08:00
committed by GitHub
parent f6fe15d11f
commit 4ee8340e79
9 changed files with 181 additions and 151 deletions

View File

@@ -234,6 +234,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
class Line {
ArrayInterface array_interface_;
size_t ridx_;
public:
Line(ArrayInterface array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
@@ -241,15 +242,14 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
size_t Size() const { return array_interface_.num_cols; }
COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_.GetElement(idx)};
return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
}
};
public:
ArrayAdapterBatch() = default;
Line const GetLine(size_t idx) const {
auto line = array_interface_.SliceRow(idx);
return Line{line, idx};
return Line{array_interface_, idx};
}
explicit ArrayAdapterBatch(ArrayInterface array_interface)
@@ -286,14 +286,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indices_;
ArrayInterface values_;
size_t ridx_;
size_t offset_;
public:
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
size_t offset)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
offset_{offset} {}
COOTuple GetElement(size_t idx) const {
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
values_.GetElement(offset_ + idx, 0)};
}
size_t Size() const {
return values_.num_rows * values_.num_cols;
}
@@ -304,7 +309,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {}
values_{std::move(values)} {
indptr_.AsColumnVector();
values_.AsColumnVector();
indices_.AsColumnVector();
}
size_t Size() const {
size_t size = indptr_.num_rows * indptr_.num_cols;
@@ -313,15 +322,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx);
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
auto indices = indices_.SliceOffset(begin_offset);
auto values = values_.SliceOffset(begin_offset);
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);
auto indices = indices_;
auto values = values_;
values.num_cols = end_offset - begin_offset;
values.num_rows = 1;
indices.num_cols = values.num_cols;
indices.num_rows = values.num_rows;
return Line{indices, values, idx};
return Line{indices, values, idx, begin_offset};
}
};