Extend array interface to handle ndarray. (#7434)
* Extend array interface to handle ndarray. The `ArrayInterface` class is extended to support multi-dim array inputs. Previously this class handles only 2-dim (vector is also matrix). This PR specifies the expected dimension at compile-time and the array interface can perform various checks automatically for input data. Also, adapters like CSR are more rigorous about their input. Lastly, row vector and column vector are handled without intervention from the caller.
This commit is contained in:
@@ -254,20 +254,20 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
|
||||
private:
|
||||
ArrayInterface array_interface_;
|
||||
ArrayInterface<2> array_interface_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
ArrayInterface<2> array_interface_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
Line(ArrayInterface<2> array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
size_t Size() const { return array_interface_.Shape(1); }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
|
||||
return {ridx_, idx, array_interface_(ridx_, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -277,11 +277,11 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
return Line{array_interface_, idx};
|
||||
}
|
||||
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumCols() const { return array_interface_.num_cols; }
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
size_t NumCols() const { return array_interface_.Shape(1); }
|
||||
size_t Size() const { return this->NumRows(); }
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
explicit ArrayAdapterBatch(ArrayInterface<2> array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
};
|
||||
|
||||
@@ -294,43 +294,42 @@ class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
public:
|
||||
explicit ArrayAdapter(StringView array_interface) {
|
||||
auto j = Json::Load(array_interface);
|
||||
array_interface_ = ArrayInterface(get<Object const>(j));
|
||||
array_interface_ = ArrayInterface<2>(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
size_t NumColumns() const { return array_interface_.Shape(1); }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
ArrayInterface array_interface_;
|
||||
ArrayInterface<2> array_interface_;
|
||||
};
|
||||
|
||||
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
ArrayInterface<1> indptr_;
|
||||
ArrayInterface<1> indices_;
|
||||
ArrayInterface<1> values_;
|
||||
bst_feature_t n_features_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
ArrayInterface<1> indices_;
|
||||
ArrayInterface<1> values_;
|
||||
size_t ridx_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
|
||||
Line(ArrayInterface<1> indices, ArrayInterface<1> values, size_t ridx,
|
||||
size_t offset)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
|
||||
offset_{offset} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
|
||||
values_.GetElement(offset_ + idx, 0)};
|
||||
return {ridx_, TypedIndex<size_t, 1>{indices_}(offset_ + idx), values_(offset_ + idx)};
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
return values_.Shape(0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -339,17 +338,16 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values, bst_feature_t n_features)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)}, n_features_{n_features} {
|
||||
indptr_.AsColumnVector();
|
||||
values_.AsColumnVector();
|
||||
indices_.AsColumnVector();
|
||||
CSRArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices,
|
||||
ArrayInterface<1> values, bst_feature_t n_features)
|
||||
: indptr_{std::move(indptr)},
|
||||
indices_{std::move(indices)},
|
||||
values_{std::move(values)},
|
||||
n_features_{n_features} {
|
||||
}
|
||||
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
size_t size = indptr_.Shape(0);
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
@@ -357,19 +355,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
size_t Size() const { return this->NumRows(); }
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);
|
||||
auto begin_no_stride = TypedIndex<size_t, 1>{indptr_}(idx);
|
||||
auto end_no_stride = TypedIndex<size_t, 1>{indptr_}(idx + 1);
|
||||
|
||||
auto indices = indices_;
|
||||
auto values = values_;
|
||||
// Slice indices and values, stride remains unchanged since this is slicing by
|
||||
// specific index.
|
||||
auto offset = indices.strides[0] * begin_no_stride;
|
||||
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
indices.shape[0] = end_no_stride - begin_no_stride;
|
||||
values.shape[0] = end_no_stride - begin_no_stride;
|
||||
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
|
||||
return Line{indices, values, idx, begin_offset};
|
||||
return Line{indices, values, idx, offset};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -391,7 +389,7 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
|
||||
return batch_;
|
||||
}
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_cols * indptr_.num_rows;
|
||||
size_t size = indptr_.Shape(0);
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
@@ -399,9 +397,9 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
|
||||
|
||||
private:
|
||||
CSRArrayAdapterBatch batch_;
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
ArrayInterface<1> indptr_;
|
||||
ArrayInterface<1> indices_;
|
||||
ArrayInterface<1> values_;
|
||||
size_t num_cols_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user