Use array interface for CSC matrix. (#8672)
* Use array interface for CSC matrix. Use array interface for CSC matrix and align the interface with CSR and dense. - Fix nthread issue in the R package DMatrix. - Unify the behavior of handling `missing` with other inputs. - Unify the behavior of handling `missing` around R, Python, Java, and Scala DMatrix. - Expose `num_non_missing` to the JVM interface. - Deprecate old CSR and CSC constructors.
This commit is contained in:
@@ -6,25 +6,25 @@
|
||||
#define XGBOOST_DATA_ADAPTER_H_
|
||||
#include <dmlc/data.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <algorithm>
|
||||
#include <cstddef> // std::size_t
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#include "array_interface.h"
|
||||
#include "../c_api/c_api_error.h"
|
||||
#include "../common/math.h"
|
||||
#include "array_interface.h"
|
||||
#include "arrow-cdi.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/string_view.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -472,6 +472,84 @@ class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
|
||||
size_t num_columns_;
|
||||
};
|
||||
|
||||
class CSCArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface<1> indptr_;
|
||||
ArrayInterface<1> indices_;
|
||||
ArrayInterface<1> values_;
|
||||
bst_row_t n_rows_;
|
||||
|
||||
class Line {
|
||||
std::size_t column_idx_;
|
||||
ArrayInterface<1> row_idx_;
|
||||
ArrayInterface<1> values_;
|
||||
std::size_t offset_;
|
||||
|
||||
public:
|
||||
Line(std::size_t idx, ArrayInterface<1> row_idx, ArrayInterface<1> values, std::size_t offset)
|
||||
: column_idx_{idx},
|
||||
row_idx_{std::move(row_idx)},
|
||||
values_{std::move(values)},
|
||||
offset_{offset} {}
|
||||
|
||||
std::size_t Size() const { return values_.Shape(0); }
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
return {TypedIndex<std::size_t, 1>{row_idx_}(offset_ + idx), column_idx_,
|
||||
values_(offset_ + idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
static constexpr bool kIsRowMajor = false;
|
||||
|
||||
CSCArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices,
|
||||
ArrayInterface<1> values, bst_row_t n_rows)
|
||||
: indptr_{std::move(indptr)},
|
||||
indices_{std::move(indices)},
|
||||
values_{std::move(values)},
|
||||
n_rows_{n_rows} {}
|
||||
|
||||
std::size_t Size() const { return indptr_.n - 1; }
|
||||
Line GetLine(std::size_t idx) const {
|
||||
auto begin_no_stride = TypedIndex<std::size_t, 1>{indptr_}(idx);
|
||||
auto end_no_stride = TypedIndex<std::size_t, 1>{indptr_}(idx + 1);
|
||||
|
||||
auto indices = indices_;
|
||||
auto values = values_;
|
||||
// Slice indices and values, stride remains unchanged since this is slicing by
|
||||
// specific index.
|
||||
auto offset = indices.strides[0] * begin_no_stride;
|
||||
indices.shape[0] = end_no_stride - begin_no_stride;
|
||||
values.shape[0] = end_no_stride - begin_no_stride;
|
||||
|
||||
return Line{idx, indices, values, offset};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief CSC adapter with support for array interface.
|
||||
*/
|
||||
class CSCArrayAdapter : public detail::SingleBatchDataIter<CSCArrayAdapterBatch> {
|
||||
ArrayInterface<1> indptr_;
|
||||
ArrayInterface<1> indices_;
|
||||
ArrayInterface<1> values_;
|
||||
size_t num_rows_;
|
||||
CSCArrayAdapterBatch batch_;
|
||||
|
||||
public:
|
||||
CSCArrayAdapter(StringView indptr, StringView indices, StringView values, std::size_t num_rows)
|
||||
: indptr_{indptr},
|
||||
indices_{indices},
|
||||
values_{values},
|
||||
num_rows_{num_rows},
|
||||
batch_{
|
||||
CSCArrayAdapterBatch{indptr_, indices_, values_, static_cast<bst_row_t>(num_rows_)}} {}
|
||||
|
||||
// JVM package sends 0 as unknown
|
||||
size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }
|
||||
size_t NumColumns() const { return indptr_.n - 1; }
|
||||
const CSCArrayAdapterBatch& Value() const override { return batch_; }
|
||||
};
|
||||
|
||||
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
enum class DTType : std::uint8_t {
|
||||
kFloat32 = 0,
|
||||
|
||||
Reference in New Issue
Block a user