Span class. (#3548)
* Add basic Span class based on ISO++20. * Use Span<Entry const> instead of Inst in SparsePage. * Add DeviceSpan in HostDeviceVector, use it in regression obj.
This commit is contained in:
committed by
Rory Mitchell
parent
2b7a1c5780
commit
2c502784ff
@@ -15,6 +15,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
#include "../../src/common/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
// forward declare learner.
|
||||
@@ -133,7 +134,7 @@ struct Entry {
|
||||
/*!
|
||||
* \brief constructor with index and value
|
||||
* \param index The feature or row index.
|
||||
* \param fvalue THe feature value.
|
||||
* \param fvalue The feature value.
|
||||
*/
|
||||
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
|
||||
/*! \brief reversely compare feature values */
|
||||
@@ -155,24 +156,14 @@ class SparsePage {
|
||||
std::vector<Entry> data;
|
||||
|
||||
size_t base_rowid;
|
||||
|
||||
/*! \brief an instance of sparse vector in the batch */
|
||||
struct Inst {
|
||||
/*! \brief pointer to the elements*/
|
||||
const Entry *data{nullptr};
|
||||
/*! \brief length of the instance */
|
||||
bst_uint length{0};
|
||||
/*! \brief constructor */
|
||||
Inst() = default;
|
||||
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
|
||||
/*! \brief get i-th pair in the sparse vector*/
|
||||
inline const Entry& operator[](size_t i) const {
|
||||
return data[i];
|
||||
}
|
||||
};
|
||||
using Inst = common::Span<Entry const>;
|
||||
|
||||
/*! \brief get i-th row from the batch */
|
||||
inline Inst operator[](size_t i) const {
|
||||
return {data.data() + offset[i], static_cast<bst_uint>(offset[i + 1] - offset[i])};
|
||||
return {data.data() + offset[i],
|
||||
static_cast<Inst::index_type>(offset[i + 1] - offset[i])};
|
||||
}
|
||||
|
||||
/*! \brief constructor */
|
||||
@@ -234,12 +225,12 @@ class SparsePage {
|
||||
* \param inst an instance row
|
||||
*/
|
||||
inline void Push(const Inst &inst) {
|
||||
offset.push_back(offset.back() + inst.length);
|
||||
offset.push_back(offset.back() + inst.size());
|
||||
size_t begin = data.size();
|
||||
data.resize(begin + inst.length);
|
||||
if (inst.length != 0) {
|
||||
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
|
||||
sizeof(Entry) * inst.length);
|
||||
data.resize(begin + inst.size());
|
||||
if (inst.size() != 0) {
|
||||
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data(),
|
||||
sizeof(Entry) * inst.size());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,7 +319,7 @@ class DMatrix {
|
||||
* \brief check if column access is supported, if not, initialize column access.
|
||||
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
|
||||
* this is a hint information that can be ignored by the implementation.
|
||||
* \param sorted If column features should be in sorted order
|
||||
* \param sorted If column features should be in sorted order
|
||||
* \return Number of column blocks in the column access.
|
||||
*/
|
||||
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
|
||||
|
||||
@@ -574,14 +574,14 @@ inline void RegTree::FVec::Init(size_t size) {
|
||||
}
|
||||
|
||||
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
for (bst_uint i = 0; i < inst.size(); ++i) {
|
||||
if (inst[i].index >= data_.size()) continue;
|
||||
data_[inst[i].index].fvalue = inst[i].fvalue;
|
||||
}
|
||||
}
|
||||
|
||||
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
for (bst_uint i = 0; i < inst.size(); ++i) {
|
||||
if (inst[i].index >= data_.size()) continue;
|
||||
data_[inst[i].index].flag = -1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user