Support host data in proxy DMatrix. (#7087)
This commit is contained in:
@@ -257,7 +257,10 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
Line const GetLine(size_t idx) const {
|
||||
return Line{array_interface_, idx};
|
||||
}
|
||||
size_t Size() const { return array_interface_.num_rows; }
|
||||
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumCols() const { return array_interface_.num_cols; }
|
||||
size_t Size() const { return this->NumRows(); }
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
@@ -288,6 +291,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
bst_feature_t n_features_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
@@ -311,23 +315,27 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
ArrayInterface values, bst_feature_t n_features)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {
|
||||
values_{std::move(values)}, n_features_{n_features} {
|
||||
indptr_.AsColumnVector();
|
||||
values_.AsColumnVector();
|
||||
indices_.AsColumnVector();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
size_t NumCols() const { return n_features_; }
|
||||
size_t Size() const { return this->NumRows(); }
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
||||
@@ -356,7 +364,8 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
|
||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||
size_t num_cols)
|
||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_,
|
||||
static_cast<bst_feature_t>(num_cols_)};
|
||||
}
|
||||
|
||||
CSRArrayAdapterBatch const& Value() const override {
|
||||
|
||||
Reference in New Issue
Block a user