Improve multi-threaded performance (#2104)
* Add UpdatePredictionCache() option to updaters Some updaters (e.g. fast_hist) has enough information to quickly compute prediction cache for the training data. Each updater may override UpdaterPredictionCache() method to update the prediction cache. Note: this trick does not apply to validation data. * Respond to code review * Disable some debug messages by default * Document UpdatePredictionCache() interface * Remove base_margin logic from UpdatePredictionCache() implementation * Do not take pointer to cfg, as reference may get stale * Improve multi-threaded performance * Use columnwise accessor to accelerate ApplySplit() step, with support for a compressed representation * Parallel sort for evaluation step * Inline BuildHist() function * Cache gradient pairs when building histograms in BuildHist() * Add missing #if macro * Respond to code review * Use wrapper to enable parallel sort on Linux * Fix C++ compatibility issues * MSVC doesn't support unsigned in OpenMP loops * gcc 4.6 doesn't support using keyword * Fix lint issues * Respond to code review * Fix bug in ApplySplitSparseData() * Attempting to read beyond the end of a sparse column * Mishandling the case where an entire range of rows have missing values * Fix training continuation bug Disable UpdatePredictionCache() in the first iteration. This way, we can accomodate the scenario where we build off of an existing (nonempty) ensemble. * Add regression test for fast_hist * Respond to code review * Add back old version of ApplySplitSparseData
This commit is contained in:
parent
332aea26a3
commit
14fba01b5a
@ -48,6 +48,15 @@
|
||||
#define XGBOOST_ALIGNAS(X)
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8
|
||||
#include <parallel/algorithm>
|
||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))
|
||||
#else
|
||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z))
|
||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
|
||||
#endif
|
||||
|
||||
/*! \brief namespace of xgboo st*/
|
||||
namespace xgboost {
|
||||
/*!
|
||||
|
||||
@ -45,14 +45,20 @@ class TreeUpdater {
|
||||
virtual void Update(const std::vector<bst_gpair>& gpair,
|
||||
DMatrix* data,
|
||||
const std::vector<RegTree*>& trees) = 0;
|
||||
|
||||
/*!
|
||||
* \brief this is simply a function for optimizing performance
|
||||
* this function asks the updater to return the leaf position of each instance in the previous performed update.
|
||||
* if it is cached in the updater, if it is not available, return nullptr
|
||||
* \return array of leaf position of each instance in the last updated tree
|
||||
* \brief determines whether updater has enough knowledge about a given dataset
|
||||
* to quickly update prediction cache its training data and performs the
|
||||
* update if possible.
|
||||
* \param data: data matrix
|
||||
* \param out_preds: prediction cache to be updated
|
||||
* \return boolean indicating whether updater has capability to update
|
||||
* the prediction cache. If true, the prediction cache will have been
|
||||
* updated by the time this function returns.
|
||||
*/
|
||||
virtual const int* GetLeafPosition() const {
|
||||
return nullptr;
|
||||
virtual bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) const {
|
||||
return false;
|
||||
}
|
||||
/*!
|
||||
* \brief Create a tree updater given name
|
||||
|
||||
@ -155,6 +155,7 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
DMLC_REGISTER_PARAMETER(CLIParam);
|
||||
|
||||
void CLITrain(const CLIParam& param) {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
||||
@ -193,6 +194,9 @@ void CLITrain(const CLIParam& param) {
|
||||
learner->InitModel();
|
||||
}
|
||||
}
|
||||
if (param.silent == 0) {
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
|
||||
}
|
||||
// start training.
|
||||
const double start = dmlc::GetTime();
|
||||
for (int i = version / 2; i < param.num_round; ++i) {
|
||||
|
||||
231
src/common/column_matrix.h
Normal file
231
src/common/column_matrix.h
Normal file
@ -0,0 +1,231 @@
|
||||
/*!
|
||||
* Copyright 2017 by Contributors
|
||||
* \file column_matrix.h
|
||||
* \brief Utility for fast column-wise access
|
||||
* \author Philip Cho
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||
|
||||
#define XGBOOST_TYPE_SWITCH(dtype, OP) \
|
||||
switch (dtype) { \
|
||||
case xgboost::common::uint32 : { \
|
||||
typedef uint32_t DType; \
|
||||
OP; break; \
|
||||
} \
|
||||
case xgboost::common::uint16 : { \
|
||||
typedef uint16_t DType; \
|
||||
OP; break; \
|
||||
} \
|
||||
case xgboost::common::uint8 : { \
|
||||
typedef uint8_t DType; \
|
||||
OP; break; \
|
||||
default: LOG(FATAL) << "don't recognize type flag" << dtype; \
|
||||
} \
|
||||
}
|
||||
|
||||
#include <type_traits>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
/*! \brief indicator of data type used for storing bin id's in a column. */
|
||||
enum DataType {
|
||||
uint8 = 1,
|
||||
uint16 = 2,
|
||||
uint32 = 4
|
||||
};
|
||||
|
||||
/*! \brief column type */
|
||||
enum ColumnType {
|
||||
kDenseColumn,
|
||||
kSparseColumn
|
||||
};
|
||||
|
||||
/*! \brief a column storage, to be used with ApplySplit. Note that each
|
||||
bin id is stored as index[i] + index_base. */
|
||||
template<typename T>
|
||||
class Column {
|
||||
public:
|
||||
ColumnType type;
|
||||
const T* index;
|
||||
uint32_t index_base;
|
||||
const uint32_t* row_ind;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
/*! \brief a collection of columns, with support for construction from
|
||||
GHistIndexMatrix. */
|
||||
class ColumnMatrix {
|
||||
public:
|
||||
// get number of features
|
||||
inline uint32_t GetNumFeature() const {
|
||||
return type_.size();
|
||||
}
|
||||
|
||||
// construct column matrix from GHistIndexMatrix
|
||||
inline void Init(const GHistIndexMatrix& gmat, DataType dtype) {
|
||||
this->dtype = dtype;
|
||||
/* if dtype is smaller than uint32_t, multiple bin_id's will be stored in each
|
||||
slot of internal buffer. */
|
||||
packing_factor_ = sizeof(uint32_t) / static_cast<size_t>(this->dtype);
|
||||
|
||||
const uint32_t nfeature = gmat.cut->row_ptr.size() - 1;
|
||||
const omp_ulong nrow = static_cast<omp_ulong>(gmat.row_ptr.size() - 1);
|
||||
|
||||
// identify type of each column
|
||||
feature_counts_.resize(nfeature);
|
||||
type_.resize(nfeature);
|
||||
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
|
||||
|
||||
uint32_t max_val = 0;
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
max_val = static_cast<uint32_t>(std::numeric_limits<DType>::max());
|
||||
});
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val);
|
||||
}
|
||||
|
||||
gmat.GetFeatureCounts(&feature_counts_[0]);
|
||||
// classify features
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
if (static_cast<double>(feature_counts_[fid]) < 0.5*nrow) {
|
||||
type_[fid] = kSparseColumn;
|
||||
} else {
|
||||
type_[fid] = kDenseColumn;
|
||||
}
|
||||
}
|
||||
|
||||
// want to compute storage boundary for each feature
|
||||
// using variants of prefix sum scan
|
||||
boundary_.resize(nfeature);
|
||||
bst_uint accum_index_ = 0;
|
||||
bst_uint accum_row_ind_ = 0;
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
boundary_[fid].index_begin = accum_index_;
|
||||
boundary_[fid].row_ind_begin = accum_row_ind_;
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
accum_index_ += nrow;
|
||||
} else {
|
||||
accum_index_ += feature_counts_[fid];
|
||||
accum_row_ind_ += feature_counts_[fid];
|
||||
}
|
||||
boundary_[fid].index_end = accum_index_;
|
||||
boundary_[fid].row_ind_end = accum_row_ind_;
|
||||
}
|
||||
|
||||
index_.resize((boundary_[nfeature - 1].index_end
|
||||
+ (packing_factor_ - 1)) / packing_factor_);
|
||||
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);
|
||||
|
||||
// store least bin id for each feature
|
||||
index_base_.resize(nfeature);
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
index_base_[fid] = gmat.cut->row_ptr[fid];
|
||||
}
|
||||
|
||||
// fill index_ for dense columns
|
||||
for (uint32_t fid = 0; fid < nfeature; ++fid) {
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
const uint32_t ibegin = boundary_[fid].index_begin;
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
const size_t block_offset = ibegin / packing_factor_;
|
||||
const size_t elem_offset = ibegin % packing_factor_;
|
||||
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
|
||||
DType* end = begin + nrow;
|
||||
std::fill(begin, end, std::numeric_limits<DType>::max());
|
||||
// max() indicates missing values
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// loop over all rows and fill column entries
|
||||
// num_nonzeros[fid] = how many nonzeros have this feature accumulated so far?
|
||||
std::vector<uint32_t> num_nonzeros;
|
||||
num_nonzeros.resize(nfeature);
|
||||
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
|
||||
for (uint32_t rid = 0; rid < nrow; ++rid) {
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
size_t fid = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t bin_id = gmat.index[i];
|
||||
while (bin_id >= gmat.cut->row_ptr[fid + 1]) {
|
||||
++fid;
|
||||
}
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
|
||||
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
|
||||
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
|
||||
begin[rid] = bin_id - index_base_[fid];
|
||||
});
|
||||
} else {
|
||||
XGBOOST_TYPE_SWITCH(this->dtype, {
|
||||
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
|
||||
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
|
||||
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
|
||||
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
|
||||
});
|
||||
row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid;
|
||||
++num_nonzeros[fid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Fetch an individual column. This code should be used with XGBOOST_TYPE_SWITCH
|
||||
to determine type of bin id's */
|
||||
template<typename T>
|
||||
inline Column<T> GetColumn(unsigned fid) const {
|
||||
const bool valid_type = std::is_same<T, uint32_t>::value
|
||||
|| std::is_same<T, uint16_t>::value
|
||||
|| std::is_same<T, uint8_t>::value;
|
||||
CHECK(valid_type);
|
||||
|
||||
Column<T> c;
|
||||
|
||||
c.type = type_[fid];
|
||||
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
|
||||
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
|
||||
c.index = reinterpret_cast<const T*>(&index_[block_offset]) + elem_offset;
|
||||
c.index_base = index_base_[fid];
|
||||
c.row_ind = &row_ind_[boundary_[fid].row_ind_begin];
|
||||
c.len = boundary_[fid].index_end - boundary_[fid].index_begin;
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
public:
|
||||
DataType dtype;
|
||||
|
||||
private:
|
||||
struct ColumnBoundary {
|
||||
// indicate where each column's index and row_ind is stored.
|
||||
// index_begin and index_end are logical offsets, so they should be converted to
|
||||
// actual offsets by scaling with packing_factor_
|
||||
unsigned index_begin;
|
||||
unsigned index_end;
|
||||
unsigned row_ind_begin;
|
||||
unsigned row_ind_end;
|
||||
};
|
||||
|
||||
std::vector<bst_uint> feature_counts_;
|
||||
std::vector<ColumnType> type_;
|
||||
std::vector<uint32_t> index_; // index_: may store smaller integers; needs padding
|
||||
std::vector<uint32_t> row_ind_;
|
||||
std::vector<ColumnBoundary> boundary_;
|
||||
|
||||
size_t packing_factor_; // how many integers are stored in each slot of index_
|
||||
|
||||
// index_base_[fid]: least bin id for feature fid
|
||||
std::vector<uint32_t> index_base_;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||
@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include "./sync.h"
|
||||
#include "./hist_util.h"
|
||||
#include "./column_matrix.h"
|
||||
#include "./quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -21,12 +22,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
|
||||
const int kFactor = 8;
|
||||
std::vector<WXQSketch> sketchs;
|
||||
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
nthread = std::max(nthread / 2, 1);
|
||||
const int nthread = omp_get_max_threads();
|
||||
|
||||
unsigned nstep = (info.num_col + nthread - 1) / nthread;
|
||||
unsigned ncol = static_cast<unsigned>(info.num_col);
|
||||
@ -105,18 +101,14 @@ void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
CHECK(cut != nullptr);
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
hit_count.resize(cut->row_ptr.back(), 0);
|
||||
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
nthread = std::max(nthread / 2, 1);
|
||||
const int nthread = omp_get_max_threads();
|
||||
const unsigned nbins = cut->row_ptr.back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||
|
||||
iter->BeforeFirst();
|
||||
row_ptr.push_back(0);
|
||||
@ -134,6 +126,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
omp_ulong bsize = static_cast<omp_ulong>(batch.size);
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (omp_ulong i = 0; i < bsize; ++i) { // NOLINT(*)
|
||||
const int tid = omp_get_thread_num();
|
||||
size_t ibegin = row_ptr[rbegin + i];
|
||||
size_t iend = row_ptr[rbegin + i + 1];
|
||||
RowBatch::Inst inst = batch[i];
|
||||
@ -147,20 +140,28 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
if (it == cend) it = cend - 1;
|
||||
unsigned idx = static_cast<unsigned>(it - cut->cut.begin());
|
||||
index[ibegin + j] = idx;
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
std::sort(index.begin() + ibegin, index.begin() + iend);
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (omp_ulong idx = 0; idx < nbins; ++idx) {
|
||||
for (int tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<bst_uint>& feat_set,
|
||||
GHistRow hist) {
|
||||
CHECK(!data_.empty()) << "GHistBuilder must be initialized";
|
||||
CHECK_EQ(data_.size(), nbins_ * nthread_) << "invalid dimensions for temp buffer";
|
||||
|
||||
data_.resize(nbins_ * nthread_, GHistEntry());
|
||||
std::fill(data_.begin(), data_.end(), GHistEntry());
|
||||
stat_buf_.resize(row_indices.size());
|
||||
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
@ -169,21 +170,42 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
const bst_omp_uint tid = omp_get_thread_num();
|
||||
const size_t off = tid * nbins_;
|
||||
bst_uint rid[K];
|
||||
bst_gpair stat[K];
|
||||
size_t ibegin[K], iend[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
rid[k] = row_indices.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
stat[k] = gpair[rid[k]];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
stat_buf_[i + k] = stat[k];
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
const bst_uint rid = row_indices.begin[i];
|
||||
const bst_gpair stat = gpair[rid];
|
||||
stat_buf_[i] = stat;
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(dynamic)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
const bst_omp_uint tid = omp_get_thread_num();
|
||||
const size_t off = tid * nbins_;
|
||||
bst_uint rid[K];
|
||||
size_t ibegin[K];
|
||||
size_t iend[K];
|
||||
bst_gpair stat[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
rid[k] = row_indices.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
|
||||
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
stat[k] = stat_buf_[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
@ -193,9 +215,9 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
const bst_uint rid = row_indices.begin[i];
|
||||
const bst_gpair stat = gpair[rid];
|
||||
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
|
||||
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
|
||||
const bst_gpair stat = stat_buf_[i];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const size_t bin = gmat.index[j];
|
||||
data_[bin].Add(stat);
|
||||
@ -212,13 +234,26 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
}
|
||||
}
|
||||
|
||||
void GHistBuilder::SubtractionTrick(GHistRow self,
|
||||
GHistRow sibling,
|
||||
GHistRow parent) {
|
||||
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
|
||||
const int K = 8;
|
||||
const bst_omp_uint rest = nbins % K;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins - rest; bin_id += K) {
|
||||
GHistEntry pb[K];
|
||||
GHistEntry sb[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
pb[k] = parent.begin[bin_id + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
sb[k] = sibling.begin[bin_id + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
self.begin[bin_id + k].SetSubtract(pb[k], sb[k]);
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
|
||||
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -102,18 +102,27 @@ struct GHistIndexMatrix {
|
||||
std::vector<unsigned> index;
|
||||
/*! \brief hit count of each index */
|
||||
std::vector<unsigned> hit_count;
|
||||
/*! \brief optional remap index from outter row_id -> internal row_id*/
|
||||
std::vector<unsigned> remap_index;
|
||||
/*! \brief The corresponding cuts */
|
||||
const HistCutMatrix* cut;
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat);
|
||||
// build remap
|
||||
void Remap();
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](bst_uint i) const {
|
||||
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
|
||||
}
|
||||
inline void GetFeatureCounts(bst_uint* counts) const {
|
||||
const unsigned nfeature = cut->row_ptr.size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
const unsigned ibegin = cut->row_ptr[fid];
|
||||
const unsigned iend = cut->row_ptr[fid + 1];
|
||||
for (unsigned i = ibegin; i < iend; ++i) {
|
||||
counts[fid] += hit_count[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<unsigned> hit_count_tloc_;
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -189,13 +198,13 @@ class GHistBuilder {
|
||||
inline void Init(size_t nthread, size_t nbins) {
|
||||
nthread_ = nthread;
|
||||
nbins_ = nbins;
|
||||
data_.resize(nthread * nbins, GHistEntry());
|
||||
}
|
||||
|
||||
// construct a histogram via histogram aggregation
|
||||
void BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<bst_uint>& feat_set,
|
||||
GHistRow hist);
|
||||
// construct a histogram via subtraction trick
|
||||
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
||||
@ -206,6 +215,7 @@ class GHistBuilder {
|
||||
/*! \brief number of all bins over all features */
|
||||
size_t nbins_;
|
||||
std::vector<GHistEntry> data_;
|
||||
std::vector<bst_gpair> stat_buf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -17,15 +17,20 @@ namespace common {
|
||||
/*! \brief collection of rowset */
|
||||
class RowSetCollection {
|
||||
public:
|
||||
/*! \brief subset of rows */
|
||||
/*! \brief data structure to store an instance set, a subset of
|
||||
* rows (instances) associated with a particular node in a decision
|
||||
* tree. */
|
||||
struct Elem {
|
||||
const bst_uint* begin;
|
||||
const bst_uint* end;
|
||||
int node_id;
|
||||
// id of node associated with this instance set; -1 means uninitialized
|
||||
Elem(void)
|
||||
: begin(nullptr), end(nullptr) {}
|
||||
: begin(nullptr), end(nullptr), node_id(-1) {}
|
||||
Elem(const bst_uint* begin,
|
||||
const bst_uint* end)
|
||||
: begin(begin), end(end) {}
|
||||
const bst_uint* end,
|
||||
int node_id)
|
||||
: begin(begin), end(end), node_id(node_id) {}
|
||||
|
||||
inline size_t size() const {
|
||||
return end - begin;
|
||||
@ -36,6 +41,15 @@ class RowSetCollection {
|
||||
std::vector<bst_uint> left;
|
||||
std::vector<bst_uint> right;
|
||||
};
|
||||
|
||||
inline std::vector<Elem>::const_iterator begin() const {
|
||||
return elem_of_each_node_.begin();
|
||||
}
|
||||
|
||||
inline std::vector<Elem>::const_iterator end() const {
|
||||
return elem_of_each_node_.end();
|
||||
}
|
||||
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline const Elem& operator[](unsigned node_id) const {
|
||||
const Elem& e = elem_of_each_node_[node_id];
|
||||
@ -53,7 +67,7 @@ class RowSetCollection {
|
||||
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
||||
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
|
||||
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end));
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||
}
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id,
|
||||
@ -79,15 +93,15 @@ class RowSetCollection {
|
||||
}
|
||||
|
||||
if (left_node_id >= elem_of_each_node_.size()) {
|
||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr));
|
||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
}
|
||||
if (right_node_id >= elem_of_each_node_.size()) {
|
||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr));
|
||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
}
|
||||
|
||||
elem_of_each_node_[left_node_id] = Elem(begin, split_pt);
|
||||
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end);
|
||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr);
|
||||
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
|
||||
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
|
||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
||||
}
|
||||
|
||||
// stores the row indices in the set
|
||||
|
||||
@ -44,6 +44,8 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
std::string updater_seq;
|
||||
/*! \brief type of boosting process to run */
|
||||
int process_type;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||
DMLC_DECLARE_FIELD(num_parallel_tree)
|
||||
@ -60,6 +62,10 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
.add_enum("update", kUpdate)
|
||||
.describe("Whether to run the normal boosting process that creates new trees,"\
|
||||
" or to update the trees in an existing model.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
// add alias
|
||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||
}
|
||||
@ -260,9 +266,13 @@ class GBTree : public GradientBooster {
|
||||
new_trees.push_back(std::move(ret));
|
||||
}
|
||||
}
|
||||
double tstart = dmlc::GetTime();
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
this->CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
if (tparam.debug_verbose > 0) {
|
||||
LOG(INFO) << "CommitModel(): " << dmlc::GetTime() - tstart << " sec";
|
||||
}
|
||||
}
|
||||
|
||||
void Predict(DMatrix* p_fmat,
|
||||
@ -474,14 +484,20 @@ class GBTree : public GradientBooster {
|
||||
// update cache entry
|
||||
for (auto &kv : cache_) {
|
||||
CacheEntry& e = kv.second;
|
||||
|
||||
if (e.predictions.size() == 0) {
|
||||
PredLoopInternal<GBTree>(
|
||||
e.data.get(), &(e.predictions),
|
||||
0, trees.size(), true);
|
||||
} else {
|
||||
PredLoopInternal<GBTree>(
|
||||
e.data.get(), &(e.predictions),
|
||||
old_ntree, trees.size(), false);
|
||||
if (mparam.num_output_group == 1 && updaters.size() > 0 && new_trees.size() == 1
|
||||
&& updaters.back()->UpdatePredictionCache(e.data.get(), &(e.predictions)) ) {
|
||||
{} // do nothing
|
||||
} else {
|
||||
PredLoopInternal<GBTree>(
|
||||
e.data.get(), &(e.predictions),
|
||||
old_ntree, trees.size(), false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
*/
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <dmlc/timer.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
@ -83,6 +84,8 @@ struct LearnerTrainParam
|
||||
// number of threads to use if OpenMP is enabled
|
||||
// if equals 0, use system default
|
||||
int nthread;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(seed).set_default(0)
|
||||
@ -109,6 +112,10 @@ struct LearnerTrainParam
|
||||
.describe("maximum row per batch.");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0)
|
||||
.describe("Number of threads to use.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
}
|
||||
};
|
||||
|
||||
@ -170,28 +177,9 @@ class LearnerImpl : public Learner {
|
||||
|
||||
if (tparam.tree_method == 3) {
|
||||
/* histogram-based algorithm */
|
||||
if (cfg_.count("updater") == 0) {
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'hist\', "
|
||||
<< "which uses histogram aggregation for faster training. "
|
||||
<< "Using default sequence of updaters: grow_fast_histmaker,prune";
|
||||
cfg_["updater"] = "grow_fast_histmaker,prune";
|
||||
} else {
|
||||
const std::string first_str = "grow_fast_histmaker";
|
||||
if (first_str.length() <= cfg_["updater"].length()
|
||||
&& std::equal(first_str.begin(), first_str.end(), cfg_["updater"].begin())) {
|
||||
// updater sequence starts with "grow_fast_histmaker"
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'hist\', "
|
||||
<< "which uses histogram aggregation for faster training. "
|
||||
<< "Using custom sequence of updaters: " << cfg_["updater"];
|
||||
} else {
|
||||
// updater sequence does not start with "grow_fast_histmaker"
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'hist\', but the given "
|
||||
<< "sequence of updaters is not compatible; "
|
||||
<< "grow_fast_histmaker must run first. "
|
||||
<< "Using default sequence of updaters: grow_fast_histmaker,prune";
|
||||
cfg_["updater"] = "grow_fast_histmaker,prune";
|
||||
}
|
||||
}
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'hist\', which uses a single updater "
|
||||
<< "grow_fast_histmaker.";
|
||||
cfg_["updater"] = "grow_fast_histmaker";
|
||||
} else if (cfg_.count("updater") == 0) {
|
||||
if (tparam.dsplit == 1) {
|
||||
cfg_["updater"] = "distcol";
|
||||
@ -333,6 +321,7 @@ class LearnerImpl : public Learner {
|
||||
std::string EvalOneIter(int iter,
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names) override {
|
||||
double tstart = dmlc::GetTime();
|
||||
std::ostringstream os;
|
||||
os << '[' << iter << ']'
|
||||
<< std::setiosflags(std::ios::fixed);
|
||||
@ -347,6 +336,10 @@ class LearnerImpl : public Learner {
|
||||
<< ev->Eval(preds_, data_sets[i]->info(), tparam.dsplit == 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (tparam.debug_verbose > 0) {
|
||||
LOG(INFO) << "EvalOneIter(): " << dmlc::GetTime() - tstart << " sec";
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
||||
@ -97,44 +97,40 @@ struct EvalAuc : public Metric {
|
||||
// sum statistics
|
||||
bst_float sum_auc = 0.0f;
|
||||
int auc_error = 0;
|
||||
#pragma omp parallel reduction(+:sum_auc)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
const bst_float wt = info.GetWeight(rec[j].second);
|
||||
const bst_float ctr = info.labels[rec[j].second];
|
||||
// keep bucketing predictions in same bucket
|
||||
if (j != 0 && rec[j].first != rec[j - 1].first) {
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
buf_neg = buf_pos = 0.0f;
|
||||
}
|
||||
buf_pos += ctr * wt;
|
||||
buf_neg += (1.0f - ctr) * wt;
|
||||
}
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
// check weird conditions
|
||||
if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
|
||||
auc_error = 1;
|
||||
continue;
|
||||
}
|
||||
// this is the AUC
|
||||
sum_auc += sum_pospair / (sum_npos*sum_nneg);
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
const bst_float wt = info.GetWeight(rec[j].second);
|
||||
const bst_float ctr = info.labels[rec[j].second];
|
||||
// keep bucketing predictions in same bucket
|
||||
if (j != 0 && rec[j].first != rec[j - 1].first) {
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
buf_neg = buf_pos = 0.0f;
|
||||
}
|
||||
buf_pos += ctr * wt;
|
||||
buf_neg += (1.0f - ctr) * wt;
|
||||
}
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
// check weird conditions
|
||||
if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
|
||||
auc_error = 1;
|
||||
continue;
|
||||
}
|
||||
// this is the AUC
|
||||
sum_auc += sum_pospair / (sum_npos*sum_nneg);
|
||||
}
|
||||
CHECK(!auc_error)
|
||||
<< "AUC: the dataset only contains pos or neg samples";
|
||||
@ -262,9 +258,9 @@ struct EvalNDCG : public EvalRankList{
|
||||
return sumdcg;
|
||||
}
|
||||
virtual bst_float EvalMetric(std::vector<std::pair<bst_float, unsigned> > &rec) const { // NOLINT(*)
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
bst_float dcg = this->CalcDCG(rec);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpSecond);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpSecond);
|
||||
bst_float idcg = this->CalcDCG(rec);
|
||||
if (idcg == 0.0f) {
|
||||
if (minus_) {
|
||||
|
||||
@ -35,9 +35,12 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
int max_leaves;
|
||||
// if using histogram based algorithm, maximum number of bins per feature
|
||||
int max_bin;
|
||||
enum class DataType { uint8 = 1, uint16 = 2, uint32 = 4 };
|
||||
int colmat_dtype;
|
||||
// growing policy
|
||||
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
|
||||
int grow_policy;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
@ -90,9 +93,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe(
|
||||
"Setting verbose flag with a positive value causes the updater "
|
||||
"to print out *detailed* list of tasks and their runtime");
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
DMLC_DECLARE_FIELD(max_depth)
|
||||
.set_lower_bound(0)
|
||||
.set_default(6)
|
||||
@ -111,6 +112,14 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
"Tree growing policy. 0: favor splitting at nodes closest to the node, "
|
||||
"i.e. grow depth-wise. 1: favor splitting at nodes with highest loss "
|
||||
"change. (cf. LightGBM)");
|
||||
DMLC_DECLARE_FIELD(colmat_dtype)
|
||||
.set_default(static_cast<int>(DataType::uint32))
|
||||
.add_enum("uint8", static_cast<int>(DataType::uint8))
|
||||
.add_enum("uint16", static_cast<int>(DataType::uint16))
|
||||
.add_enum("uint32", static_cast<int>(DataType::uint32))
|
||||
.describe("Integral data type to be used with columnar data storage."
|
||||
"May carry marginal performance implications. Reserved for "
|
||||
"advanced use");
|
||||
DMLC_DECLARE_FIELD(min_child_weight)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
|
||||
@ -792,9 +792,6 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(dmat, *trees[0]);
|
||||
}
|
||||
const int* GetLeafPosition() const override {
|
||||
return builder.GetLeafPosition();
|
||||
}
|
||||
|
||||
private:
|
||||
struct Builder : public ColMaker<TStats, TConstraint>::Builder {
|
||||
@ -951,11 +948,6 @@ class TreeUpdaterSwitch : public TreeUpdater {
|
||||
inner_->Update(gpair, data, trees);
|
||||
}
|
||||
|
||||
const int* GetLeafPosition() const override {
|
||||
CHECK(inner_ != nullptr);
|
||||
return inner_->GetLeafPosition();
|
||||
}
|
||||
|
||||
private:
|
||||
// monotone constraints
|
||||
bool monotone_;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "../common/sync.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/row_set.h"
|
||||
#include "../common/column_matrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -30,6 +31,8 @@ using xgboost::common::HistCollection;
|
||||
using xgboost::common::RowSetCollection;
|
||||
using xgboost::common::GHistRow;
|
||||
using xgboost::common::GHistBuilder;
|
||||
using xgboost::common::ColumnMatrix;
|
||||
using xgboost::common::Column;
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_fast_hist);
|
||||
|
||||
@ -38,6 +41,11 @@ template<typename TStats, typename TConstraint>
|
||||
class FastHistMaker: public TreeUpdater {
|
||||
public:
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
// initialize pruner
|
||||
if (!pruner_) {
|
||||
pruner_.reset(TreeUpdater::Create("prune"));
|
||||
}
|
||||
pruner_->Init(args);
|
||||
param.InitAllowUnknown(args);
|
||||
is_gmat_initialized_ = false;
|
||||
}
|
||||
@ -51,6 +59,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
hmat_.Init(dmat, param.max_bin);
|
||||
gmat_.cut = &hmat_;
|
||||
gmat_.Init(dmat);
|
||||
column_matrix_.Init(gmat_, static_cast<xgboost::common::DataType>(param.colmat_dtype));
|
||||
is_gmat_initialized_ = true;
|
||||
if (param.debug_verbose > 0) {
|
||||
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
|
||||
@ -62,20 +71,31 @@ class FastHistMaker: public TreeUpdater {
|
||||
TConstraint::Init(¶m, dmat->info().num_col);
|
||||
// build tree
|
||||
if (!builder_) {
|
||||
builder_.reset(new Builder(param));
|
||||
builder_.reset(new Builder(param, std::move(pruner_)));
|
||||
}
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
builder_->Update(gmat_, gpair, dmat, trees[i]);
|
||||
builder_->Update(gmat_, column_matrix_, gpair, dmat, trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) const override {
|
||||
if (!builder_ || param.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
return builder_->UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
// data sketch
|
||||
HistCutMatrix hmat_;
|
||||
GHistIndexMatrix gmat_;
|
||||
// column accessor
|
||||
ColumnMatrix column_matrix_;
|
||||
bool is_gmat_initialized_;
|
||||
|
||||
// data structure
|
||||
@ -115,17 +135,18 @@ class FastHistMaker: public TreeUpdater {
|
||||
struct Builder {
|
||||
public:
|
||||
// constructor
|
||||
explicit Builder(const TrainParam& param) : param(param) {
|
||||
}
|
||||
explicit Builder(const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner)
|
||||
: param(param), pruner_(std::move(pruner)),
|
||||
p_last_tree_(nullptr), p_last_fmat_(nullptr) {}
|
||||
// update one tree, growing
|
||||
virtual void Update(const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const std::vector<bst_gpair>& gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
double gstart = dmlc::GetTime();
|
||||
|
||||
std::vector<int> feat_set(p_fmat->info().num_col);
|
||||
std::iota(feat_set.begin(), feat_set.end(), 0);
|
||||
int num_leaves = 0;
|
||||
unsigned timestamp = 0;
|
||||
|
||||
@ -138,14 +159,16 @@ class FastHistMaker: public TreeUpdater {
|
||||
|
||||
tstart = dmlc::GetTime();
|
||||
this->InitData(gmat, gpair, *p_fmat, *p_tree);
|
||||
std::vector<bst_uint> feat_set = feat_index;
|
||||
time_init_data = dmlc::GetTime() - tstart;
|
||||
|
||||
// FIXME(hcho3): this code is broken when param.num_roots > 1. Please fix it
|
||||
CHECK_EQ(p_tree->param.num_roots, 1)
|
||||
<< "tree_method=hist does not support multiple roots at this moment";
|
||||
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
|
||||
tstart = dmlc::GetTime();
|
||||
hist_.AddHistRow(nid);
|
||||
builder_.BuildHist(gpair, row_set_collection_[nid], gmat, hist_[nid]);
|
||||
builder_.BuildHist(gpair, row_set_collection_[nid], gmat, feat_set, hist_[nid]);
|
||||
time_build_hist += dmlc::GetTime() - tstart;
|
||||
|
||||
tstart = dmlc::GetTime();
|
||||
@ -171,7 +194,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
|
||||
} else {
|
||||
tstart = dmlc::GetTime();
|
||||
this->ApplySplit(nid, gmat, hist_, *p_fmat, p_tree);
|
||||
this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
|
||||
time_apply_split += dmlc::GetTime() - tstart;
|
||||
|
||||
tstart = dmlc::GetTime();
|
||||
@ -180,10 +203,12 @@ class FastHistMaker: public TreeUpdater {
|
||||
hist_.AddHistRow(cleft);
|
||||
hist_.AddHistRow(cright);
|
||||
if (row_set_collection_[cleft].size() < row_set_collection_[cright].size()) {
|
||||
builder_.BuildHist(gpair, row_set_collection_[cleft], gmat, hist_[cleft]);
|
||||
builder_.BuildHist(gpair, row_set_collection_[cleft], gmat, feat_set,
|
||||
hist_[cleft]);
|
||||
builder_.SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
|
||||
} else {
|
||||
builder_.BuildHist(gpair, row_set_collection_[cright], gmat, hist_[cright]);
|
||||
builder_.BuildHist(gpair, row_set_collection_[cright], gmat, feat_set,
|
||||
hist_[cright]);
|
||||
builder_.SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
|
||||
}
|
||||
time_build_hist += dmlc::GetTime() - tstart;
|
||||
@ -225,34 +250,76 @@ class FastHistMaker: public TreeUpdater {
|
||||
snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid));
|
||||
}
|
||||
|
||||
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
||||
|
||||
if (param.debug_verbose > 0) {
|
||||
double total_time = dmlc::GetTime() - gstart;
|
||||
LOG(INFO) << "\nInitData: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << time_init_data
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_data
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_data / total_time * 100 << "%)\n"
|
||||
<< "InitNewNode: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << time_init_new_node
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_new_node / total_time * 100 << "%)\n"
|
||||
<< "BuildHist: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << time_build_hist
|
||||
<< "BuildHist: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_build_hist / total_time * 100 << "%)\n"
|
||||
<< "EvaluateSplit: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << time_evaluate_split
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_evaluate_split / total_time * 100 << "%)\n"
|
||||
<< "ApplySplit: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << time_apply_split
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_apply_split / total_time * 100 << "%)\n"
|
||||
<< "========================================\n"
|
||||
<< "Total: "
|
||||
<< std::fixed << std::setw(4) << std::setprecision(2) << total_time;
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << total_time;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* p_out_preds) {
|
||||
std::vector<bst_float>& out_preds = *p_out_preds;
|
||||
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (leaf_value_cache_.empty()) {
|
||||
leaf_value_cache_.resize(p_last_tree_->param.num_nodes,
|
||||
std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
CHECK_GT(out_preds.size(), 0);
|
||||
|
||||
for (const RowSetCollection::Elem rowset : row_set_collection_) {
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||
int nid = rowset.node_id;
|
||||
bst_float leaf_value;
|
||||
// if a node is marked as deleted by the pruner, traverse upward to locate
|
||||
// a non-deleted leaf.
|
||||
if ((*p_last_tree_)[nid].is_deleted()) {
|
||||
while ((*p_last_tree_)[nid].is_deleted()) {
|
||||
nid = (*p_last_tree_)[nid].parent();
|
||||
}
|
||||
CHECK((*p_last_tree_)[nid].is_leaf());
|
||||
}
|
||||
leaf_value = (*p_last_tree_)[nid].leaf_value();
|
||||
|
||||
for (const bst_uint* it = rowset.begin; it < rowset.end; ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
// initialize temp data structure
|
||||
inline void InitData(const GHistIndexMatrix& gmat,
|
||||
@ -273,10 +340,13 @@ class FastHistMaker: public TreeUpdater {
|
||||
{
|
||||
// initialize the row set
|
||||
row_set_collection_.Clear();
|
||||
// clear local prediction cache
|
||||
leaf_value_cache_.clear();
|
||||
// initialize histogram collection
|
||||
size_t nbins = gmat.cut->row_ptr.back();
|
||||
hist_.Init(nbins);
|
||||
|
||||
// initialize histogram builder
|
||||
#pragma omp parallel
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
@ -305,11 +375,21 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
{
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
// store a pointer to training data
|
||||
p_last_fmat_ = &fmat;
|
||||
// initialize feature index
|
||||
unsigned ncol = static_cast<unsigned>(info.num_col);
|
||||
feat_index.clear();
|
||||
for (unsigned i = 0; i < ncol; ++i) {
|
||||
feat_index.push_back(i);
|
||||
if (data_layout_ == kDenseDataOneBased) {
|
||||
for (unsigned i = 1; i < ncol; ++i) {
|
||||
feat_index.push_back(i);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < ncol; ++i) {
|
||||
feat_index.push_back(i);
|
||||
}
|
||||
}
|
||||
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
|
||||
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
|
||||
@ -373,22 +453,48 @@ class FastHistMaker: public TreeUpdater {
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree,
|
||||
const std::vector<int>& feat_set) {
|
||||
const std::vector<bst_uint>& feat_set) {
|
||||
// start enumeration
|
||||
const MetaInfo& info = fmat.info();
|
||||
for (int fid : feat_set) {
|
||||
const bst_omp_uint nfeature = feat_set.size();
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread);
|
||||
best_split_tloc_.resize(nthread);
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
best_split_tloc_[tid] = snode[nid].best;
|
||||
}
|
||||
#pragma omp parallel for schedule(dynamic) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < nfeature; ++i) {
|
||||
const bst_uint fid = feat_set[i];
|
||||
const unsigned tid = omp_get_thread_num();
|
||||
this->EnumerateSplit(-1, gmat, hist[nid], snode[nid], constraints_[nid], info,
|
||||
&snode[nid].best, fid);
|
||||
&best_split_tloc_[tid], fid);
|
||||
this->EnumerateSplit(+1, gmat, hist[nid], snode[nid], constraints_[nid], info,
|
||||
&snode[nid].best, fid);
|
||||
&best_split_tloc_[tid], fid);
|
||||
}
|
||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||
snode[nid].best.Update(best_split_tloc_[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ApplySplit(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree) {
|
||||
XGBOOST_TYPE_SWITCH(column_matrix.dtype, {
|
||||
ApplySplit_<DType>(nid, gmat, column_matrix, hist, fmat, p_tree);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void ApplySplit_(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree) {
|
||||
// TODO(hcho3): support feature sampling by levels
|
||||
|
||||
/* 1. Create child nodes */
|
||||
@ -422,66 +528,89 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
const auto& rowset = row_set_collection_[nid];
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
/* specialized code for dense data */
|
||||
const size_t column_offset = (data_layout_ == kDenseDataOneBased) ? (fid - 1): fid;
|
||||
ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column_offset, split_cond);
|
||||
|
||||
Column<T> column = column_matrix.GetColumn<T>(fid);
|
||||
if (column.type == xgboost::common::kDenseColumn) {
|
||||
ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond,
|
||||
default_left);
|
||||
} else {
|
||||
ApplySplitSparseData(rowset, gmat, &row_split_tloc_, lower_bound, upper_bound,
|
||||
split_cond, default_left);
|
||||
ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound,
|
||||
upper_bound, split_cond, default_left);
|
||||
}
|
||||
|
||||
row_set_collection_.AddSplit(
|
||||
nid, row_split_tloc_, (*p_tree)[nid].cleft(), (*p_tree)[nid].cright());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void ApplySplitDenseData(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
size_t column_offset,
|
||||
bst_uint split_cond) {
|
||||
const Column<T>& column,
|
||||
bst_uint split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
const bst_omp_uint rest = nrows % K;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
|
||||
bst_uint rid[K];
|
||||
unsigned rbin[K];
|
||||
bst_uint tid = omp_get_thread_num();
|
||||
const bst_uint tid = omp_get_thread_num();
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
bst_uint rid[K];
|
||||
T rbin[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
rid[k] = rowset.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
rbin[k] = gmat[rid[k]].index[column_offset];
|
||||
rbin[k] = column.index[rid[k]];
|
||||
}
|
||||
for (int k = 0; k < K; ++k) {
|
||||
if (rbin[k] <= split_cond) {
|
||||
left.push_back(rid[k]);
|
||||
if (rbin[k] == std::numeric_limits<T>::max()) { // missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
}
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
if (rbin[k] + column.index_base <= split_cond) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
auto& left = row_split_tloc[nthread-1].left;
|
||||
auto& right = row_split_tloc[nthread-1].right;
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
const unsigned rbin = gmat[rid].index[column_offset];
|
||||
if (rbin <= split_cond) {
|
||||
row_split_tloc[0].left.push_back(rid);
|
||||
const T rbin = column.index[rid];
|
||||
if (rbin == std::numeric_limits<T>::max()) { // missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
row_split_tloc[0].right.push_back(rid);
|
||||
if (rbin + column.index_base <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
bst_uint lower_bound,
|
||||
bst_uint upper_bound,
|
||||
bst_uint split_cond,
|
||||
bool default_left) {
|
||||
inline void ApplySplitSparseDataOld(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
bst_uint lower_bound,
|
||||
bst_uint upper_bound,
|
||||
bst_uint split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
@ -541,6 +670,73 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
const Column<T>& column,
|
||||
bst_uint lower_bound,
|
||||
bst_uint upper_bound,
|
||||
bst_uint split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const bst_omp_uint nrows = rowset.end - rowset.begin;
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
const bst_uint tid = omp_get_thread_num();
|
||||
const bst_omp_uint ibegin = tid * nrows / nthread;
|
||||
const bst_omp_uint iend = (tid + 1) * nrows / nthread;
|
||||
// search first nonzero row with index >= rowset[ibegin]
|
||||
const uint32_t* p = std::lower_bound(column.row_ind,
|
||||
column.row_ind + column.len,
|
||||
rowset.begin[ibegin]);
|
||||
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) {
|
||||
bst_omp_uint cursor = p - column.row_ind;
|
||||
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
while (cursor < column.len
|
||||
&& column.row_ind[cursor] < rid
|
||||
&& column.row_ind[cursor] <= rowset.begin[iend - 1]) {
|
||||
++cursor;
|
||||
}
|
||||
if (cursor < column.len && column.row_ind[cursor] == rid) {
|
||||
const T rbin = column.index[cursor];
|
||||
if (rbin + column.index_base <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
++cursor;
|
||||
} else {
|
||||
// missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // all rows in [ibegin, iend) have missing values
|
||||
if (default_left) {
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
left.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
for (bst_omp_uint i = ibegin; i < iend; ++i) {
|
||||
const bst_uint rid = rowset.begin[i];
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void InitNewNode(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<bst_gpair>& gpair,
|
||||
@ -600,7 +796,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
const TConstraint& constraint,
|
||||
const MetaInfo& info,
|
||||
SplitEntry* p_best,
|
||||
int fid) {
|
||||
bst_uint fid) {
|
||||
CHECK(d_step == +1 || d_step == -1);
|
||||
|
||||
// aliases
|
||||
@ -695,13 +891,23 @@ class FastHistMaker: public TreeUpdater {
|
||||
RowSetCollection row_set_collection_;
|
||||
// the temp space for split
|
||||
std::vector<RowSetCollection::Split> row_split_tloc_;
|
||||
std::vector<SplitEntry> best_split_tloc_;
|
||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||
std::vector<NodeEntry> snode;
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
HistCollection hist_;
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
size_t fid_least_bins_;
|
||||
/*! \brief local prediction cache; maps node id to leaf value */
|
||||
std::vector<float> leaf_value_cache_;
|
||||
|
||||
GHistBuilder builder_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
|
||||
// back pointers to tree and data matrix
|
||||
const RegTree* p_last_tree_;
|
||||
const DMatrix* p_last_fmat_;
|
||||
|
||||
// constraint value
|
||||
std::vector<TConstraint> constraints_;
|
||||
@ -716,6 +922,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
};
|
||||
|
||||
std::unique_ptr<Builder> builder_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
||||
|
||||
@ -15,6 +15,32 @@ class TestFastHist(unittest.TestCase):
|
||||
except:
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
dpath = 'demo/data/'
|
||||
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
ag_param = {'max_depth': 2,
|
||||
'tree_method': 'exact',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': 2,
|
||||
'tree_method': 'hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
ag_res2 = {}
|
||||
|
||||
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res)
|
||||
xgb.train(ag_param2, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res2)
|
||||
assert ag_res['train']['auc'] == ag_res2['train']['auc']
|
||||
assert ag_res['test']['auc'] == ag_res2['test']['auc']
|
||||
|
||||
digits = load_digits(2)
|
||||
X = digits['data']
|
||||
y = digits['target']
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user