Accept numpy array view. (#4147)

* Accept array view (slice) in metainfo.
This commit is contained in:
Jiaming Yuan
2019-02-18 22:21:34 +08:00
committed by GitHub
parent 0ff84d950e
commit a985a99cf0
6 changed files with 152 additions and 43 deletions

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2014 by Contributors
// Copyright (c) 2014-2019 by Contributors
#include <xgboost/data.h>
#include <xgboost/learner.h>
@@ -768,9 +768,9 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
}
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
xgboost::bst_ulong len) {
const char* field,
const xgboost::bst_float* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
@@ -778,14 +778,38 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
xgboost::bst_ulong len) {
XGB_DLL int XGDMatrixSetFloatInfoStrided(DMatrixHandle handle,
const char* field,
const xgboost::bst_float* info,
const xgboost::bst_ulong stride,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kUInt32, len);
->get()->Info().SetInfo(field, info, kFloat32, stride, len);
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* array,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, array, kUInt32, len);
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfoStrided(DMatrixHandle handle,
const char* field,
const unsigned* array,
const xgboost::bst_ulong stride,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, array, kUInt32, stride, len);
API_END();
}
@@ -864,8 +888,8 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
// xgboost implementation
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,
BoosterHandle *out) {
xgboost::bst_ulong len,
BoosterHandle *out) {
API_BEGIN();
std::vector<std::shared_ptr<DMatrix> > mats;
for (xgboost::bst_ulong i = 0; i < len; ++i) {