Avoid including c_api.h in header files. (#5782)

This commit is contained in:
Jiaming Yuan 2020-06-12 16:24:24 +08:00 committed by GitHub
parent 3028fa6b42
commit 306e38ff31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 29 deletions

View File

@ -38,16 +38,16 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
} }
} }
int XGBRegisterLogCallback(void (*callback)(const char*)) { XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN(); API_BEGIN();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get(); LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback); registry->Register(callback);
API_END(); API_END();
} }
int XGDMatrixCreateFromFile(const char *fname, XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent, int silent,
DMatrixHandle *out) { DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
bool load_row_split = false; bool load_row_split = false;
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
@ -60,7 +60,7 @@ int XGDMatrixCreateFromFile(const char *fname,
} }
XGB_DLL int XGDMatrixCreateFromDataIter( XGB_DLL int XGDMatrixCreateFromDataIter(
void *data_handle, // a Java interator void *data_handle, // a Java iterator
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
const char *cache_info, DMatrixHandle *out) { const char *cache_info, DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
@ -69,7 +69,8 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
if (cache_info != nullptr) { if (cache_info != nullptr) {
scache = cache_info; scache = cache_info;
} }
xgboost::data::IteratorAdapter adapter(data_handle, callback); xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> adapter(data_handle, callback);
*out = new std::shared_ptr<DMatrix> { *out = new std::shared_ptr<DMatrix> {
DMatrix::Create( DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(), &adapter, std::numeric_limits<float>::quiet_NaN(),

View File

@ -4,6 +4,7 @@
* \brief C error handling * \brief C error handling
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "./c_api_error.h" #include "./c_api_error.h"
struct XGBAPIErrorEntry { struct XGBAPIErrorEntry {
@ -12,7 +13,7 @@ struct XGBAPIErrorEntry {
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>; using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
const char *XGBGetLastError() { XGB_DLL const char *XGBGetLastError() {
return XGBAPIErrorStore::Get()->last_error.c_str(); return XGBAPIErrorStore::Get()->last_error.c_str();
} }

View File

@ -8,7 +8,6 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <xgboost/c_api.h>
/*! \brief macro to guard beginning and end section of all functions */ /*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try { #define API_BEGIN() try {

View File

@ -18,8 +18,8 @@
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/c_api.h"
#include "array_interface.h"
#include "../c_api/c_api_error.h" #include "../c_api/c_api_error.h"
namespace xgboost { namespace xgboost {
@ -496,6 +496,7 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
/*! \brief Data iterator that takes callback to return data, used in JVM package for /*! \brief Data iterator that takes callback to return data, used in JVM package for
* accepting data iterator. */ * accepting data iterator. */
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> { class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
public: public:
IteratorAdapter(DataIterHandle data_handle, IteratorAdapter(DataIterHandle data_handle,

View File

@ -7,6 +7,7 @@
#include "dmlc/io.h" #include "dmlc/io.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/version_config.h" #include "xgboost/version_config.h"
@ -533,7 +534,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
template <typename AdapterT> template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size ) { const std::string& cache_prefix, size_t page_size) {
if (cache_prefix.length() == 0) { if (cache_prefix.length() == 0) {
// Data split mode is fixed to be row right now. // Data split mode is fixed to be row right now.
return new data::SimpleDMatrix(adapter, missing, nthread); return new data::SimpleDMatrix(adapter, missing, nthread);
@ -563,9 +564,11 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>( template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread, data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size); const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>( template DMatrix *
data::IteratorAdapter* adapter, float missing, int nthread, DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
const std::string& cache_prefix, size_t page_size); XGBoostBatchCSR> *adapter,
float missing, int nthread, const std::string &cache_prefix,
size_t page_size);
SparsePage SparsePage::GetTranspose(int num_columns) const { SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage transpose; SparsePage transpose;

View File

@ -10,7 +10,7 @@
#include "array_interface.h" #include "array_interface.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "device_adapter.cuh" #include "device_adapter.cuh"
#include "device_dmatrix.h" #include "simple_dmatrix.h"
namespace xgboost { namespace xgboost {

View File

@ -4,8 +4,14 @@
* \brief the input data structure for gradient boosting * \brief the input data structure for gradient boosting
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "./simple_dmatrix.h" #include <vector>
#include <xgboost/data.h> #include <limits>
#include <algorithm>
#include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "simple_dmatrix.h"
#include "./simple_batch_iterator.h" #include "./simple_batch_iterator.h"
#include "../common/random.h" #include "../common/random.h"
#include "adapter.h" #include "adapter.h"
@ -195,7 +201,9 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread); int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread); int nthread);
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(
int nthread); IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
*adapter,
float missing, int nthread);
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -68,7 +68,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
} }
// A mock for JVM data iterator. // A mock for JVM data iterator.
class DataIterForTest { class CSRIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5}; std::vector<float> data_ {1, 2, 3, 4, 5};
std::vector<std::remove_pointer<decltype(std::declval<XGBoostBatchCSR>().index)>::type> std::vector<std::remove_pointer<decltype(std::declval<XGBoostBatchCSR>().index)>::type>
feature_idx_ {0, 1, 0, 1, 1}; feature_idx_ {0, 1, 0, 1, 1};
@ -100,16 +100,16 @@ class DataIterForTest {
size_t Iter() const { return iter_; } size_t Iter() const { return iter_; }
}; };
size_t constexpr DataIterForTest::kCols; size_t constexpr CSRIterForTest::kCols;
int SetDataNextForTest(DataIterHandle data_handle, int CSRSetDataNextForTest(DataIterHandle data_handle,
XGBCallbackSetData *set_function, XGBCallbackSetData *set_function,
DataHolderHandle set_function_handle) { DataHolderHandle set_function_handle) {
size_t constexpr kIters { 2 }; size_t constexpr kIters { 2 };
auto iter = static_cast<DataIterForTest *>(data_handle); auto iter = static_cast<CSRIterForTest *>(data_handle);
if (iter->Iter() < kIters) { if (iter->Iter() < kIters) {
auto batch = iter->Next(); auto batch = iter->Next();
batch.columns = DataIterForTest::kCols; batch.columns = CSRIterForTest::kCols;
set_function(set_function_handle, batch); set_function(set_function_handle, batch);
return 1; return 1;
} else { } else {
@ -118,15 +118,15 @@ int SetDataNextForTest(DataIterHandle data_handle,
} }
TEST(Adapter, IteratorAdaper) { TEST(Adapter, IteratorAdaper) {
DataIterForTest iter; CSRIterForTest iter;
data::IteratorAdapter adapter{&iter, SetDataNextForTest}; data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> adapter{&iter, CSRSetDataNextForTest};
constexpr size_t kRows { 6 }; constexpr size_t kRows { 6 };
std::unique_ptr<DMatrix> data { std::unique_ptr<DMatrix> data {
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1) DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
}; };
ASSERT_EQ(data->Info().num_col_, DataIterForTest::kCols); ASSERT_EQ(data->Info().num_col_, CSRIterForTest::kCols);
ASSERT_EQ(data->Info().num_row_, kRows); ASSERT_EQ(data->Info().num_row_, kRows);
} }
} // namespace xgboost } // namespace xgboost