diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 9c22d837f..efa43fd00 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -42,6 +42,7 @@ #include "../src/tree/tree_model.cc" #include "../src/tree/tree_updater.cc" #include "../src/tree/updater_colmaker.cc" +#include "../src/tree/updater_fast_hist.cc" #include "../src/tree/updater_prune.cc" #include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_sync.cc" @@ -52,6 +53,7 @@ #include "../src/learner.cc" #include "../src/logging.cc" #include "../src/common/common.cc" +#include "../src/common/hist_util.cc" // c_api #include "../src/c_api/c_api.cc" diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 985121791..87c0ef4ae 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -39,6 +39,15 @@ #define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE #endif +/*! + * \brief Check if alignas(*) keyword is supported. (g++ 4.8 or higher) + */ +#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 +#define XGBOOST_ALIGNAS(X) alignas(X) +#else +#define XGBOOST_ALIGNAS(X) +#endif + /*! \brief namespace of xgboo st*/ namespace xgboost { /*! diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 61bd5176a..42596f344 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -127,6 +127,7 @@ struct SparseBatch { /*! \brief length of the instance */ bst_uint length; /*! \brief constructor */ + Inst() : data(0), length(0) {} Inst(const Entry *data, bst_uint length) : data(data), length(length) {} /*! \brief get i-th pair in the sparse vector*/ inline const Entry& operator[](size_t i) const { diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc new file mode 100644 index 000000000..fa61bfc79 --- /dev/null +++ b/src/common/hist_util.cc @@ -0,0 +1,227 @@ +/*! + * Copyright 2017 by Contributors + * \file hist_util.h + * \brief Utilities to store histograms + * \author Philip Cho, Tianqi Chen + */ +#include +#include +#include "./sync.h" +#include "./hist_util.h" +#include "./quantile.h" + +namespace xgboost { +namespace common { + +void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) { + typedef common::WXQuantileSketch WXQSketch; + const MetaInfo& info = p_fmat->info(); + + // safe factor for better accuracy + const int kFactor = 8; + std::vector sketchs; + + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } + nthread = std::max(nthread / 2, 1); + + unsigned nstep = (info.num_col + nthread - 1) / nthread; + unsigned ncol = static_cast(info.num_col); + sketchs.resize(info.num_col); + for (auto& s : sketchs) { + s.Init(info.num_row, 1.0 / (max_num_bins * kFactor)); + } + + dmlc::DataIter* iter = p_fmat->RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + #pragma omp parallel num_threads(nthread) + { + CHECK_EQ(nthread, omp_get_num_threads()); + unsigned tid = static_cast(omp_get_thread_num()); + unsigned begin = std::min(nstep * tid, ncol); + unsigned end = std::min(nstep * (tid + 1), ncol); + for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*) + bst_uint ridx = static_cast(batch.base_rowid + i); + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + if (inst[j].index >= begin && inst[j].index < end) { + sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx)); + } + } + } + } + } + + // gather the histogram data + rabit::SerializeReducer sreducer; + std::vector summary_array; + summary_array.resize(sketchs.size()); + for (size_t i = 0; i < sketchs.size(); ++i) { + WXQSketch::SummaryContainer out; + sketchs[i].GetSummary(&out); + summary_array[i].Reserve(max_num_bins * kFactor); + summary_array[i].SetPrune(out, max_num_bins * kFactor); + } + size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); + sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); + + this->min_val.resize(info.num_col); + row_ptr.push_back(0); + for (size_t fid = 0; fid < summary_array.size(); ++fid) { + WXQSketch::SummaryContainer a; + a.Reserve(max_num_bins); + a.SetPrune(summary_array[fid], max_num_bins); + const bst_float mval = a.data[0].value; + this->min_val[fid] = mval - fabs(mval); + if (a.size > 1 && a.size <= 16) { + /* specialized code categorial / ordinal data -- use midpoints */ + for (size_t i = 1; i < a.size; ++i) { + bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0; + if (i == 1 || cpt > cut.back()) { + cut.push_back(cpt); + } + } + } else { + for (size_t i = 2; i < a.size; ++i) { + bst_float cpt = a.data[i - 1].value; + if (i == 2 || cpt > cut.back()) { + cut.push_back(cpt); + } + } + } + // push a value that is greater than anything + if (a.size != 0) { + bst_float cpt = a.data[a.size - 1].value; + // this must be bigger than last value in a scale + bst_float last = cpt + fabs(cpt); + cut.push_back(last); + } + row_ptr.push_back(cut.size()); + } +} + + +void GHistIndexMatrix::Init(DMatrix* p_fmat) { + CHECK(cut != nullptr); + dmlc::DataIter* iter = p_fmat->RowIterator(); + hit_count.resize(cut->row_ptr.back(), 0); + + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } + nthread = std::max(nthread / 2, 1); + + iter->BeforeFirst(); + row_ptr.push_back(0); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + size_t rbegin = row_ptr.size() - 1; + for (size_t i = 0; i < batch.size; ++i) { + row_ptr.push_back(batch[i].length + row_ptr.back()); + } + index.resize(row_ptr.back()); + + CHECK_GT(cut->cut.size(), 0); + CHECK_EQ(cut->row_ptr.back(), cut->cut.size()); + + omp_ulong bsize = static_cast(batch.size); + #pragma omp parallel for num_threads(nthread) schedule(static) + for (omp_ulong i = 0; i < bsize; ++i) { // NOLINT(*) + size_t ibegin = row_ptr[rbegin + i]; + size_t iend = row_ptr[rbegin + i + 1]; + RowBatch::Inst inst = batch[i]; + CHECK_EQ(ibegin + inst.length, iend); + for (bst_uint j = 0; j < inst.length; ++j) { + unsigned fid = inst[j].index; + auto cbegin = cut->cut.begin() + cut->row_ptr[fid]; + auto cend = cut->cut.begin() + cut->row_ptr[fid + 1]; + CHECK(cbegin != cend); + auto it = std::upper_bound(cbegin, cend, inst[j].fvalue); + if (it == cend) it = cend - 1; + unsigned idx = static_cast(it - cut->cut.begin()); + index[ibegin + j] = idx; + } + std::sort(index.begin() + ibegin, index.begin() + iend); + } + } +} + +void GHistBuilder::BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist) { + CHECK(!data_.empty()) << "GHistBuilder must be initialized"; + CHECK_EQ(data_.size(), nbins_ * nthread_) << "invalid dimensions for temp buffer"; + + std::fill(data_.begin(), data_.end(), GHistEntry()); + + const int K = 8; // loop unrolling factor + const bst_omp_uint nthread = static_cast(this->nthread_); + const bst_omp_uint nrows = row_indices.end - row_indices.begin; + const bst_omp_uint rest = nrows % K; + + #pragma omp parallel for num_threads(nthread) schedule(static) + for (bst_omp_uint i = 0; i < nrows - rest; i += K) { + const bst_omp_uint tid = omp_get_thread_num(); + const size_t off = tid * nbins_; + bst_uint rid[K]; + bst_gpair stat[K]; + size_t ibegin[K], iend[K]; + for (int k = 0; k < K; ++k) { + rid[k] = row_indices.begin[i + k]; + } + for (int k = 0; k < K; ++k) { + stat[k] = gpair[rid[k]]; + } + for (int k = 0; k < K; ++k) { + ibegin[k] = static_cast(gmat.row_ptr[rid[k]]); + iend[k] = static_cast(gmat.row_ptr[rid[k] + 1]); + } + for (int k = 0; k < K; ++k) { + for (size_t j = ibegin[k]; j < iend[k]; ++j) { + const size_t bin = gmat.index[j]; + data_[off + bin].Add(stat[k]); + } + } + } + for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { + const bst_uint rid = row_indices.begin[i]; + const bst_gpair stat = gpair[rid]; + const size_t ibegin = static_cast(gmat.row_ptr[rid]); + const size_t iend = static_cast(gmat.row_ptr[rid + 1]); + for (size_t j = ibegin; j < iend; ++j) { + const size_t bin = gmat.index[j]; + data_[bin].Add(stat); + } + } + + /* reduction */ + const bst_omp_uint nbins = static_cast(nbins_); + #pragma omp parallel for num_threads(nthread) schedule(static) + for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) { + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + hist.begin[bin_id].Add(data_[tid * nbins_ + bin_id]); + } + } +} + +void GHistBuilder::SubtractionTrick(GHistRow self, + GHistRow sibling, + GHistRow parent) { + const bst_omp_uint nthread = static_cast(this->nthread_); + const bst_omp_uint nbins = static_cast(nbins_); + #pragma omp parallel for num_threads(nthread) schedule(static) + for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) { + self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]); + } +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h new file mode 100644 index 000000000..6dce1d7c8 --- /dev/null +++ b/src/common/hist_util.h @@ -0,0 +1,214 @@ +/*! + * Copyright 2017 by Contributors + * \file hist_util.h + * \brief Utility for fast histogram aggregation + * \author Philip Cho, Tianqi Chen + */ +#ifndef XGBOOST_COMMON_HIST_UTIL_H_ +#define XGBOOST_COMMON_HIST_UTIL_H_ + +#include +#include +#include +#include "row_set.h" + +namespace xgboost { +namespace common { + +/*! \brief sums of gradient statistics corresponding to a histogram bin */ +struct GHistEntry { + /*! \brief sum of first-order gradient statistics */ + double sum_grad; + /*! \brief sum of second-order gradient statistics */ + double sum_hess; + + GHistEntry() : sum_grad(0), sum_hess(0) {} + + /*! \brief add a bst_gpair to the sum */ + inline void Add(const bst_gpair& e) { + sum_grad += e.grad; + sum_hess += e.hess; + } + + /*! \brief add a GHistEntry to the sum */ + inline void Add(const GHistEntry& e) { + sum_grad += e.sum_grad; + sum_hess += e.sum_hess; + } + + /*! \brief set sum to be difference of two GHistEntry's */ + inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) { + sum_grad = a.sum_grad - b.sum_grad; + sum_hess = a.sum_hess - b.sum_hess; + } +}; + + +/*! \brief Cut configuration for one feature */ +struct HistCutUnit { + /*! \brief the index pointer of each histunit */ + const bst_float* cut; + /*! \brief number of cutting point, containing the maximum point */ + size_t size; + // default constructor + HistCutUnit() {} + // constructor + HistCutUnit(const bst_float* cut, unsigned size) + : cut(cut), size(size) {} +}; + +/*! \brief cut configuration for all the features */ +struct HistCutMatrix { + /*! \brief actual unit pointer */ + std::vector row_ptr; + /*! \brief minimum value of each feature */ + std::vector min_val; + /*! \brief the cut field */ + std::vector cut; + /*! \brief Get histogram bound for fid */ + inline HistCutUnit operator[](unsigned fid) const { + return HistCutUnit(dmlc::BeginPtr(cut) + row_ptr[fid], + row_ptr[fid + 1] - row_ptr[fid]); + } + // create histogram cut matrix given statistics from data + // using approximate quantile sketch approach + void Init(DMatrix* p_fmat, size_t max_num_bins); +}; + + +/*! + * \brief A single row in global histogram index. + * Directly represent the global index in the histogram entry. + */ +struct GHistIndexRow { + /*! \brief The index of the histogram */ + const unsigned* index; + /*! \brief The size of the histogram */ + unsigned size; + GHistIndexRow() {} + GHistIndexRow(const unsigned* index, unsigned size) + : index(index), size(size) {} +}; + +/*! + * \brief preprocessed global index matrix, in CSR format + * Transform floating values to integer index in histogram + * This is a global histogram index. + */ +struct GHistIndexMatrix { + /*! \brief row pointer */ + std::vector row_ptr; + /*! \brief The index data */ + std::vector index; + /*! \brief hit count of each index */ + std::vector hit_count; + /*! \brief optional remap index from outter row_id -> internal row_id*/ + std::vector remap_index; + /*! \brief The corresponding cuts */ + const HistCutMatrix* cut; + // Create a global histogram matrix, given cut + void Init(DMatrix* p_fmat); + // build remap + void Remap(); + // get i-th row + inline GHistIndexRow operator[](bst_uint i) const { + return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]); + } +}; + +/*! + * \brief histogram of graident statistics for a single node. + * Consists of multiple GHistEntry's, each entry showing total graident statistics + * for that particular bin + * Uses global bin id so as to represent all features simultaneously + */ +struct GHistRow { + /*! \brief base pointer to first entry */ + GHistEntry* begin; + /*! \brief number of entries */ + unsigned size; + + GHistRow() {} + GHistRow(GHistEntry* begin, unsigned size) + : begin(begin), size(size) {} +}; + +/*! + * \brief histogram of gradient statistics for multiple nodes + */ +class HistCollection { + public: + // access histogram for i-th node + inline GHistRow operator[](bst_uint nid) const { + const size_t kMax = std::numeric_limits::max(); + CHECK_NE(row_ptr_[nid], kMax); + return GHistRow(const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_); + } + + // have we computed a histogram for i-th node? + inline bool RowExists(bst_uint nid) const { + const size_t kMax = std::numeric_limits::max(); + return (nid < row_ptr_.size() && row_ptr_[nid] != kMax); + } + + // initialize histogram collection + inline void Init(size_t nbins) { + nbins_ = nbins; + row_ptr_.clear(); + data_.clear(); + } + + // create an empty histogram for i-th node + inline void AddHistRow(bst_uint nid) { + const size_t kMax = std::numeric_limits::max(); + if (nid >= row_ptr_.size()) { + row_ptr_.resize(nid + 1, kMax); + } + CHECK_EQ(row_ptr_[nid], kMax); + + row_ptr_[nid] = data_.size(); + data_.resize(data_.size() + nbins_); + } + + private: + /*! \brief number of all bins over all features */ + size_t nbins_; + + std::vector data_; + + /*! \brief row_ptr_[nid] locates bin for historgram of node nid */ + std::vector row_ptr_; +}; + +/*! + * \brief builder for histograms of gradient statistics + */ +class GHistBuilder { + public: + // initialize builder + inline void Init(size_t nthread, size_t nbins) { + nthread_ = nthread; + nbins_ = nbins; + data_.resize(nthread * nbins, GHistEntry()); + } + + // construct a histogram via histogram aggregation + void BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist); + // construct a histogram via subtraction trick + void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); + + private: + /*! \brief number of threads for parallel computation */ + size_t nthread_; + /*! \brief number of all bins over all features */ + size_t nbins_; + std::vector data_; +}; + + +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_HIST_UTIL_H_ diff --git a/src/common/quantile.h b/src/common/quantile.h index 9c427470f..409279bd9 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -348,10 +348,12 @@ struct WXQSummary : public WQSummary { this->CopyFrom(src); return; } RType begin = src.data[0].rmax; - size_t n = maxsize - 1, nbig = 0; + // n is number of points exclude the min/max points + size_t n = maxsize - 2, nbig = 0; + // these is the range of data exclude the min/max point RType range = src.data[src.size - 1].rmin - begin; // prune off zero weights - if (range == 0.0f) { + if (range == 0.0f || maxsize <= 2) { // special case, contain only two effective data pts this->data[0] = src.data[0]; this->data[1] = src.data[src.size - 1]; @@ -360,16 +362,21 @@ struct WXQSummary : public WQSummary { } else { range = std::max(range, static_cast(1e-3f)); } + // Get a big enough chunk size, bigger than range / n + // (multiply by 2 is a safe factor) const RType chunk = 2 * range / n; // minimized range RType mrange = 0; { // first scan, grab all the big chunk - // moving block index + // moving block index, exclude the two ends. size_t bid = 0; - for (size_t i = 1; i < src.size; ++i) { + for (size_t i = 1; i < src.size - 1; ++i) { + // detect big chunk data point in the middle + // always save these data points. if (CheckLarge(src.data[i], chunk)) { if (bid != i - 1) { + // accumulate the range of the rest points mrange += src.data[i].rmax_prev() - src.data[bid].rmin_next(); } bid = i; ++nbig; @@ -379,17 +386,18 @@ struct WXQSummary : public WQSummary { mrange += src.data[src.size-1].rmax_prev() - src.data[bid].rmin_next(); } } - if (nbig >= n - 1) { + // assert: there cannot be more than n big data points + if (nbig >= n) { // see what was the case LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n; LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize << ", range=" << range << ", chunk=" << chunk; src.Print(); - CHECK(nbig < n - 1) << "quantile: too many large chunk"; + CHECK(nbig < n) << "quantile: too many large chunk"; } this->data[0] = src.data[0]; this->size = 1; - // use smaller size + // The counter on the rest of points, to be selected equally from small chunks. n = n - nbig; // find the rest of point size_t bid = 0, k = 1, lastidx = 0; diff --git a/src/common/row_set.h b/src/common/row_set.h new file mode 100644 index 000000000..58103e664 --- /dev/null +++ b/src/common/row_set.h @@ -0,0 +1,104 @@ +/*! + * Copyright 2017 by Contributors + * \file row_set.h + * \brief Quick Utility to compute subset of rows + * \author Philip Cho, Tianqi Chen + */ +#ifndef XGBOOST_COMMON_ROW_SET_H_ +#define XGBOOST_COMMON_ROW_SET_H_ + +#include +#include +#include + +namespace xgboost { +namespace common { + +/*! \brief collection of rowset */ +class RowSetCollection { + public: + /*! \brief subset of rows */ + struct Elem { + const bst_uint* begin; + const bst_uint* end; + Elem(void) + : begin(nullptr), end(nullptr) {} + Elem(const bst_uint* begin, + const bst_uint* end) + : begin(begin), end(end) {} + + inline size_t size() const { + return end - begin; + } + }; + /* \brief specifies how to split a rowset into two */ + struct Split { + std::vector left; + std::vector right; + }; + /*! \brief return corresponding element set given the node_id */ + inline const Elem& operator[](unsigned node_id) const { + const Elem& e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr) + << "access element that is not in the set"; + return e; + } + // clear up things + inline void Clear() { + row_indices_.clear(); + elem_of_each_node_.clear(); + } + // initialize node id 0->everything + inline void Init() { + CHECK_EQ(elem_of_each_node_.size(), 0); + const bst_uint* begin = dmlc::BeginPtr(row_indices_); + const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); + elem_of_each_node_.emplace_back(Elem(begin, end)); + } + // split rowset into two + inline void AddSplit(unsigned node_id, + const std::vector& row_split_tloc, + unsigned left_node_id, + unsigned right_node_id) { + const Elem e = elem_of_each_node_[node_id]; + const unsigned nthread = row_split_tloc.size(); + CHECK(e.begin != nullptr); + bst_uint* all_begin = dmlc::BeginPtr(row_indices_); + bst_uint* begin = all_begin + (e.begin - all_begin); + + bst_uint* it = begin; + // TODO(hcho3): parallelize this section + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); + it += row_split_tloc[tid].left.size(); + } + bst_uint* split_pt = it; + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); + it += row_split_tloc[tid].right.size(); + } + + if (left_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr)); + } + if (right_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr)); + } + + elem_of_each_node_[left_node_id] = Elem(begin, split_pt); + elem_of_each_node_[right_node_id] = Elem(split_pt, e.end); + elem_of_each_node_[node_id] = Elem(nullptr, nullptr); + } + + // stores the row indices in the set + std::vector row_indices_; + + private: + // vector: node_id -> elements + std::vector elem_of_each_node_; +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_ROW_SET_H_ diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 3d8c2a9db..af96732dc 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -6,6 +6,7 @@ */ #include #include +#include #include #include #include @@ -369,7 +370,7 @@ class GBTree : public GradientBooster { const int nthread = omp_get_max_threads(); CHECK_EQ(num_group, mparam.num_output_group); InitThreadTemp(nthread); - std::vector &preds = *out_preds; + std::vector& preds = *out_preds; CHECK_EQ(mparam.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group); @@ -380,17 +381,38 @@ class GBTree : public GradientBooster { while (iter->Next()) { const RowBatch &batch = iter->Value(); // parallel over local batch + const int K = 8; const bst_omp_uint nsize = static_cast(batch.size); + const bst_omp_uint rest = nsize % K; #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize; ++i) { + for (bst_omp_uint i = 0; i < nsize - rest; i += K) { const int tid = omp_get_thread_num(); - RegTree::FVec &feats = thread_temp[tid]; - int64_t ridx = static_cast(batch.base_rowid + i); - CHECK_LT(static_cast(ridx), info.num_row); + RegTree::FVec& feats = thread_temp[tid]; + int64_t ridx[K]; + RowBatch::Inst inst[K]; + for (int k = 0; k < K; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); + } + for (int k = 0; k < K; ++k) { + inst[k] = batch[i + k]; + } + for (int k = 0; k < K; ++k) { + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += + self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), + &feats, tree_begin, tree_end); + } + } + } + for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { + RegTree::FVec& feats = thread_temp[0]; + const int64_t ridx = static_cast(batch.base_rowid + i); + const RowBatch::Inst inst = batch[i]; for (int gid = 0; gid < num_group; ++gid) { - size_t offset = ridx * num_group + gid; + const size_t offset = ridx * num_group + gid; preds[offset] += - self->PredValue(batch[i], gid, info.GetRoot(ridx), + self->PredValue(inst, gid, info.GetRoot(ridx), &feats, tree_begin, tree_end); } } diff --git a/src/learner.cc b/src/learner.cc index 55651c7f4..f9ebc2afc 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -99,6 +99,7 @@ struct LearnerTrainParam .add_enum("auto", 0) .add_enum("approx", 1) .add_enum("exact", 2) + .add_enum("hist", 3) .describe("Choice of tree construction method."); DMLC_DECLARE_FIELD(test_flag).set_default("") .describe("Internal test flag"); @@ -167,7 +168,31 @@ class LearnerImpl : public Learner { cfg_["max_delta_step"] = "0.7"; } - if (cfg_.count("updater") == 0) { + if (tparam.tree_method == 3) { + /* histogram-based algorithm */ + if (cfg_.count("updater") == 0) { + LOG(CONSOLE) << "Tree method is selected to be \'hist\', " + << "which uses histogram aggregation for faster training. " + << "Using default sequence of updaters: grow_fast_histmaker,prune"; + cfg_["updater"] = "grow_fast_histmaker,prune"; + } else { + const std::string first_str = "grow_fast_histmaker"; + if (first_str.length() <= cfg_["updater"].length() + && std::equal(first_str.begin(), first_str.end(), cfg_["updater"].begin())) { + // updater sequence starts with "grow_fast_histmaker" + LOG(CONSOLE) << "Tree method is selected to be \'hist\', " + << "which uses histogram aggregation for faster training. " + << "Using custom sequence of updaters: " << cfg_["updater"]; + } else { + // updater sequence does not start with "grow_fast_histmaker" + LOG(CONSOLE) << "Tree method is selected to be \'hist\', but the given " + << "sequence of updaters is not compatible; " + << "grow_fast_histmaker must run first. " + << "Using default sequence of updaters: grow_fast_histmaker,prune"; + cfg_["updater"] = "grow_fast_histmaker,prune"; + } + } + } else if (cfg_.count("updater") == 0) { if (tparam.dsplit == 1) { cfg_["updater"] = "distcol"; } else if (tparam.dsplit == 2) { @@ -379,8 +404,8 @@ class LearnerImpl : public Learner { protected: // check if p_train is ready to used by training. // if not, initialize the column access. - inline void LazyInitDMatrix(DMatrix *p_train) { - if (!p_train->HaveColAccess()) { + inline void LazyInitDMatrix(DMatrix* p_train) { + if (tparam.tree_method != 3 && !p_train->HaveColAccess()) { int ncol = static_cast(p_train->info().num_col); std::vector enabled(ncol, true); // set max row per batch to limited value diff --git a/src/tree/param.h b/src/tree/param.h index f1c0ad60b..8fde1d796 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -31,6 +31,14 @@ struct TrainParam : public dmlc::Parameter { float min_split_loss; // maximum depth of a tree int max_depth; + // maximum number of leaves + int max_leaves; + // if using histogram based algorithm, maximum number of bins per feature + int max_bin; + // growing policy + enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 }; + int grow_policy; + int verbose; //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child float min_child_weight; @@ -77,11 +85,32 @@ struct TrainParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(min_split_loss) .set_lower_bound(0.0f) .set_default(0.0f) - .describe("Minimum loss reduction required to make a further partition."); + .describe( + "Minimum loss reduction required to make a further partition."); + DMLC_DECLARE_FIELD(verbose) + .set_lower_bound(0) + .set_default(0) + .describe( + "Setting verbose flag with a positive value causes the updater " + "to print out *detailed* list of tasks and their runtime"); DMLC_DECLARE_FIELD(max_depth) .set_lower_bound(0) .set_default(6) - .describe("Maximum depth of the tree."); + .describe( + "Maximum depth of the tree; 0 indicates no limit; a limit is required " + "for depthwise policy"); + DMLC_DECLARE_FIELD(max_leaves).set_lower_bound(0).set_default(0).describe( + "Maximum number of leaves; 0 indicates no limit."); + DMLC_DECLARE_FIELD(max_bin).set_lower_bound(2).set_default(256).describe( + "if using histogram-based algorithm, maximum number of bins per feature"); + DMLC_DECLARE_FIELD(grow_policy) + .set_default(kDepthWise) + .add_enum("depthwise", kDepthWise) + .add_enum("lossguide", kLossGuide) + .describe( + "Tree growing policy. 0: favor splitting at nodes closest to the node, " + "i.e. grow depth-wise. 1: favor splitting at nodes with highest loss " + "change. (cf. LightGBM)"); DMLC_DECLARE_FIELD(min_child_weight) .set_lower_bound(0.0f) .set_default(1.0f) @@ -258,7 +287,7 @@ XGB_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, } /*! \brief core statistics used for tree construction */ -struct GradStats { +struct XGBOOST_ALIGNAS(16) GradStats { /*! \brief sum gradient statistics */ double sum_grad; /*! \brief sum hessian statistics */ @@ -269,11 +298,11 @@ struct GradStats { */ static const int kSimpleStats = 1; /*! \brief constructor, the object must be cleared during construction */ - explicit GradStats(const TrainParam ¶m) { this->Clear(); } + explicit GradStats(const TrainParam& param) { this->Clear(); } /*! \brief clear the statistics */ inline void Clear() { sum_grad = sum_hess = 0.0f; } /*! \brief check if necessary information is ready */ - inline static void CheckInfo(const MetaInfo &info) {} + inline static void CheckInfo(const MetaInfo& info) {} /*! * \brief accumulate statistics * \param p the gradient pair @@ -285,34 +314,37 @@ struct GradStats { * \param info the additional information * \param ridx instance index of this instance */ - inline void Add(const std::vector &gpair, const MetaInfo &info, + inline void Add(const std::vector& gpair, const MetaInfo& info, bst_uint ridx) { - const bst_gpair &b = gpair[ridx]; + const bst_gpair& b = gpair[ridx]; this->Add(b.grad, b.hess); } /*! \brief calculate leaf weight */ - inline double CalcWeight(const TrainParam ¶m) const { + inline double CalcWeight(const TrainParam& param) const { return xgboost::tree::CalcWeight(param, sum_grad, sum_hess); } /*! \brief calculate gain of the solution */ - inline double CalcGain(const TrainParam ¶m) const { + inline double CalcGain(const TrainParam& param) const { return xgboost::tree::CalcGain(param, sum_grad, sum_hess); } /*! \brief add statistics to the data */ - inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); } + inline void Add(const GradStats& b) { + sum_grad += b.sum_grad; + sum_hess += b.sum_hess; + } /*! \brief same as add, reduce is used in All Reduce */ - inline static void Reduce(GradStats &a, const GradStats &b) { // NOLINT(*) + inline static void Reduce(GradStats& a, const GradStats& b) { // NOLINT(*) a.Add(b); } /*! \brief set current value to a - b */ - inline void SetSubstract(const GradStats &a, const GradStats &b) { + inline void SetSubstract(const GradStats& a, const GradStats& b) { sum_grad = a.sum_grad - b.sum_grad; sum_hess = a.sum_hess - b.sum_hess; } /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } /*! \brief set leaf vector value based on statistics */ - inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const {} + inline void SetLeafVec(const TrainParam& param, bst_float* vec) const {} // constructor to allow inheritance GradStats() {} /*! \brief add statistics to the data */ diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc new file mode 100644 index 000000000..dac1e740f --- /dev/null +++ b/src/tree/updater_fast_hist.cc @@ -0,0 +1,725 @@ +/*! + * Copyright 2017 by Contributors + * \file updater_fast_hist.cc + * \brief use quantized feature values to construct a tree + * \author Philip Cho, Tianqi Checn + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./param.h" +#include "../common/random.h" +#include "../common/bitmap.h" +#include "../common/sync.h" +#include "../common/hist_util.h" +#include "../common/row_set.h" + +namespace xgboost { +namespace tree { + +using xgboost::common::HistCutMatrix; +using xgboost::common::GHistIndexMatrix; +using xgboost::common::GHistIndexRow; +using xgboost::common::GHistEntry; +using xgboost::common::HistCollection; +using xgboost::common::RowSetCollection; +using xgboost::common::GHistRow; +using xgboost::common::GHistBuilder; + +DMLC_REGISTRY_FILE_TAG(updater_fast_hist); + +/*! \brief construct a tree using quantized feature values */ +template +class FastHistMaker: public TreeUpdater { + public: + void Init(const std::vector >& args) override { + param.InitAllowUnknown(args); + is_gmat_initialized_ = false; + } + + void Update(const std::vector& gpair, + DMatrix* dmat, + const std::vector& trees) override { + TStats::CheckInfo(dmat->info()); + if (is_gmat_initialized_ == false) { + double tstart = dmlc::GetTime(); + hmat_.Init(dmat, param.max_bin); + gmat_.cut = &hmat_; + gmat_.Init(dmat); + is_gmat_initialized_ = true; + if (param.verbose > 0) { + LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec"; + } + } + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + TConstraint::Init(¶m, dmat->info().num_col); + // build tree + if (!builder_) { + builder_.reset(new Builder(param)); + } + for (size_t i = 0; i < trees.size(); ++i) { + builder_->Update(gmat_, gpair, dmat, trees[i]); + } + param.learning_rate = lr; + } + + protected: + // training parameter + TrainParam param; + // data sketch + HistCutMatrix hmat_; + GHistIndexMatrix gmat_; + bool is_gmat_initialized_; + + // data structure + /*! \brief per thread x per node entry to store tmp data */ + struct ThreadEntry { + /*! \brief statistics of data */ + TStats stats; + /*! \brief extra statistics of data */ + TStats stats_extra; + /*! \brief last feature value scanned */ + float last_fvalue; + /*! \brief first feature value scanned */ + float first_fvalue; + /*! \brief current best solution */ + SplitEntry best; + // constructor + explicit ThreadEntry(const TrainParam& param) + : stats(param), stats_extra(param) { + } + }; + struct NodeEntry { + /*! \brief statics for node entry */ + TStats stats; + /*! \brief loss of this node, without split */ + bst_float root_gain; + /*! \brief weight calculated related to current data */ + float weight; + /*! \brief current best solution */ + SplitEntry best; + // constructor + explicit NodeEntry(const TrainParam& param) + : stats(param), root_gain(0.0f), weight(0.0f) { + } + }; + // actual builder that runs the algorithm + + struct Builder { + public: + // constructor + explicit Builder(const TrainParam& param) : param(param) { + } + // update one tree, growing + virtual void Update(const GHistIndexMatrix& gmat, + const std::vector& gpair, + DMatrix* p_fmat, + RegTree* p_tree) { + double gstart = dmlc::GetTime(); + + std::vector feat_set(p_fmat->info().num_col); + std::iota(feat_set.begin(), feat_set.end(), 0); + int num_leaves = 0; + unsigned timestamp = 0; + + double tstart; + double time_init_data = 0; + double time_init_new_node = 0; + double time_build_hist = 0; + double time_evaluate_split = 0; + double time_apply_split = 0; + + tstart = dmlc::GetTime(); + this->InitData(gmat, gpair, *p_fmat, *p_tree); + time_init_data = dmlc::GetTime() - tstart; + for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { + tstart = dmlc::GetTime(); + hist_.AddHistRow(nid); + builder_.BuildHist(gpair, row_set_collection_[nid], gmat, hist_[nid]); + time_build_hist += dmlc::GetTime() - tstart; + + tstart = dmlc::GetTime(); + this->InitNewNode(nid, gmat, gpair, *p_fmat, *p_tree); + time_init_new_node += dmlc::GetTime() - tstart; + + tstart = dmlc::GetTime(); + this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree, feat_set); + time_evaluate_split += dmlc::GetTime() - tstart; + qexpand_->push(ExpandEntry(nid, p_tree->GetDepth(nid), + snode[nid].best.loss_chg, + timestamp++)); + ++num_leaves; + } + + while (!qexpand_->empty()) { + const ExpandEntry candidate = qexpand_->top(); + const int nid = candidate.nid; + qexpand_->pop(); + if (candidate.loss_chg <= rt_eps + || (param.max_depth > 0 && candidate.depth == param.max_depth) + || (param.max_leaves > 0 && num_leaves == param.max_leaves) ) { + (*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate); + } else { + tstart = dmlc::GetTime(); + this->ApplySplit(nid, gmat, hist_, *p_fmat, p_tree); + time_apply_split += dmlc::GetTime() - tstart; + + tstart = dmlc::GetTime(); + const int cleft = (*p_tree)[nid].cleft(); + const int cright = (*p_tree)[nid].cright(); + hist_.AddHistRow(cleft); + hist_.AddHistRow(cright); + if (row_set_collection_[cleft].size() < row_set_collection_[cright].size()) { + builder_.BuildHist(gpair, row_set_collection_[cleft], gmat, hist_[cleft]); + builder_.SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); + } else { + builder_.BuildHist(gpair, row_set_collection_[cright], gmat, hist_[cright]); + builder_.SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); + } + time_build_hist += dmlc::GetTime() - tstart; + + tstart = dmlc::GetTime(); + this->InitNewNode(cleft, gmat, gpair, *p_fmat, *p_tree); + this->InitNewNode(cright, gmat, gpair, *p_fmat, *p_tree); + time_init_new_node += dmlc::GetTime() - tstart; + + tstart = dmlc::GetTime(); + this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree, feat_set); + this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree, feat_set); + time_evaluate_split += dmlc::GetTime() - tstart; + + qexpand_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft), + snode[cleft].best.loss_chg, + timestamp++)); + qexpand_->push(ExpandEntry(cright, p_tree->GetDepth(cright), + snode[cright].best.loss_chg, + timestamp++)); + + ++num_leaves; // give two and take one, as parent is no longer a leaf + } + } + + // set all the rest expanding nodes to leaf + // This post condition is not needed in current code, but may be necessary + // when there are stopping rule that leaves qexpand non-empty + while (!qexpand_->empty()) { + const int nid = qexpand_->top().nid; + qexpand_->pop(); + (*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate); + } + // remember auxiliary statistics in the tree node + for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { + p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg; + p_tree->stat(nid).base_weight = snode[nid].weight; + p_tree->stat(nid).sum_hess = static_cast(snode[nid].stats.sum_hess); + snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid)); + } + + if (param.verbose > 0) { + double total_time = dmlc::GetTime() - gstart; + LOG(INFO) << "\nInitData: " + << std::fixed << std::setw(4) << std::setprecision(2) << time_init_data + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_init_data / total_time * 100 << "%)\n" + << "InitNewNode: " + << std::fixed << std::setw(4) << std::setprecision(2) << time_init_new_node + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_init_new_node / total_time * 100 << "%)\n" + << "BuildHist: " + << std::fixed << std::setw(4) << std::setprecision(2) << time_build_hist + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_build_hist / total_time * 100 << "%)\n" + << "EvaluateSplit: " + << std::fixed << std::setw(4) << std::setprecision(2) << time_evaluate_split + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_evaluate_split / total_time * 100 << "%)\n" + << "ApplySplit: " + << std::fixed << std::setw(4) << std::setprecision(2) << time_apply_split + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_apply_split / total_time * 100 << "%)\n" + << "========================================\n" + << "Total: " + << std::fixed << std::setw(4) << std::setprecision(2) << total_time; + } + } + + protected: + // initialize temp data structure + inline void InitData(const GHistIndexMatrix& gmat, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { + CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) + << "ColMakerHist: can only grow new tree"; + CHECK((param.max_depth > 0 || param.max_leaves > 0)) + << "max_depth or max_leaves cannot be both 0 (unlimited); " + << "at least one should be a positive quantity."; + if (param.grow_policy == TrainParam::kDepthWise) { + CHECK(param.max_depth > 0) << "max_depth cannot be 0 (unlimited) " + << "when grow_policy is depthwise."; + } + const auto& info = fmat.info(); + + { + // initialize the row set + row_set_collection_.Clear(); + // initialize histogram collection + size_t nbins = gmat.cut->row_ptr.back(); + hist_.Init(nbins); + + #pragma omp parallel + { + this->nthread = omp_get_num_threads(); + } + builder_.Init(this->nthread, nbins); + + CHECK_EQ(info.root_index.size(), 0); + std::vector& row_indices = row_set_collection_.row_indices_; + // mark subsample and build list of member rows + if (param.subsample < 1.0f) { + std::bernoulli_distribution coin_flip(param.subsample); + auto& rnd = common::GlobalRandom(); + for (bst_uint i = 0; i < info.num_row; ++i) { + if (gpair[i].hess >= 0.0f && coin_flip(rnd)) { + row_indices.push_back(i); + } + } + } else { + for (bst_uint i = 0; i < info.num_row; ++i) { + if (gpair[i].hess >= 0.0f) { + row_indices.push_back(i); + } + } + } + row_set_collection_.Init(); + } + + { + // initialize feature index + unsigned ncol = static_cast(info.num_col); + feat_index.clear(); + for (unsigned i = 0; i < ncol; ++i) { + feat_index.push_back(i); + } + unsigned n = static_cast(param.colsample_bytree * feat_index.size()); + std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom()); + CHECK_GT(n, 0) + << "colsample_bytree=" << param.colsample_bytree + << " is too small that no feature can be included"; + feat_index.resize(n); + } + { + /* determine layout of data */ + const auto nrow = info.num_row; + const auto ncol = info.num_col; + const auto nnz = info.num_nonzero; + // number of discrete bins for feature 0 + const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0]; + if (nrow * ncol == nnz) { + // dense data with zero-based indexing + data_layout_ = kDenseDataZeroBased; + } else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) { + // dense data with one-based indexing + data_layout_ = kDenseDataOneBased; + } else { + // sparse data + data_layout_ = kSparseData; + } + } + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + /* specialized code for dense data: + choose the column that has a least positive number of discrete bins. + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] */ + const std::vector& row_ptr = gmat.cut->row_ptr; + const size_t nfeature = row_ptr.size() - 1; + size_t min_nbins_per_feature = 0; + for (size_t i = 0; i < nfeature; ++i) { + const unsigned nbins = row_ptr[i + 1] - row_ptr[i]; + if (nbins > 0) { + if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { + min_nbins_per_feature = nbins; + fid_least_bins_ = i; + } + } + } + CHECK_GT(min_nbins_per_feature, 0); + } + { + snode.reserve(256); + snode.clear(); + } + { + if (param.grow_policy == TrainParam::kLossGuide) { + qexpand_.reset(new ExpandQueue(loss_guide)); + } else { + qexpand_.reset(new ExpandQueue(depth_wise)); + } + } + } + + inline void EvaluateSplit(int nid, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const DMatrix& fmat, + const RegTree& tree, + const std::vector& feat_set) { + // start enumeration + const MetaInfo& info = fmat.info(); + for (int fid : feat_set) { + this->EnumerateSplit(-1, gmat, hist[nid], snode[nid], constraints_[nid], info, + &snode[nid].best, fid); + this->EnumerateSplit(+1, gmat, hist[nid], snode[nid], constraints_[nid], info, + &snode[nid].best, fid); + } + } + + inline void ApplySplit(int nid, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree) { + // TODO(hcho3): support feature sampling by levels + + /* 1. Create child nodes */ + NodeEntry& e = snode[nid]; + + p_tree->AddChilds(nid); + (*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left()); + // mark right child as 0, to indicate fresh leaf + int cleft = (*p_tree)[nid].cleft(); + int cright = (*p_tree)[nid].cright(); + (*p_tree)[cleft].set_leaf(0.0f, 0); + (*p_tree)[cright].set_leaf(0.0f, 0); + + /* 2. Categorize member rows */ + const bst_omp_uint nthread = static_cast(this->nthread); + row_split_tloc_.resize(nthread); + for (bst_omp_uint i = 0; i < nthread; ++i) { + row_split_tloc_[i].left.clear(); + row_split_tloc_[i].right.clear(); + } + const bool default_left = (*p_tree)[nid].default_left(); + const bst_uint fid = (*p_tree)[nid].split_index(); + const bst_float split_pt = (*p_tree)[nid].split_cond(); + const bst_uint lower_bound = gmat.cut->row_ptr[fid]; + const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1]; + // set the split condition correctly + bst_uint split_cond = 0; + // set the condition + for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) { + if (split_pt == gmat.cut->cut[i]) split_cond = i; + } + + const auto& rowset = row_set_collection_[nid]; + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + /* specialized code for dense data */ + const size_t column_offset = (data_layout_ == kDenseDataOneBased) ? (fid - 1): fid; + ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column_offset, split_cond); + } else { + ApplySplitSparseData(rowset, gmat, &row_split_tloc_, lower_bound, upper_bound, + split_cond, default_left); + } + row_set_collection_.AddSplit( + nid, row_split_tloc_, (*p_tree)[nid].cleft(), (*p_tree)[nid].cright()); + } + + inline void ApplySplitDenseData(const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + size_t column_offset, + bst_uint split_cond) { + std::vector& row_split_tloc = *p_row_split_tloc; + const int K = 8; // loop unrolling factor + const bst_omp_uint nrows = rowset.end - rowset.begin; + const bst_omp_uint rest = nrows % K; + #pragma omp parallel for num_threads(nthread) schedule(static) + for (bst_omp_uint i = 0; i < nrows - rest; i += K) { + bst_uint rid[K]; + unsigned rbin[K]; + bst_uint tid = omp_get_thread_num(); + auto& left = row_split_tloc[tid].left; + auto& right = row_split_tloc[tid].right; + for (int k = 0; k < K; ++k) { + rid[k] = rowset.begin[i + k]; + } + for (int k = 0; k < K; ++k) { + rbin[k] = gmat[rid[k]].index[column_offset]; + } + for (int k = 0; k < K; ++k) { + if (rbin[k] <= split_cond) { + left.push_back(rid[k]); + } else { + right.push_back(rid[k]); + } + } + } + for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { + const bst_uint rid = rowset.begin[i]; + const unsigned rbin = gmat[rid].index[column_offset]; + if (rbin <= split_cond) { + row_split_tloc[0].left.push_back(rid); + } else { + row_split_tloc[0].right.push_back(rid); + } + } + } + + inline void ApplySplitSparseData(const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + bst_uint lower_bound, + bst_uint upper_bound, + bst_uint split_cond, + bool default_left) { + std::vector& row_split_tloc = *p_row_split_tloc; + const int K = 8; // loop unrolling factor + const bst_omp_uint nrows = rowset.end - rowset.begin; + const bst_omp_uint rest = nrows % K; + #pragma omp parallel for num_threads(nthread) schedule(static) + for (bst_omp_uint i = 0; i < nrows - rest; i += K) { + bst_uint rid[K]; + GHistIndexRow row[K]; + const unsigned* p[K]; + bst_uint tid = omp_get_thread_num(); + auto& left = row_split_tloc[tid].left; + auto& right = row_split_tloc[tid].right; + for (int k = 0; k < K; ++k) { + rid[k] = rowset.begin[i + k]; + } + for (int k = 0; k < K; ++k) { + row[k] = gmat[rid[k]]; + } + for (int k = 0; k < K; ++k) { + p[k] = std::lower_bound(row[k].index, row[k].index + row[k].size, lower_bound); + } + for (int k = 0; k < K; ++k) { + if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) { + if (*p[k] <= split_cond) { + left.push_back(rid[k]); + } else { + right.push_back(rid[k]); + } + } else { + if (default_left) { + left.push_back(rid[k]); + } else { + right.push_back(rid[k]); + } + } + } + } + for (bst_omp_uint i = nrows - rest; i < nrows; ++i) { + const bst_uint rid = rowset.begin[i]; + const auto row = gmat[rid]; + const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound); + auto& left = row_split_tloc[0].left; + auto& right = row_split_tloc[0].right; + if (p != row.index + row.size && *p < upper_bound) { + if (*p <= split_cond) { + left.push_back(rid); + } else { + right.push_back(rid); + } + } else { + if (default_left) { + left.push_back(rid); + } else { + right.push_back(rid); + } + } + } + } + + inline void InitNewNode(int nid, + const GHistIndexMatrix& gmat, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { + { + snode.resize(tree.param.num_nodes, NodeEntry(param)); + constraints_.resize(tree.param.num_nodes); + } + + // setup constraints before calculating the weight + { + auto& stats = snode[nid].stats; + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + /* specialized code for dense data + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] */ + GHistRow hist = hist_[nid]; + const std::vector& row_ptr = gmat.cut->row_ptr; + + const size_t ibegin = row_ptr[fid_least_bins_]; + const size_t iend = row_ptr[fid_least_bins_ + 1]; + for (size_t i = ibegin; i < iend; ++i) { + const GHistEntry et = hist.begin[i]; + stats.Add(et.sum_grad, et.sum_hess); + } + } else { + const RowSetCollection::Elem e = row_set_collection_[nid]; + for (const bst_uint* it = e.begin; it < e.end; ++it) { + stats.Add(gpair[*it]); + } + } + if (!tree[nid].is_root()) { + const int pid = tree[nid].parent(); + constraints_[pid].SetChild(param, tree[pid].split_index(), + snode[tree[pid].cleft()].stats, + snode[tree[pid].cright()].stats, + &constraints_[tree[pid].cleft()], + &constraints_[tree[pid].cright()]); + } + } + + // calculating the weights + { + snode[nid].root_gain = static_cast( + constraints_[nid].CalcGain(param, snode[nid].stats)); + snode[nid].weight = static_cast( + constraints_[nid].CalcWeight(param, snode[nid].stats)); + } + } + + // enumerate the split values of specific feature + inline void EnumerateSplit(int d_step, + const GHistIndexMatrix& gmat, + const GHistRow& hist, + const NodeEntry& snode, + const TConstraint& constraint, + const MetaInfo& info, + SplitEntry* p_best, + int fid) { + CHECK(d_step == +1 || d_step == -1); + + // aliases + const std::vector& cut_ptr = gmat.cut->row_ptr; + const std::vector& cut_val = gmat.cut->cut; + + // statistics on both sides of split + TStats c(param); + TStats e(param); + // best split so far + SplitEntry best; + + // bin boundaries + // imin: index (offset) of the minimum value for feature fid + // need this for backward enumeration + const int imin = cut_ptr[fid]; + // ibegin, iend: smallest/largest cut points for feature fid + int ibegin, iend; + if (d_step > 0) { + ibegin = cut_ptr[fid]; + iend = cut_ptr[fid + 1]; + } else { + ibegin = cut_ptr[fid + 1] - 1; + iend = cut_ptr[fid] - 1; + } + + for (int i = ibegin; i != iend; i += d_step) { + // start working + // try to find a split + e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); + if (e.sum_hess >= param.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.sum_hess >= param.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + // forward enumeration: split at right bound of each bin + loss_chg = static_cast( + constraint.CalcSplitGain(param, fid, e, c) - + snode.root_gain); + split_pt = cut_val[i]; + } else { + // backward enumeration: split at left bound of each bin + loss_chg = static_cast( + constraint.CalcSplitGain(param, fid, c, e) - + snode.root_gain); + if (i == imin) { + // for leftmost bin, left bound is the smallest feature value + split_pt = gmat.cut->min_val[fid]; + } else { + split_pt = cut_val[i - 1]; + } + } + best.Update(loss_chg, fid, split_pt, d_step == -1); + } + } + } + p_best->Update(best); + } + + /* tree growing policies */ + struct ExpandEntry { + int nid; + int depth; + bst_float loss_chg; + unsigned timestamp; + ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp) + : nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} + }; + inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.depth == rhs.depth) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.depth > rhs.depth; // favor small depth + } + } + inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.loss_chg == rhs.loss_chg) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg + } + } + + // --data fields-- + const TrainParam& param; + // number of omp thread used during training + int nthread; + // Per feature: shuffle index of each feature index + std::vector feat_index; + // the internal row sets + RowSetCollection row_set_collection_; + // the temp space for split + std::vector row_split_tloc_; + /*! \brief TreeNode Data: statistics for each constructed node */ + std::vector snode; + /*! \brief culmulative histogram of gradients. */ + HistCollection hist_; + size_t fid_least_bins_; + + GHistBuilder builder_; + + // constraint value + std::vector constraints_; + + typedef std::priority_queue, + std::function> ExpandQueue; + std::unique_ptr qexpand_; + + enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; + DataLayout data_layout_; + }; + + std::unique_ptr builder_; +}; + +XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") +.describe("Grow tree using quantized histogram.") +.set_body([]() { + return new FastHistMaker(); + }); + +} // namespace tree +} // namespace xgboost diff --git a/tests/python/test_fast_hist.py b/tests/python/test_fast_hist.py new file mode 100644 index 000000000..a79f402a6 --- /dev/null +++ b/tests/python/test_fast_hist.py @@ -0,0 +1,107 @@ +import xgboost as xgb +import testing as tm +import numpy as np +import unittest + +rng = np.random.RandomState(1994) + + +class TestFastHist(unittest.TestCase): + def test_fast_hist(self): + tm._skip_if_no_sklearn() + from sklearn.datasets import load_digits + from sklearn.cross_validation import train_test_split + + digits = load_digits(2) + X = digits['data'] + y = digits['target'] + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + + param = {'objective': 'binary:logistic', + 'tree_method': 'hist', + 'grow_policy': 'depthwise', + 'max_depth': 3, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert self.non_decreasing(res['test']['auc']) + + param2 = {'objective': 'binary:logistic', + 'tree_method': 'hist', + 'grow_policy': 'lossguide', + 'max_depth': 0, + 'max_leaves': 8, + 'eval_metric': 'auc'} + res = {} + xgb.train(param2, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert self.non_decreasing(res['test']['auc']) + + param3 = {'objective': 'binary:logistic', + 'tree_method': 'hist', + 'grow_policy': 'lossguide', + 'max_depth': 0, + 'max_leaves': 8, + 'max_bin': 16, + 'eval_metric': 'auc'} + res = {} + xgb.train(param3, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) + assert self.non_decreasing(res['train']['auc']) + + # fail-safe test for dense data + from sklearn.datasets import load_svmlight_file + dpath = 'demo/data/' + X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train') + X2 = X2.toarray() + dtrain2 = xgb.DMatrix(X2, label=y2) + + param = {'objective': 'binary:logistic', + 'tree_method': 'hist', + 'grow_policy': 'depthwise', + 'max_depth': 2, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in np.random.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 2 + + dtrain3 = xgb.DMatrix(X2, label=y2) + res = {} + xgb.train(param, dtrain3, 10, [(dtrain3, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in np.random.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 3 + + dtrain4 = xgb.DMatrix(X2, label=y2) + res = {} + xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + # fail-safe test for max_bin=2 + param = {'objective': 'binary:logistic', + 'tree_method': 'hist', + 'grow_policy': 'depthwise', + 'max_depth': 2, + 'eval_metric': 'auc', + 'max_bin': 2} + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + def non_decreasing(self, L): + return all(x <= y for x, y in zip(L, L[1:])) diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 2c8a14194..0b04e359b 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -18,7 +18,9 @@ make -f dmlc-core/scripts/packages.mk lz4 if [ ${TRAVIS_OS_NAME} == "osx" ]; then - echo "USE_OPENMP=0" >> config.mk + echo 'USE_OPENMP=0' >> config.mk + echo 'TMPVAR := $(XGB_PLUGINS)' >> config.mk + echo 'XGB_PLUGINS = $(filter-out plugin/lz4/plugin.mk, $(TMPVAR))' >> config.mk fi if [ ${TASK} == "python_test" ]; then