Span class. (#3548)

* Add basic Span class based on ISO++20.

* Use Span<Entry const> instead of Inst in SparsePage.

* Add DeviceSpan in HostDeviceVector, use it in regression obj.
This commit is contained in:
trivialfis
2018-08-14 13:58:11 +08:00
committed by Rory Mitchell
parent 2b7a1c5780
commit 2c502784ff
28 changed files with 1927 additions and 138 deletions

View File

@@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"
namespace xgboost {
// forward declare learner.
@@ -133,7 +134,7 @@ struct Entry {
/*!
* \brief constructor with index and value
* \param index The feature or row index.
* \param fvalue THe feature value.
* \param fvalue The feature value.
*/
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
@@ -155,24 +156,14 @@ class SparsePage {
std::vector<Entry> data;
size_t base_rowid;
/*! \brief an instance of sparse vector in the batch */
struct Inst {
/*! \brief pointer to the elements*/
const Entry *data{nullptr};
/*! \brief length of the instance */
bst_uint length{0};
/*! \brief constructor */
Inst() = default;
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
/*! \brief get i-th pair in the sparse vector*/
inline const Entry& operator[](size_t i) const {
return data[i];
}
};
using Inst = common::Span<Entry const>;
/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return {data.data() + offset[i], static_cast<bst_uint>(offset[i + 1] - offset[i])};
return {data.data() + offset[i],
static_cast<Inst::index_type>(offset[i + 1] - offset[i])};
}
/*! \brief constructor */
@@ -234,12 +225,12 @@ class SparsePage {
* \param inst an instance row
*/
inline void Push(const Inst &inst) {
offset.push_back(offset.back() + inst.length);
offset.push_back(offset.back() + inst.size());
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(Entry) * inst.length);
data.resize(begin + inst.size());
if (inst.size() != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data(),
sizeof(Entry) * inst.size());
}
}
@@ -328,7 +319,7 @@ class DMatrix {
* \brief check if column access is supported, if not, initialize column access.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \param sorted If column features should be in sorted order
* \param sorted If column features should be in sorted order
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;

View File

@@ -574,14 +574,14 @@ inline void RegTree::FVec::Init(size_t size) {
}
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
}
}
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;
}