Accept numpy array view. (#4147)
* Accept array view (slice) in metainfo.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2019 by Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include <xgboost/data.h>
|
||||
@@ -100,45 +100,70 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
||||
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
||||
switch (dtype) { \
|
||||
case kFloat32: { \
|
||||
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
|
||||
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; \
|
||||
break; \
|
||||
} \
|
||||
case kDouble: { \
|
||||
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
|
||||
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; \
|
||||
break; \
|
||||
} \
|
||||
case kUInt32: { \
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; \
|
||||
break; \
|
||||
} \
|
||||
case kUInt64: { \
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; \
|
||||
break; \
|
||||
} \
|
||||
default: LOG(FATAL) << "Unknown data type" << dtype; \
|
||||
} \
|
||||
|
||||
|
||||
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
this->SetInfo(key, dptr, dtype, 1, num);
|
||||
}
|
||||
|
||||
template <typename IterIn, typename IterOut>
|
||||
void StridedCopy(IterIn in_beg, IterIn in_end, IterOut out_beg, size_t stride) {
|
||||
if (stride != 1) {
|
||||
IterOut out_iter = out_beg;
|
||||
for (IterIn in_iter = in_beg; in_iter < in_end; in_iter += stride) {
|
||||
*out_iter = *in_iter;
|
||||
out_iter++;
|
||||
}
|
||||
} else {
|
||||
// There can be builtin optimization in std::copy
|
||||
std::copy(in_beg, in_end, out_beg);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfo(
|
||||
const char* key, const void* dptr, DataType dtype, size_t stride, size_t num) {
|
||||
size_t view_length =
|
||||
static_cast<size_t>(std::ceil(static_cast<bst_float>(num) / stride));
|
||||
if (!std::strcmp(key, "root_index")) {
|
||||
root_index_.resize(num);
|
||||
root_index_.resize(view_length);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, root_index_.begin()));
|
||||
StridedCopy(cast_dptr, cast_dptr + num, root_index_.begin(), stride));
|
||||
} else if (!std::strcmp(key, "label")) {
|
||||
auto& labels = labels_.HostVector();
|
||||
labels.resize(num);
|
||||
labels.resize(view_length);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||
StridedCopy(cast_dptr, cast_dptr + num, labels.begin(), stride));
|
||||
} else if (!std::strcmp(key, "weight")) {
|
||||
auto& weights = weights_.HostVector();
|
||||
weights.resize(num);
|
||||
weights.resize(view_length);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||
StridedCopy(cast_dptr, cast_dptr + num, weights.begin(), stride));
|
||||
} else if (!std::strcmp(key, "base_margin")) {
|
||||
auto& base_margin = base_margin_.HostVector();
|
||||
base_margin.resize(num);
|
||||
base_margin.resize(view_length);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
|
||||
StridedCopy(cast_dptr, cast_dptr + num, base_margin.begin(), stride));
|
||||
} else if (!std::strcmp(key, "group")) {
|
||||
group_ptr_.resize(num + 1);
|
||||
group_ptr_.resize(view_length+1);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1));
|
||||
StridedCopy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1, stride));
|
||||
group_ptr_[0] = 0;
|
||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
||||
@@ -146,7 +171,6 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri,
|
||||
bool silent,
|
||||
bool load_row_split,
|
||||
|
||||
Reference in New Issue
Block a user