parent
573f1c7db4
commit
39ddf40a8d
@ -508,7 +508,7 @@ class RegTree : public Model {
|
|||||||
* \brief drop the trace after fill, must be called after fill.
|
* \brief drop the trace after fill, must be called after fill.
|
||||||
* \param inst The sparse instance to drop.
|
* \param inst The sparse instance to drop.
|
||||||
*/
|
*/
|
||||||
void Drop(const SparsePage::Inst& inst);
|
void Drop();
|
||||||
/*!
|
/*!
|
||||||
* \brief returns the size of the feature vector
|
* \brief returns the size of the feature vector
|
||||||
* \return the size of the feature vector
|
* \return the size of the feature vector
|
||||||
@ -709,13 +709,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
|
|||||||
has_missing_ = data_.size() != feature_count;
|
has_missing_ = data_.size() != feature_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
|
inline void RegTree::FVec::Drop() {
|
||||||
for (auto const& entry : inst) {
|
Entry e{};
|
||||||
if (entry.index >= data_.size()) {
|
e.flag = -1;
|
||||||
continue;
|
std::fill_n(data_.data(), data_.size(), e);
|
||||||
}
|
|
||||||
data_[entry.index].flag = -1;
|
|
||||||
}
|
|
||||||
has_missing_ = true;
|
has_missing_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -149,10 +149,28 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
|
|||||||
return *columns_;
|
return *columns_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bst_bin_t GHistIndexMatrix::GetGindex(size_t ridx, size_t fidx) const {
|
||||||
|
auto begin = RowIdx(ridx);
|
||||||
|
if (IsDense()) {
|
||||||
|
return static_cast<bst_bin_t>(index[begin + fidx]);
|
||||||
|
}
|
||||||
|
auto end = RowIdx(ridx + 1);
|
||||||
|
auto const& cut_ptrs = cut.Ptrs();
|
||||||
|
auto f_begin = cut_ptrs[fidx];
|
||||||
|
auto f_end = cut_ptrs[fidx + 1];
|
||||||
|
return BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||||
|
}
|
||||||
|
|
||||||
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||||
auto const &values = cut.Values();
|
auto const &values = cut.Values();
|
||||||
auto const &mins = cut.MinValues();
|
auto const &mins = cut.MinValues();
|
||||||
auto const &ptrs = cut.Ptrs();
|
auto const &ptrs = cut.Ptrs();
|
||||||
|
return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat);
|
||||||
|
}
|
||||||
|
|
||||||
|
float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
|
||||||
|
std::vector<float> const &values, std::vector<float> const &mins,
|
||||||
|
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
auto f_begin = ptrs[fidx];
|
auto f_begin = ptrs[fidx];
|
||||||
auto f_end = ptrs[fidx + 1];
|
auto f_end = ptrs[fidx + 1];
|
||||||
@ -172,8 +190,8 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
|||||||
}
|
}
|
||||||
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||||
};
|
};
|
||||||
|
switch (columns_->GetColumnType(fidx)) {
|
||||||
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
|
case common::kDenseColumn: {
|
||||||
if (columns_->AnyMissing()) {
|
if (columns_->AnyMissing()) {
|
||||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||||
@ -182,15 +200,18 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
|||||||
} else {
|
} else {
|
||||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||||
return get_bin_val(column);
|
auto bin_idx = column[ridx];
|
||||||
|
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
case common::kSparseColumn: {
|
||||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||||
return get_bin_val(column);
|
return get_bin_val(column);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SPAN_CHECK(false);
|
SPAN_CHECK(false);
|
||||||
return std::numeric_limits<float>::quiet_NaN();
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
|||||||
@ -227,7 +227,12 @@ class GHistIndexMatrix {
|
|||||||
|
|
||||||
common::ColumnMatrix const& Transpose() const;
|
common::ColumnMatrix const& Transpose() const;
|
||||||
|
|
||||||
|
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
|
||||||
|
|
||||||
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
||||||
|
float GetFvalue(std::vector<std::uint32_t> const& ptrs, std::vector<float> const& values,
|
||||||
|
std::vector<float> const& mins, bst_row_t ridx, bst_feature_t fidx,
|
||||||
|
bool is_cat) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<common::ColumnMatrix> columns_;
|
std::unique_ptr<common::ColumnMatrix> columns_;
|
||||||
|
|||||||
@ -63,7 +63,7 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
|||||||
psum += (*trees[i])[nidx].LeafValue();
|
psum += (*trees[i])[nidx].LeafValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p_feats->Drop(inst);
|
p_feats->Drop();
|
||||||
return psum;
|
return psum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,13 +116,11 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataView>
|
void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset,
|
||||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
std::vector<RegTree::FVec> *p_feats) {
|
||||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
|
||||||
for (size_t i = 0; i < block_size; ++i) {
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
||||||
const SparsePage::Inst inst = (*batch)[batch_offset + i];
|
feats.Drop();
|
||||||
feats.Drop(inst);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,11 +140,15 @@ struct SparsePageView {
|
|||||||
struct GHistIndexMatrixView {
|
struct GHistIndexMatrixView {
|
||||||
private:
|
private:
|
||||||
GHistIndexMatrix const &page_;
|
GHistIndexMatrix const &page_;
|
||||||
uint64_t n_features_;
|
std::uint64_t const n_features_;
|
||||||
common::Span<FeatureType const> ft_;
|
common::Span<FeatureType const> ft_;
|
||||||
common::Span<Entry> workspace_;
|
common::Span<Entry> workspace_;
|
||||||
std::vector<size_t> current_unroll_;
|
std::vector<size_t> current_unroll_;
|
||||||
|
|
||||||
|
std::vector<std::uint32_t> const& ptrs_;
|
||||||
|
std::vector<float> const& mins_;
|
||||||
|
std::vector<float> const& values_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
size_t base_rowid;
|
size_t base_rowid;
|
||||||
|
|
||||||
@ -159,6 +161,9 @@ struct GHistIndexMatrixView {
|
|||||||
ft_{ft},
|
ft_{ft},
|
||||||
workspace_{workplace},
|
workspace_{workplace},
|
||||||
current_unroll_(n_threads > 0 ? n_threads : 1, 0),
|
current_unroll_(n_threads > 0 ? n_threads : 1, 0),
|
||||||
|
ptrs_{_page.cut.Ptrs()},
|
||||||
|
mins_{_page.cut.MinValues()},
|
||||||
|
values_{_page.cut.Values()},
|
||||||
base_rowid{_page.base_rowid} {}
|
base_rowid{_page.base_rowid} {}
|
||||||
|
|
||||||
SparsePage::Inst operator[](size_t r) {
|
SparsePage::Inst operator[](size_t r) {
|
||||||
@ -167,7 +172,7 @@ struct GHistIndexMatrixView {
|
|||||||
size_t non_missing{beg};
|
size_t non_missing{beg};
|
||||||
|
|
||||||
for (bst_feature_t c = 0; c < n_features_; ++c) {
|
for (bst_feature_t c = 0; c < n_features_; ++c) {
|
||||||
float f = page_.GetFvalue(r, c, common::IsCat(ft_, c));
|
float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c));
|
||||||
if (!common::CheckNAN(f)) {
|
if (!common::CheckNAN(f)) {
|
||||||
workspace_[non_missing] = Entry{c, f};
|
workspace_[non_missing] = Entry{c, f};
|
||||||
++non_missing;
|
++non_missing;
|
||||||
@ -250,10 +255,9 @@ void PredictBatchByBlockOfRowsKernel(
|
|||||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
||||||
p_thread_temp);
|
p_thread_temp);
|
||||||
// process block of rows through all trees to keep cache locality
|
// process block of rows through all trees to keep cache locality
|
||||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
|
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
|
||||||
batch_offset + batch.base_rowid, num_group, thread_temp,
|
num_group, thread_temp, fvec_offset, block_size);
|
||||||
fvec_offset, block_size);
|
FVecDrop(block_size, fvec_offset, p_thread_temp);
|
||||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -470,7 +474,7 @@ class CPUPredictor : public Predictor {
|
|||||||
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
||||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||||
}
|
}
|
||||||
feats.Drop(page[i]);
|
feats.Drop();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -544,7 +548,7 @@ class CPUPredictor : public Predictor {
|
|||||||
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
feats.Drop(page[i]);
|
feats.Drop();
|
||||||
// add base margin to BIAS
|
// add base margin to BIAS
|
||||||
if (base_margin.Size() != 0) {
|
if (base_margin.Size() != 0) {
|
||||||
CHECK_EQ(base_margin.Shape(1), ngroup);
|
CHECK_EQ(base_margin.Shape(1), ngroup);
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class TreeRefresher : public TreeUpdater {
|
|||||||
dmlc::BeginPtr(stemp[tid]) + offset);
|
dmlc::BeginPtr(stemp[tid]) + offset);
|
||||||
offset += tree->param.num_nodes;
|
offset += tree->param.num_nodes;
|
||||||
}
|
}
|
||||||
feats.Drop(inst);
|
feats.Drop();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// aggregate the statistics
|
// aggregate the statistics
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user