Support column major array. (#6765)
This commit is contained in:
@@ -234,6 +234,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
@@ -241,15 +242,14 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
return Line{array_interface_, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
@@ -286,14 +286,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
|
||||
size_t offset)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
|
||||
offset_{offset} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
|
||||
values_.GetElement(offset_ + idx, 0)};
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
@@ -304,7 +309,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
values_{std::move(values)} {
|
||||
indptr_.AsColumnVector();
|
||||
values_.AsColumnVector();
|
||||
indices_.AsColumnVector();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
@@ -313,15 +322,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);
|
||||
|
||||
auto indices = indices_;
|
||||
auto values = values_;
|
||||
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
|
||||
return Line{indices, values, idx, begin_offset};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user