Small cleanup to jvm iter adapter. (#9616)

- Remove header dependency on c_api
- Remove remaining code for arrow.
This commit is contained in:
Jiaming Yuan 2023-09-29 00:39:07 +08:00 committed by GitHub
parent 417c3ba47e
commit d95be1c38d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 76 deletions

View File

@ -62,6 +62,7 @@ OBJECTS= \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \
$(PKGROOT)/src/gbm/gblinear.o \ $(PKGROOT)/src/gbm/gblinear.o \
$(PKGROOT)/src/gbm/gblinear_model.o \ $(PKGROOT)/src/gbm/gblinear_model.o \
$(PKGROOT)/src/data/adapter.o \
$(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/simple_dmatrix.o \
$(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \

View File

@ -62,6 +62,7 @@ OBJECTS= \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \
$(PKGROOT)/src/gbm/gblinear.o \ $(PKGROOT)/src/gbm/gblinear.o \
$(PKGROOT)/src/gbm/gblinear_model.o \ $(PKGROOT)/src/gbm/gblinear_model.o \
$(PKGROOT)/src/data/adapter.o \
$(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/simple_dmatrix.o \
$(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \

View File

@ -269,8 +269,8 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
if (cache_info != nullptr) { if (cache_info != nullptr) {
scache = cache_info; scache = cache_info;
} }
xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR> adapter(
XGBoostBatchCSR> adapter(data_handle, callback); data_handle, callback);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix> { *out = new std::shared_ptr<DMatrix> {
DMatrix::Create( DMatrix::Create(

28
src/data/adapter.cc Normal file
View File

@ -0,0 +1,28 @@
/**
* Copyright 2019-2023, XGBoost Contributors
*/
#include "adapter.h"
#include "../c_api/c_api_error.h" // for API_BEGIN, API_END
#include "xgboost/c_api.h"
namespace xgboost::data {
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
bool IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>::Next() {
if ((*next_callback_)(
data_handle_,
[](void *handle, XGBoostBatchCSR batch) -> int {
API_BEGIN();
static_cast<IteratorAdapter *>(handle)->SetData(batch);
API_END();
},
this) != 0) {
at_first_ = false;
return true;
} else {
return false;
}
}
template class IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
} // namespace xgboost::data

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright (c) 2019~2021 by Contributors * Copyright 2019-2023, XGBoost Contributors
* \file adapter.h * \file adapter.h
*/ */
#ifndef XGBOOST_DATA_ADAPTER_H_ #ifndef XGBOOST_DATA_ADAPTER_H_
@ -16,7 +16,6 @@
#include <utility> // std::move #include <utility> // std::move
#include <vector> #include <vector>
#include "../c_api/c_api_error.h"
#include "../common/error_msg.h" // for MaxFeatureSize #include "../common/error_msg.h" // for MaxFeatureSize
#include "../common/math.h" #include "../common/math.h"
#include "array_interface.h" #include "array_interface.h"
@ -742,8 +741,10 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
dmlc::Parser<uint32_t>* parser_; dmlc::Parser<uint32_t>* parser_;
}; };
/*! \brief Data iterator that takes callback to return data, used in JVM package for /**
* accepting data iterator. */ * @brief Data iterator that takes callback to return data, used in JVM package for accepting data
* iterator.
*/
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR> template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> { class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
public: public:
@ -757,23 +758,9 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
CHECK(at_first_) << "Cannot reset IteratorAdapter"; CHECK(at_first_) << "Cannot reset IteratorAdapter";
} }
bool Next() override { [[nodiscard]] bool Next() override;
if ((*next_callback_)(
data_handle_,
[](void *handle, XGBoostBatchCSR batch) -> int {
API_BEGIN();
static_cast<IteratorAdapter *>(handle)->SetData(batch);
API_END();
},
this) != 0) {
at_first_ = false;
return true;
} else {
return false;
}
}
FileAdapterBatch const& Value() const override { [[nodiscard]] FileAdapterBatch const& Value() const override {
return *batch_.get(); return *batch_.get();
} }
@ -821,12 +808,12 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
block_.index = dmlc::BeginPtr(index_); block_.index = dmlc::BeginPtr(index_);
block_.value = dmlc::BeginPtr(value_); block_.value = dmlc::BeginPtr(value_);
batch_.reset(new FileAdapterBatch(&block_, row_offset_)); batch_ = std::make_unique<FileAdapterBatch>(&block_, row_offset_);
row_offset_ += offset_.size() - 1; row_offset_ += offset_.size() - 1;
} }
size_t NumColumns() const { return columns_; } [[nodiscard]] std::size_t NumColumns() const { return columns_; }
size_t NumRows() const { return kAdapterUnknownSize; } [[nodiscard]] std::size_t NumRows() const { return kAdapterUnknownSize; }
private: private:
std::vector<size_t> offset_; std::vector<size_t> offset_;
@ -848,56 +835,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
std::unique_ptr<FileAdapterBatch> batch_; std::unique_ptr<FileAdapterBatch> batch_;
}; };
enum ColumnDType : uint8_t {
kUnknown,
kInt8,
kUInt8,
kInt16,
kUInt16,
kInt32,
kUInt32,
kInt64,
kUInt64,
kFloat,
kDouble
};
class Column {
public:
Column() = default;
Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
: col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
virtual ~Column() = default;
Column(const Column&) = delete;
Column& operator=(const Column&) = delete;
Column(Column&&) = delete;
Column& operator=(Column&&) = delete;
// whether the valid bit is set for this element
bool IsValid(size_t row_idx) const {
return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
}
virtual COOTuple GetElement(size_t row_idx) const = 0;
virtual bool IsValidElement(size_t row_idx) const = 0;
virtual std::vector<float> AsFloatVector() const = 0;
virtual std::vector<uint64_t> AsUint64Vector() const = 0;
size_t Length() const { return length_; }
protected:
size_t col_idx_;
size_t length_;
size_t null_count_;
const uint8_t* bitmap_;
};
class SparsePageAdapterBatch { class SparsePageAdapterBatch {
HostSparsePageView page_; HostSparsePageView page_;