Accept numpy array view. (#4147)
* Accept array view (slice) in metainfo.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2014 by Contributors
|
||||
// Copyright (c) 2014-2019 by Contributors
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/learner.h>
|
||||
@@ -768,9 +768,9 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const bst_float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
const char* field,
|
||||
const xgboost::bst_float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
@@ -778,14 +778,38 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
xgboost::bst_ulong len) {
|
||||
XGB_DLL int XGDMatrixSetFloatInfoStrided(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const xgboost::bst_float* info,
|
||||
const xgboost::bst_ulong stride,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, kUInt32, len);
|
||||
->get()->Info().SetInfo(field, info, kFloat32, stride, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* array,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, array, kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfoStrided(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* array,
|
||||
const xgboost::bst_ulong stride,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, array, kUInt32, stride, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -864,8 +888,8 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
|
||||
// xgboost implementation
|
||||
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
xgboost::bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
xgboost::bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
std::vector<std::shared_ptr<DMatrix> > mats;
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
|
||||
Reference in New Issue
Block a user