Avoid including c_api.h in header files. (#5782)
This commit is contained in:
parent
3028fa6b42
commit
306e38ff31
@ -38,14 +38,14 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
|
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
|
||||||
registry->Register(callback);
|
registry->Register(callback);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
int XGDMatrixCreateFromFile(const char *fname,
|
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
||||||
int silent,
|
int silent,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
@ -60,7 +60,7 @@ int XGDMatrixCreateFromFile(const char *fname,
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGDMatrixCreateFromDataIter(
|
XGB_DLL int XGDMatrixCreateFromDataIter(
|
||||||
void *data_handle, // a Java interator
|
void *data_handle, // a Java iterator
|
||||||
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
|
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
|
||||||
const char *cache_info, DMatrixHandle *out) {
|
const char *cache_info, DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
@ -69,7 +69,8 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
|
|||||||
if (cache_info != nullptr) {
|
if (cache_info != nullptr) {
|
||||||
scache = cache_info;
|
scache = cache_info;
|
||||||
}
|
}
|
||||||
xgboost::data::IteratorAdapter adapter(data_handle, callback);
|
xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||||
|
XGBoostBatchCSR> adapter(data_handle, callback);
|
||||||
*out = new std::shared_ptr<DMatrix> {
|
*out = new std::shared_ptr<DMatrix> {
|
||||||
DMatrix::Create(
|
DMatrix::Create(
|
||||||
&adapter, std::numeric_limits<float>::quiet_NaN(),
|
&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
* \brief C error handling
|
* \brief C error handling
|
||||||
*/
|
*/
|
||||||
#include <dmlc/thread_local.h>
|
#include <dmlc/thread_local.h>
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
#include "./c_api_error.h"
|
#include "./c_api_error.h"
|
||||||
|
|
||||||
struct XGBAPIErrorEntry {
|
struct XGBAPIErrorEntry {
|
||||||
@ -12,7 +13,7 @@ struct XGBAPIErrorEntry {
|
|||||||
|
|
||||||
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
|
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
|
||||||
|
|
||||||
const char *XGBGetLastError() {
|
XGB_DLL const char *XGBGetLastError() {
|
||||||
return XGBAPIErrorStore::Get()->last_error.c_str();
|
return XGBAPIErrorStore::Get()->last_error.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include <dmlc/base.h>
|
#include <dmlc/base.h>
|
||||||
#include <dmlc/logging.h>
|
#include <dmlc/logging.h>
|
||||||
#include <xgboost/c_api.h>
|
|
||||||
|
|
||||||
/*! \brief macro to guard beginning and end section of all functions */
|
/*! \brief macro to guard beginning and end section of all functions */
|
||||||
#define API_BEGIN() try {
|
#define API_BEGIN() try {
|
||||||
|
|||||||
@ -18,8 +18,8 @@
|
|||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/c_api.h"
|
|
||||||
|
|
||||||
|
#include "array_interface.h"
|
||||||
#include "../c_api/c_api_error.h"
|
#include "../c_api/c_api_error.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -496,6 +496,7 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
|||||||
|
|
||||||
/*! \brief Data iterator that takes callback to return data, used in JVM package for
|
/*! \brief Data iterator that takes callback to return data, used in JVM package for
|
||||||
* accepting data iterator. */
|
* accepting data iterator. */
|
||||||
|
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
||||||
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||||
public:
|
public:
|
||||||
IteratorAdapter(DataIterHandle data_handle,
|
IteratorAdapter(DataIterHandle data_handle,
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include "dmlc/io.h"
|
#include "dmlc/io.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/version_config.h"
|
#include "xgboost/version_config.h"
|
||||||
@ -533,7 +534,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix, size_t page_size ) {
|
const std::string& cache_prefix, size_t page_size) {
|
||||||
if (cache_prefix.length() == 0) {
|
if (cache_prefix.length() == 0) {
|
||||||
// Data split mode is fixed to be row right now.
|
// Data split mode is fixed to be row right now.
|
||||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
return new data::SimpleDMatrix(adapter, missing, nthread);
|
||||||
@ -563,9 +564,11 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
|||||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||||
data::FileAdapter* adapter, float missing, int nthread,
|
data::FileAdapter* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix, size_t page_size);
|
const std::string& cache_prefix, size_t page_size);
|
||||||
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
|
template DMatrix *
|
||||||
data::IteratorAdapter* adapter, float missing, int nthread,
|
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||||
const std::string& cache_prefix, size_t page_size);
|
XGBoostBatchCSR> *adapter,
|
||||||
|
float missing, int nthread, const std::string &cache_prefix,
|
||||||
|
size_t page_size);
|
||||||
|
|
||||||
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||||
SparsePage transpose;
|
SparsePage transpose;
|
||||||
|
|||||||
@ -10,7 +10,7 @@
|
|||||||
#include "array_interface.h"
|
#include "array_interface.h"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
#include "device_dmatrix.h"
|
#include "simple_dmatrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
|
|||||||
@ -4,8 +4,14 @@
|
|||||||
* \brief the input data structure for gradient boosting
|
* \brief the input data structure for gradient boosting
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include "./simple_dmatrix.h"
|
#include <vector>
|
||||||
#include <xgboost/data.h>
|
#include <limits>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
|
|
||||||
|
#include "simple_dmatrix.h"
|
||||||
#include "./simple_batch_iterator.h"
|
#include "./simple_batch_iterator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
@ -195,7 +201,9 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
|||||||
int nthread);
|
int nthread);
|
||||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
||||||
int nthread);
|
int nthread);
|
||||||
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(
|
||||||
int nthread);
|
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
||||||
|
*adapter,
|
||||||
|
float missing, int nthread);
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -68,7 +68,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A mock for JVM data iterator.
|
// A mock for JVM data iterator.
|
||||||
class DataIterForTest {
|
class CSRIterForTest {
|
||||||
std::vector<float> data_ {1, 2, 3, 4, 5};
|
std::vector<float> data_ {1, 2, 3, 4, 5};
|
||||||
std::vector<std::remove_pointer<decltype(std::declval<XGBoostBatchCSR>().index)>::type>
|
std::vector<std::remove_pointer<decltype(std::declval<XGBoostBatchCSR>().index)>::type>
|
||||||
feature_idx_ {0, 1, 0, 1, 1};
|
feature_idx_ {0, 1, 0, 1, 1};
|
||||||
@ -100,16 +100,16 @@ class DataIterForTest {
|
|||||||
size_t Iter() const { return iter_; }
|
size_t Iter() const { return iter_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t constexpr DataIterForTest::kCols;
|
size_t constexpr CSRIterForTest::kCols;
|
||||||
|
|
||||||
int SetDataNextForTest(DataIterHandle data_handle,
|
int CSRSetDataNextForTest(DataIterHandle data_handle,
|
||||||
XGBCallbackSetData *set_function,
|
XGBCallbackSetData *set_function,
|
||||||
DataHolderHandle set_function_handle) {
|
DataHolderHandle set_function_handle) {
|
||||||
size_t constexpr kIters { 2 };
|
size_t constexpr kIters { 2 };
|
||||||
auto iter = static_cast<DataIterForTest *>(data_handle);
|
auto iter = static_cast<CSRIterForTest *>(data_handle);
|
||||||
if (iter->Iter() < kIters) {
|
if (iter->Iter() < kIters) {
|
||||||
auto batch = iter->Next();
|
auto batch = iter->Next();
|
||||||
batch.columns = DataIterForTest::kCols;
|
batch.columns = CSRIterForTest::kCols;
|
||||||
set_function(set_function_handle, batch);
|
set_function(set_function_handle, batch);
|
||||||
return 1;
|
return 1;
|
||||||
} else {
|
} else {
|
||||||
@ -118,15 +118,15 @@ int SetDataNextForTest(DataIterHandle data_handle,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Adapter, IteratorAdaper) {
|
TEST(Adapter, IteratorAdaper) {
|
||||||
DataIterForTest iter;
|
CSRIterForTest iter;
|
||||||
data::IteratorAdapter adapter{&iter, SetDataNextForTest};
|
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||||
|
XGBoostBatchCSR> adapter{&iter, CSRSetDataNextForTest};
|
||||||
constexpr size_t kRows { 6 };
|
constexpr size_t kRows { 6 };
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> data {
|
std::unique_ptr<DMatrix> data {
|
||||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
|
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
|
||||||
};
|
};
|
||||||
ASSERT_EQ(data->Info().num_col_, DataIterForTest::kCols);
|
ASSERT_EQ(data->Info().num_col_, CSRIterForTest::kCols);
|
||||||
ASSERT_EQ(data->Info().num_row_, kRows);
|
ASSERT_EQ(data->Info().num_row_, kRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user