diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index f42c94501..e3af418e3 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -62,6 +62,7 @@ OBJECTS= \ $(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gblinear.o \ $(PKGROOT)/src/gbm/gblinear_model.o \ + $(PKGROOT)/src/data/adapter.o \ $(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 1b620751f..8f003403f 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -62,6 +62,7 @@ OBJECTS= \ $(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gblinear.o \ $(PKGROOT)/src/gbm/gblinear_model.o \ + $(PKGROOT)/src/data/adapter.o \ $(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f8b0aa3de..8cead56d8 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -269,8 +269,8 @@ XGB_DLL int XGDMatrixCreateFromDataIter( if (cache_info != nullptr) { scache = cache_info; } - xgboost::data::IteratorAdapter adapter(data_handle, callback); + xgboost::data::IteratorAdapter adapter( + data_handle, callback); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr { DMatrix::Create( diff --git a/src/data/adapter.cc b/src/data/adapter.cc new file mode 100644 index 000000000..4fa171c9d --- /dev/null +++ b/src/data/adapter.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2019-2023, XGBoost Contributors + */ +#include "adapter.h" + +#include "../c_api/c_api_error.h" // for API_BEGIN, API_END +#include "xgboost/c_api.h" + +namespace xgboost::data { +template +bool IteratorAdapter::Next() { + if ((*next_callback_)( + data_handle_, + [](void *handle, XGBoostBatchCSR batch) -> int { + API_BEGIN(); + static_cast(handle)->SetData(batch); + API_END(); + }, + this) != 0) { + at_first_ = false; + return true; + } else { + return false; + } +} + +template class IteratorAdapter; +} // namespace xgboost::data diff --git a/src/data/adapter.h b/src/data/adapter.h index e7eaa372f..9e7058aba 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -1,5 +1,5 @@ -/*! - * Copyright (c) 2019~2021 by Contributors +/** + * Copyright 2019-2023, XGBoost Contributors * \file adapter.h */ #ifndef XGBOOST_DATA_ADAPTER_H_ @@ -16,7 +16,6 @@ #include // std::move #include -#include "../c_api/c_api_error.h" #include "../common/error_msg.h" // for MaxFeatureSize #include "../common/math.h" #include "array_interface.h" @@ -742,8 +741,10 @@ class FileAdapter : dmlc::DataIter { dmlc::Parser* parser_; }; -/*! \brief Data iterator that takes callback to return data, used in JVM package for - * accepting data iterator. */ +/** + * @brief Data iterator that takes callback to return data, used in JVM package for accepting data + * iterator. + */ template class IteratorAdapter : public dmlc::DataIter { public: @@ -757,23 +758,9 @@ class IteratorAdapter : public dmlc::DataIter { CHECK(at_first_) << "Cannot reset IteratorAdapter"; } - bool Next() override { - if ((*next_callback_)( - data_handle_, - [](void *handle, XGBoostBatchCSR batch) -> int { - API_BEGIN(); - static_cast(handle)->SetData(batch); - API_END(); - }, - this) != 0) { - at_first_ = false; - return true; - } else { - return false; - } - } + [[nodiscard]] bool Next() override; - FileAdapterBatch const& Value() const override { + [[nodiscard]] FileAdapterBatch const& Value() const override { return *batch_.get(); } @@ -821,12 +808,12 @@ class IteratorAdapter : public dmlc::DataIter { block_.index = dmlc::BeginPtr(index_); block_.value = dmlc::BeginPtr(value_); - batch_.reset(new FileAdapterBatch(&block_, row_offset_)); + batch_ = std::make_unique(&block_, row_offset_); row_offset_ += offset_.size() - 1; } - size_t NumColumns() const { return columns_; } - size_t NumRows() const { return kAdapterUnknownSize; } + [[nodiscard]] std::size_t NumColumns() const { return columns_; } + [[nodiscard]] std::size_t NumRows() const { return kAdapterUnknownSize; } private: std::vector offset_; @@ -848,56 +835,6 @@ class IteratorAdapter : public dmlc::DataIter { std::unique_ptr batch_; }; -enum ColumnDType : uint8_t { - kUnknown, - kInt8, - kUInt8, - kInt16, - kUInt16, - kInt32, - kUInt32, - kInt64, - kUInt64, - kFloat, - kDouble -}; - -class Column { - public: - Column() = default; - - Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap) - : col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {} - - virtual ~Column() = default; - - Column(const Column&) = delete; - Column& operator=(const Column&) = delete; - Column(Column&&) = delete; - Column& operator=(Column&&) = delete; - - // whether the valid bit is set for this element - bool IsValid(size_t row_idx) const { - return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8)))); - } - - virtual COOTuple GetElement(size_t row_idx) const = 0; - - virtual bool IsValidElement(size_t row_idx) const = 0; - - virtual std::vector AsFloatVector() const = 0; - - virtual std::vector AsUint64Vector() const = 0; - - size_t Length() const { return length_; } - - protected: - size_t col_idx_; - size_t length_; - size_t null_count_; - const uint8_t* bitmap_; -}; - class SparsePageAdapterBatch { HostSparsePageView page_;