Histogram Optimized Tree Grower (#1940)

* Support histogram-based algorithm + multiple tree growing strategy

* Add a brand new updater to support histogram-based algorithm, which buckets
  continuous features into discrete bins to speed up training. To use it, set
  `tree_method = fast_hist` to configuration.
* Support multiple tree growing strategies. For now, two policies are supported:
  * `grow_policy=depthwise` (default):  favor splitting at nodes closest to the
    root, i.e. grow depth-wise.
  * `grow_policy=lossguide`: favor splitting at nodes with highest loss change
* Improve single-threaded performance
  * Unroll critical loops
  * Introduce specialized code for dense data (i.e. no missing values)
* Additional training parameters: `max_leaves`, `max_bin`, `grow_policy`, `verbose`

* Adding a small test for hist method

* Fix memory error in row_set.h

When std::vector is resized, a reference to one of its element may become
stale. Any such reference must be updated as well.

* Resolve cross-platform compilation issues

* Versions of g++ older than 4.8 lacks support for a few C++11 features, e.g.
  alignas(*) and new initializer syntax. To support g++ 4.6, use pre-C++11
  initializer and remove alignas(*).
* Versions of MSVC older than 2015 does not support alignas(*). To support
  MSVC 2012, remove alignas(*).
* For g++ 4.8 and newer, alignas(*) is enabled for performance benefits.
* Some old compilers (MSVC 2012, g++ 4.6) do not support template aliases
  (which uses `using` to declate type aliases). So always use `typedef`.

* Fix a host of CI issues

* Remove dependency for libz on osx
* Fix heading for hist_util
* Fix minor style issues
* Add missing #include
* Remove extraneous logging

* Enable tree_method=hist in R

* Renaming HistMaker to GHistBuilder to avoid confusion

* Fix R integration

* Respond to style comments

* Consistent tie-breaking for priority queue using timestamps

* Last-minute style fixes

* Fix issuecomment-271977647

The way we quantize data is broken. The agaricus data consists of all
categorical values. When NAs are converted into 0's,
`HistCutMatrix::Init` assign both 0's and 1's to the same single bin.

Why? gmat only the smallest value (0) and an upper bound (2), which is twice
the maximum value (1). Add the maximum value itself to gmat to fix the issue.

* Fix issuecomment-272266358

* Remove padding from cut values for the continuous case
* For categorical/ordinal values, use midpoints as bin boundaries to be safe

* Fix CI issue -- do not use xrange(*)

* Fix corner case in quantile sketch

Signed-off-by: Philip Cho <chohyu01@cs.washington.edu>

* Adding a test for an edge case in quantile sketcher

max_bin=2 used to cause an exception.

* Fix fast_hist test

The test used to require a strictly increasing Test AUC for all examples.
One of them exhibits a small blip in Test AUC before achieving a Test AUC
of 1. (See bottom.)

Solution: do not require monotonic increase for this particular example.

[0] train-auc:0.99989 test-auc:0.999497
[1] train-auc:1 test-auc:0.999749
[2] train-auc:1 test-auc:0.999749
[3] train-auc:1 test-auc:0.999749
[4] train-auc:1 test-auc:0.999749
[5] train-auc:1 test-auc:0.999497
[6] train-auc:1 test-auc:1
[7] train-auc:1 test-auc:1
[8] train-auc:1 test-auc:1
[9] train-auc:1 test-auc:1
This commit is contained in:
Philip Cho 2017-01-13 09:25:55 -08:00 committed by Tianqi Chen
parent ef8d92fc52
commit aeb4e76118
13 changed files with 1509 additions and 31 deletions

View File

@ -42,6 +42,7 @@
#include "../src/tree/tree_model.cc"
#include "../src/tree/tree_updater.cc"
#include "../src/tree/updater_colmaker.cc"
#include "../src/tree/updater_fast_hist.cc"
#include "../src/tree/updater_prune.cc"
#include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc"
@ -52,6 +53,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/hist_util.cc"
// c_api
#include "../src/c_api/c_api.cc"

View File

@ -39,6 +39,15 @@
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
#endif
/*!
* \brief Check if alignas(*) keyword is supported. (g++ 4.8 or higher)
*/
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8
#define XGBOOST_ALIGNAS(X) alignas(X)
#else
#define XGBOOST_ALIGNAS(X)
#endif
/*! \brief namespace of xgboo st*/
namespace xgboost {
/*!

View File

@ -127,6 +127,7 @@ struct SparseBatch {
/*! \brief length of the instance */
bst_uint length;
/*! \brief constructor */
Inst() : data(0), length(0) {}
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
/*! \brief get i-th pair in the sparse vector*/
inline const Entry& operator[](size_t i) const {

227
src/common/hist_util.cc Normal file
View File

@ -0,0 +1,227 @@
/*!
* Copyright 2017 by Contributors
* \file hist_util.h
* \brief Utilities to store histograms
* \author Philip Cho, Tianqi Chen
*/
#include <dmlc/omp.h>
#include <vector>
#include "./sync.h"
#include "./hist_util.h"
#include "./quantile.h"
namespace xgboost {
namespace common {
void HistCutMatrix::Init(DMatrix* p_fmat, size_t max_num_bins) {
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
const MetaInfo& info = p_fmat->info();
// safe factor for better accuracy
const int kFactor = 8;
std::vector<WXQSketch> sketchs;
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
nthread = std::max(nthread / 2, 1);
unsigned nstep = (info.num_col + nthread - 1) / nthread;
unsigned ncol = static_cast<unsigned>(info.num_col);
sketchs.resize(info.num_col);
for (auto& s : sketchs) {
s.Init(info.num_row, 1.0 / (max_num_bins * kFactor));
}
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
#pragma omp parallel num_threads(nthread)
{
CHECK_EQ(nthread, omp_get_num_threads());
unsigned tid = static_cast<unsigned>(omp_get_thread_num());
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);
for (size_t i = 0; i < batch.size; ++i) { // NOLINT(*)
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
}
}
}
}
}
// gather the histogram data
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
std::vector<WXQSketch::SummaryContainer> summary_array;
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
WXQSketch::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_num_bins * kFactor);
summary_array[i].SetPrune(out, max_num_bins * kFactor);
}
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
this->min_val.resize(info.num_col);
row_ptr.push_back(0);
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WXQSketch::SummaryContainer a;
a.Reserve(max_num_bins);
a.SetPrune(summary_array[fid], max_num_bins);
const bst_float mval = a.data[0].value;
this->min_val[fid] = mval - fabs(mval);
if (a.size > 1 && a.size <= 16) {
/* specialized code categorial / ordinal data -- use midpoints */
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0;
if (i == 1 || cpt > cut.back()) {
cut.push_back(cpt);
}
}
} else {
for (size_t i = 2; i < a.size; ++i) {
bst_float cpt = a.data[i - 1].value;
if (i == 2 || cpt > cut.back()) {
cut.push_back(cpt);
}
}
}
// push a value that is greater than anything
if (a.size != 0) {
bst_float cpt = a.data[a.size - 1].value;
// this must be bigger than last value in a scale
bst_float last = cpt + fabs(cpt);
cut.push_back(last);
}
row_ptr.push_back(cut.size());
}
}
void GHistIndexMatrix::Init(DMatrix* p_fmat) {
CHECK(cut != nullptr);
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
hit_count.resize(cut->row_ptr.back(), 0);
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
nthread = std::max(nthread / 2, 1);
iter->BeforeFirst();
row_ptr.push_back(0);
while (iter->Next()) {
const RowBatch& batch = iter->Value();
size_t rbegin = row_ptr.size() - 1;
for (size_t i = 0; i < batch.size; ++i) {
row_ptr.push_back(batch[i].length + row_ptr.back());
}
index.resize(row_ptr.back());
CHECK_GT(cut->cut.size(), 0);
CHECK_EQ(cut->row_ptr.back(), cut->cut.size());
omp_ulong bsize = static_cast<omp_ulong>(batch.size);
#pragma omp parallel for num_threads(nthread) schedule(static)
for (omp_ulong i = 0; i < bsize; ++i) { // NOLINT(*)
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
RowBatch::Inst inst = batch[i];
CHECK_EQ(ibegin + inst.length, iend);
for (bst_uint j = 0; j < inst.length; ++j) {
unsigned fid = inst[j].index;
auto cbegin = cut->cut.begin() + cut->row_ptr[fid];
auto cend = cut->cut.begin() + cut->row_ptr[fid + 1];
CHECK(cbegin != cend);
auto it = std::upper_bound(cbegin, cend, inst[j].fvalue);
if (it == cend) it = cend - 1;
unsigned idx = static_cast<unsigned>(it - cut->cut.begin());
index[ibegin + j] = idx;
}
std::sort(index.begin() + ibegin, index.begin() + iend);
}
}
}
void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist) {
CHECK(!data_.empty()) << "GHistBuilder must be initialized";
CHECK_EQ(data_.size(), nbins_ * nthread_) << "invalid dimensions for temp buffer";
std::fill(data_.begin(), data_.end(), GHistEntry());
const int K = 8; // loop unrolling factor
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
const bst_omp_uint nrows = row_indices.end - row_indices.begin;
const bst_omp_uint rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
const bst_omp_uint tid = omp_get_thread_num();
const size_t off = tid * nbins_;
bst_uint rid[K];
bst_gpair stat[K];
size_t ibegin[K], iend[K];
for (int k = 0; k < K; ++k) {
rid[k] = row_indices.begin[i + k];
}
for (int k = 0; k < K; ++k) {
stat[k] = gpair[rid[k]];
}
for (int k = 0; k < K; ++k) {
ibegin[k] = static_cast<size_t>(gmat.row_ptr[rid[k]]);
iend[k] = static_cast<size_t>(gmat.row_ptr[rid[k] + 1]);
}
for (int k = 0; k < K; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const size_t bin = gmat.index[j];
data_[off + bin].Add(stat[k]);
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = row_indices.begin[i];
const bst_gpair stat = gpair[rid];
const size_t ibegin = static_cast<size_t>(gmat.row_ptr[rid]);
const size_t iend = static_cast<size_t>(gmat.row_ptr[rid + 1]);
for (size_t j = ibegin; j < iend; ++j) {
const size_t bin = gmat.index[j];
data_[bin].Add(stat);
}
}
/* reduction */
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
hist.begin[bin_id].Add(data_[tid * nbins_ + bin_id]);
}
}
}
void GHistBuilder::SubtractionTrick(GHistRow self,
GHistRow sibling,
GHistRow parent) {
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
const bst_omp_uint nbins = static_cast<bst_omp_uint>(nbins_);
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]);
}
}
} // namespace common
} // namespace xgboost

214
src/common/hist_util.h Normal file
View File

@ -0,0 +1,214 @@
/*!
* Copyright 2017 by Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
*/
#ifndef XGBOOST_COMMON_HIST_UTIL_H_
#define XGBOOST_COMMON_HIST_UTIL_H_
#include <xgboost/data.h>
#include <limits>
#include <vector>
#include "row_set.h"
namespace xgboost {
namespace common {
/*! \brief sums of gradient statistics corresponding to a histogram bin */
struct GHistEntry {
/*! \brief sum of first-order gradient statistics */
double sum_grad;
/*! \brief sum of second-order gradient statistics */
double sum_hess;
GHistEntry() : sum_grad(0), sum_hess(0) {}
/*! \brief add a bst_gpair to the sum */
inline void Add(const bst_gpair& e) {
sum_grad += e.grad;
sum_hess += e.hess;
}
/*! \brief add a GHistEntry to the sum */
inline void Add(const GHistEntry& e) {
sum_grad += e.sum_grad;
sum_hess += e.sum_hess;
}
/*! \brief set sum to be difference of two GHistEntry's */
inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}
};
/*! \brief Cut configuration for one feature */
struct HistCutUnit {
/*! \brief the index pointer of each histunit */
const bst_float* cut;
/*! \brief number of cutting point, containing the maximum point */
size_t size;
// default constructor
HistCutUnit() {}
// constructor
HistCutUnit(const bst_float* cut, unsigned size)
: cut(cut), size(size) {}
};
/*! \brief cut configuration for all the features */
struct HistCutMatrix {
/*! \brief actual unit pointer */
std::vector<unsigned> row_ptr;
/*! \brief minimum value of each feature */
std::vector<bst_float> min_val;
/*! \brief the cut field */
std::vector<bst_float> cut;
/*! \brief Get histogram bound for fid */
inline HistCutUnit operator[](unsigned fid) const {
return HistCutUnit(dmlc::BeginPtr(cut) + row_ptr[fid],
row_ptr[fid + 1] - row_ptr[fid]);
}
// create histogram cut matrix given statistics from data
// using approximate quantile sketch approach
void Init(DMatrix* p_fmat, size_t max_num_bins);
};
/*!
* \brief A single row in global histogram index.
* Directly represent the global index in the histogram entry.
*/
struct GHistIndexRow {
/*! \brief The index of the histogram */
const unsigned* index;
/*! \brief The size of the histogram */
unsigned size;
GHistIndexRow() {}
GHistIndexRow(const unsigned* index, unsigned size)
: index(index), size(size) {}
};
/*!
* \brief preprocessed global index matrix, in CSR format
* Transform floating values to integer index in histogram
* This is a global histogram index.
*/
struct GHistIndexMatrix {
/*! \brief row pointer */
std::vector<unsigned> row_ptr;
/*! \brief The index data */
std::vector<unsigned> index;
/*! \brief hit count of each index */
std::vector<unsigned> hit_count;
/*! \brief optional remap index from outter row_id -> internal row_id*/
std::vector<unsigned> remap_index;
/*! \brief The corresponding cuts */
const HistCutMatrix* cut;
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat);
// build remap
void Remap();
// get i-th row
inline GHistIndexRow operator[](bst_uint i) const {
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
}
};
/*!
* \brief histogram of graident statistics for a single node.
* Consists of multiple GHistEntry's, each entry showing total graident statistics
* for that particular bin
* Uses global bin id so as to represent all features simultaneously
*/
struct GHistRow {
/*! \brief base pointer to first entry */
GHistEntry* begin;
/*! \brief number of entries */
unsigned size;
GHistRow() {}
GHistRow(GHistEntry* begin, unsigned size)
: begin(begin), size(size) {}
};
/*!
* \brief histogram of gradient statistics for multiple nodes
*/
class HistCollection {
public:
// access histogram for i-th node
inline GHistRow operator[](bst_uint nid) const {
const size_t kMax = std::numeric_limits<size_t>::max();
CHECK_NE(row_ptr_[nid], kMax);
return GHistRow(const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_);
}
// have we computed a histogram for i-th node?
inline bool RowExists(bst_uint nid) const {
const size_t kMax = std::numeric_limits<size_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != kMax);
}
// initialize histogram collection
inline void Init(size_t nbins) {
nbins_ = nbins;
row_ptr_.clear();
data_.clear();
}
// create an empty histogram for i-th node
inline void AddHistRow(bst_uint nid) {
const size_t kMax = std::numeric_limits<size_t>::max();
if (nid >= row_ptr_.size()) {
row_ptr_.resize(nid + 1, kMax);
}
CHECK_EQ(row_ptr_[nid], kMax);
row_ptr_[nid] = data_.size();
data_.resize(data_.size() + nbins_);
}
private:
/*! \brief number of all bins over all features */
size_t nbins_;
std::vector<GHistEntry> data_;
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
std::vector<size_t> row_ptr_;
};
/*!
* \brief builder for histograms of gradient statistics
*/
class GHistBuilder {
public:
// initialize builder
inline void Init(size_t nthread, size_t nbins) {
nthread_ = nthread;
nbins_ = nbins;
data_.resize(nthread * nbins, GHistEntry());
}
// construct a histogram via histogram aggregation
void BuildHist(const std::vector<bst_gpair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist);
// construct a histogram via subtraction trick
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
private:
/*! \brief number of threads for parallel computation */
size_t nthread_;
/*! \brief number of all bins over all features */
size_t nbins_;
std::vector<GHistEntry> data_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_

View File

@ -348,10 +348,12 @@ struct WXQSummary : public WQSummary<DType, RType> {
this->CopyFrom(src); return;
}
RType begin = src.data[0].rmax;
size_t n = maxsize - 1, nbig = 0;
// n is number of points exclude the min/max points
size_t n = maxsize - 2, nbig = 0;
// these is the range of data exclude the min/max point
RType range = src.data[src.size - 1].rmin - begin;
// prune off zero weights
if (range == 0.0f) {
if (range == 0.0f || maxsize <= 2) {
// special case, contain only two effective data pts
this->data[0] = src.data[0];
this->data[1] = src.data[src.size - 1];
@ -360,16 +362,21 @@ struct WXQSummary : public WQSummary<DType, RType> {
} else {
range = std::max(range, static_cast<RType>(1e-3f));
}
// Get a big enough chunk size, bigger than range / n
// (multiply by 2 is a safe factor)
const RType chunk = 2 * range / n;
// minimized range
RType mrange = 0;
{
// first scan, grab all the big chunk
// moving block index
// moving block index, exclude the two ends.
size_t bid = 0;
for (size_t i = 1; i < src.size; ++i) {
for (size_t i = 1; i < src.size - 1; ++i) {
// detect big chunk data point in the middle
// always save these data points.
if (CheckLarge(src.data[i], chunk)) {
if (bid != i - 1) {
// accumulate the range of the rest points
mrange += src.data[i].rmax_prev() - src.data[bid].rmin_next();
}
bid = i; ++nbig;
@ -379,17 +386,18 @@ struct WXQSummary : public WQSummary<DType, RType> {
mrange += src.data[src.size-1].rmax_prev() - src.data[bid].rmin_next();
}
}
if (nbig >= n - 1) {
// assert: there cannot be more than n big data points
if (nbig >= n) {
// see what was the case
LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n;
LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize
<< ", range=" << range << ", chunk=" << chunk;
src.Print();
CHECK(nbig < n - 1) << "quantile: too many large chunk";
CHECK(nbig < n) << "quantile: too many large chunk";
}
this->data[0] = src.data[0];
this->size = 1;
// use smaller size
// The counter on the rest of points, to be selected equally from small chunks.
n = n - nbig;
// find the rest of point
size_t bid = 0, k = 1, lastidx = 0;

104
src/common/row_set.h Normal file
View File

@ -0,0 +1,104 @@
/*!
* Copyright 2017 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
*/
#ifndef XGBOOST_COMMON_ROW_SET_H_
#define XGBOOST_COMMON_ROW_SET_H_
#include <xgboost/data.h>
#include <algorithm>
#include <vector>
namespace xgboost {
namespace common {
/*! \brief collection of rowset */
class RowSetCollection {
public:
/*! \brief subset of rows */
struct Elem {
const bst_uint* begin;
const bst_uint* end;
Elem(void)
: begin(nullptr), end(nullptr) {}
Elem(const bst_uint* begin,
const bst_uint* end)
: begin(begin), end(end) {}
inline size_t size() const {
return end - begin;
}
};
/* \brief specifies how to split a rowset into two */
struct Split {
std::vector<bst_uint> left;
std::vector<bst_uint> right;
};
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}
// clear up things
inline void Clear() {
row_indices_.clear();
elem_of_each_node_.clear();
}
// initialize node id 0->everything
inline void Init() {
CHECK_EQ(elem_of_each_node_.size(), 0);
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(Elem(begin, end));
}
// split rowset into two
inline void AddSplit(unsigned node_id,
const std::vector<Split>& row_split_tloc,
unsigned left_node_id,
unsigned right_node_id) {
const Elem e = elem_of_each_node_[node_id];
const unsigned nthread = row_split_tloc.size();
CHECK(e.begin != nullptr);
bst_uint* all_begin = dmlc::BeginPtr(row_indices_);
bst_uint* begin = all_begin + (e.begin - all_begin);
bst_uint* it = begin;
// TODO(hcho3): parallelize this section
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
it += row_split_tloc[tid].left.size();
}
bst_uint* split_pt = it;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
it += row_split_tloc[tid].right.size();
}
if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr));
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr));
}
elem_of_each_node_[left_node_id] = Elem(begin, split_pt);
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr);
}
// stores the row indices in the set
std::vector<bst_uint> row_indices_;
private:
// vector: node_id -> elements
std::vector<Elem> elem_of_each_node_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_ROW_SET_H_

View File

@ -6,6 +6,7 @@
*/
#include <dmlc/omp.h>
#include <dmlc/parameter.h>
#include <dmlc/timer.h>
#include <xgboost/logging.h>
#include <xgboost/gbm.h>
#include <xgboost/tree_updater.h>
@ -369,7 +370,7 @@ class GBTree : public GradientBooster {
const int nthread = omp_get_max_threads();
CHECK_EQ(num_group, mparam.num_output_group);
InitThreadTemp(nthread);
std::vector<bst_float> &preds = *out_preds;
std::vector<bst_float>& preds = *out_preds;
CHECK_EQ(mparam.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group);
@ -380,17 +381,38 @@ class GBTree : public GradientBooster {
while (iter->Next()) {
const RowBatch &batch = iter->Value();
// parallel over local batch
const int K = 8;
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
const bst_omp_uint rest = nsize % K;
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
for (bst_omp_uint i = 0; i < nsize - rest; i += K) {
const int tid = omp_get_thread_num();
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
CHECK_LT(static_cast<size_t>(ridx), info.num_row);
RegTree::FVec& feats = thread_temp[tid];
int64_t ridx[K];
RowBatch::Inst inst[K];
for (int k = 0; k < K; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
}
for (int k = 0; k < K; ++k) {
inst[k] = batch[i + k];
}
for (int k = 0; k < K; ++k) {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid;
preds[offset] +=
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
&feats, tree_begin, tree_end);
}
}
}
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
RegTree::FVec& feats = thread_temp[0];
const int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
const RowBatch::Inst inst = batch[i];
for (int gid = 0; gid < num_group; ++gid) {
size_t offset = ridx * num_group + gid;
const size_t offset = ridx * num_group + gid;
preds[offset] +=
self->PredValue(batch[i], gid, info.GetRoot(ridx),
self->PredValue(inst, gid, info.GetRoot(ridx),
&feats, tree_begin, tree_end);
}
}

View File

@ -99,6 +99,7 @@ struct LearnerTrainParam
.add_enum("auto", 0)
.add_enum("approx", 1)
.add_enum("exact", 2)
.add_enum("hist", 3)
.describe("Choice of tree construction method.");
DMLC_DECLARE_FIELD(test_flag).set_default("")
.describe("Internal test flag");
@ -167,7 +168,31 @@ class LearnerImpl : public Learner {
cfg_["max_delta_step"] = "0.7";
}
if (cfg_.count("updater") == 0) {
if (tparam.tree_method == 3) {
/* histogram-based algorithm */
if (cfg_.count("updater") == 0) {
LOG(CONSOLE) << "Tree method is selected to be \'hist\', "
<< "which uses histogram aggregation for faster training. "
<< "Using default sequence of updaters: grow_fast_histmaker,prune";
cfg_["updater"] = "grow_fast_histmaker,prune";
} else {
const std::string first_str = "grow_fast_histmaker";
if (first_str.length() <= cfg_["updater"].length()
&& std::equal(first_str.begin(), first_str.end(), cfg_["updater"].begin())) {
// updater sequence starts with "grow_fast_histmaker"
LOG(CONSOLE) << "Tree method is selected to be \'hist\', "
<< "which uses histogram aggregation for faster training. "
<< "Using custom sequence of updaters: " << cfg_["updater"];
} else {
// updater sequence does not start with "grow_fast_histmaker"
LOG(CONSOLE) << "Tree method is selected to be \'hist\', but the given "
<< "sequence of updaters is not compatible; "
<< "grow_fast_histmaker must run first. "
<< "Using default sequence of updaters: grow_fast_histmaker,prune";
cfg_["updater"] = "grow_fast_histmaker,prune";
}
}
} else if (cfg_.count("updater") == 0) {
if (tparam.dsplit == 1) {
cfg_["updater"] = "distcol";
} else if (tparam.dsplit == 2) {
@ -379,8 +404,8 @@ class LearnerImpl : public Learner {
protected:
// check if p_train is ready to used by training.
// if not, initialize the column access.
inline void LazyInitDMatrix(DMatrix *p_train) {
if (!p_train->HaveColAccess()) {
inline void LazyInitDMatrix(DMatrix* p_train) {
if (tparam.tree_method != 3 && !p_train->HaveColAccess()) {
int ncol = static_cast<int>(p_train->info().num_col);
std::vector<bool> enabled(ncol, true);
// set max row per batch to limited value

View File

@ -31,6 +31,14 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
float min_split_loss;
// maximum depth of a tree
int max_depth;
// maximum number of leaves
int max_leaves;
// if using histogram based algorithm, maximum number of bins per feature
int max_bin;
// growing policy
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
int grow_policy;
int verbose;
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
@ -77,11 +85,32 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(min_split_loss)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Minimum loss reduction required to make a further partition.");
.describe(
"Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(verbose)
.set_lower_bound(0)
.set_default(0)
.describe(
"Setting verbose flag with a positive value causes the updater "
"to print out *detailed* list of tasks and their runtime");
DMLC_DECLARE_FIELD(max_depth)
.set_lower_bound(0)
.set_default(6)
.describe("Maximum depth of the tree.");
.describe(
"Maximum depth of the tree; 0 indicates no limit; a limit is required "
"for depthwise policy");
DMLC_DECLARE_FIELD(max_leaves).set_lower_bound(0).set_default(0).describe(
"Maximum number of leaves; 0 indicates no limit.");
DMLC_DECLARE_FIELD(max_bin).set_lower_bound(2).set_default(256).describe(
"if using histogram-based algorithm, maximum number of bins per feature");
DMLC_DECLARE_FIELD(grow_policy)
.set_default(kDepthWise)
.add_enum("depthwise", kDepthWise)
.add_enum("lossguide", kLossGuide)
.describe(
"Tree growing policy. 0: favor splitting at nodes closest to the node, "
"i.e. grow depth-wise. 1: favor splitting at nodes with highest loss "
"change. (cf. LightGBM)");
DMLC_DECLARE_FIELD(min_child_weight)
.set_lower_bound(0.0f)
.set_default(1.0f)
@ -258,7 +287,7 @@ XGB_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
}
/*! \brief core statistics used for tree construction */
struct GradStats {
struct XGBOOST_ALIGNAS(16) GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
@ -269,11 +298,11 @@ struct GradStats {
*/
static const int kSimpleStats = 1;
/*! \brief constructor, the object must be cleared during construction */
explicit GradStats(const TrainParam &param) { this->Clear(); }
explicit GradStats(const TrainParam& param) { this->Clear(); }
/*! \brief clear the statistics */
inline void Clear() { sum_grad = sum_hess = 0.0f; }
/*! \brief check if necessary information is ready */
inline static void CheckInfo(const MetaInfo &info) {}
inline static void CheckInfo(const MetaInfo& info) {}
/*!
* \brief accumulate statistics
* \param p the gradient pair
@ -285,34 +314,37 @@ struct GradStats {
* \param info the additional information
* \param ridx instance index of this instance
*/
inline void Add(const std::vector<bst_gpair> &gpair, const MetaInfo &info,
inline void Add(const std::vector<bst_gpair>& gpair, const MetaInfo& info,
bst_uint ridx) {
const bst_gpair &b = gpair[ridx];
const bst_gpair& b = gpair[ridx];
this->Add(b.grad, b.hess);
}
/*! \brief calculate leaf weight */
inline double CalcWeight(const TrainParam &param) const {
inline double CalcWeight(const TrainParam& param) const {
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
}
/*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const {
inline double CalcGain(const TrainParam& param) const {
return xgboost::tree::CalcGain(param, sum_grad, sum_hess);
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); }
inline void Add(const GradStats& b) {
sum_grad += b.sum_grad;
sum_hess += b.sum_hess;
}
/*! \brief same as add, reduce is used in All Reduce */
inline static void Reduce(GradStats &a, const GradStats &b) { // NOLINT(*)
inline static void Reduce(GradStats& a, const GradStats& b) { // NOLINT(*)
a.Add(b);
}
/*! \brief set current value to a - b */
inline void SetSubstract(const GradStats &a, const GradStats &b) {
inline void SetSubstract(const GradStats& a, const GradStats& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}
/*! \return whether the statistics is not used yet */
inline bool Empty() const { return sum_hess == 0.0; }
/*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const {}
inline void SetLeafVec(const TrainParam& param, bst_float* vec) const {}
// constructor to allow inheritance
GradStats() {}
/*! \brief add statistics to the data */

View File

@ -0,0 +1,725 @@
/*!
* Copyright 2017 by Contributors
* \file updater_fast_hist.cc
* \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn
*/
#include <dmlc/timer.h>
#include <xgboost/tree_updater.h>
#include <cmath>
#include <vector>
#include <algorithm>
#include <queue>
#include <iomanip>
#include <numeric>
#include "./param.h"
#include "../common/random.h"
#include "../common/bitmap.h"
#include "../common/sync.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
namespace xgboost {
namespace tree {
using xgboost::common::HistCutMatrix;
using xgboost::common::GHistIndexMatrix;
using xgboost::common::GHistIndexRow;
using xgboost::common::GHistEntry;
using xgboost::common::HistCollection;
using xgboost::common::RowSetCollection;
using xgboost::common::GHistRow;
using xgboost::common::GHistBuilder;
DMLC_REGISTRY_FILE_TAG(updater_fast_hist);
/*! \brief construct a tree using quantized feature values */
template<typename TStats, typename TConstraint>
class FastHistMaker: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.InitAllowUnknown(args);
is_gmat_initialized_ = false;
}
void Update(const std::vector<bst_gpair>& gpair,
DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
if (is_gmat_initialized_ == false) {
double tstart = dmlc::GetTime();
hmat_.Init(dmat, param.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(dmat);
is_gmat_initialized_ = true;
if (param.verbose > 0) {
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
}
}
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
TConstraint::Init(&param, dmat->info().num_col);
// build tree
if (!builder_) {
builder_.reset(new Builder(param));
}
for (size_t i = 0; i < trees.size(); ++i) {
builder_->Update(gmat_, gpair, dmat, trees[i]);
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
// data sketch
HistCutMatrix hmat_;
GHistIndexMatrix gmat_;
bool is_gmat_initialized_;
// data structure
/*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry {
/*! \brief statistics of data */
TStats stats;
/*! \brief extra statistics of data */
TStats stats_extra;
/*! \brief last feature value scanned */
float last_fvalue;
/*! \brief first feature value scanned */
float first_fvalue;
/*! \brief current best solution */
SplitEntry best;
// constructor
explicit ThreadEntry(const TrainParam& param)
: stats(param), stats_extra(param) {
}
};
struct NodeEntry {
/*! \brief statics for node entry */
TStats stats;
/*! \brief loss of this node, without split */
bst_float root_gain;
/*! \brief weight calculated related to current data */
float weight;
/*! \brief current best solution */
SplitEntry best;
// constructor
explicit NodeEntry(const TrainParam& param)
: stats(param), root_gain(0.0f), weight(0.0f) {
}
};
// actual builder that runs the algorithm
struct Builder {
public:
// constructor
explicit Builder(const TrainParam& param) : param(param) {
}
// update one tree, growing
virtual void Update(const GHistIndexMatrix& gmat,
const std::vector<bst_gpair>& gpair,
DMatrix* p_fmat,
RegTree* p_tree) {
double gstart = dmlc::GetTime();
std::vector<int> feat_set(p_fmat->info().num_col);
std::iota(feat_set.begin(), feat_set.end(), 0);
int num_leaves = 0;
unsigned timestamp = 0;
double tstart;
double time_init_data = 0;
double time_init_new_node = 0;
double time_build_hist = 0;
double time_evaluate_split = 0;
double time_apply_split = 0;
tstart = dmlc::GetTime();
this->InitData(gmat, gpair, *p_fmat, *p_tree);
time_init_data = dmlc::GetTime() - tstart;
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
tstart = dmlc::GetTime();
hist_.AddHistRow(nid);
builder_.BuildHist(gpair, row_set_collection_[nid], gmat, hist_[nid]);
time_build_hist += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->InitNewNode(nid, gmat, gpair, *p_fmat, *p_tree);
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree, feat_set);
time_evaluate_split += dmlc::GetTime() - tstart;
qexpand_->push(ExpandEntry(nid, p_tree->GetDepth(nid),
snode[nid].best.loss_chg,
timestamp++));
++num_leaves;
}
while (!qexpand_->empty()) {
const ExpandEntry candidate = qexpand_->top();
const int nid = candidate.nid;
qexpand_->pop();
if (candidate.loss_chg <= rt_eps
|| (param.max_depth > 0 && candidate.depth == param.max_depth)
|| (param.max_leaves > 0 && num_leaves == param.max_leaves) ) {
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
} else {
tstart = dmlc::GetTime();
this->ApplySplit(nid, gmat, hist_, *p_fmat, p_tree);
time_apply_split += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
const int cleft = (*p_tree)[nid].cleft();
const int cright = (*p_tree)[nid].cright();
hist_.AddHistRow(cleft);
hist_.AddHistRow(cright);
if (row_set_collection_[cleft].size() < row_set_collection_[cright].size()) {
builder_.BuildHist(gpair, row_set_collection_[cleft], gmat, hist_[cleft]);
builder_.SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else {
builder_.BuildHist(gpair, row_set_collection_[cright], gmat, hist_[cright]);
builder_.SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
}
time_build_hist += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->InitNewNode(cleft, gmat, gpair, *p_fmat, *p_tree);
this->InitNewNode(cright, gmat, gpair, *p_fmat, *p_tree);
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree, feat_set);
this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree, feat_set);
time_evaluate_split += dmlc::GetTime() - tstart;
qexpand_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft),
snode[cleft].best.loss_chg,
timestamp++));
qexpand_->push(ExpandEntry(cright, p_tree->GetDepth(cright),
snode[cright].best.loss_chg,
timestamp++));
++num_leaves; // give two and take one, as parent is no longer a leaf
}
}
// set all the rest expanding nodes to leaf
// This post condition is not needed in current code, but may be necessary
// when there are stopping rule that leaves qexpand non-empty
while (!qexpand_->empty()) {
const int nid = qexpand_->top().nid;
qexpand_->pop();
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
}
// remember auxiliary statistics in the tree node
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
p_tree->stat(nid).base_weight = snode[nid].weight;
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid));
}
if (param.verbose > 0) {
double total_time = dmlc::GetTime() - gstart;
LOG(INFO) << "\nInitData: "
<< std::fixed << std::setw(4) << std::setprecision(2) << time_init_data
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
<< time_init_data / total_time * 100 << "%)\n"
<< "InitNewNode: "
<< std::fixed << std::setw(4) << std::setprecision(2) << time_init_new_node
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
<< time_init_new_node / total_time * 100 << "%)\n"
<< "BuildHist: "
<< std::fixed << std::setw(4) << std::setprecision(2) << time_build_hist
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
<< time_build_hist / total_time * 100 << "%)\n"
<< "EvaluateSplit: "
<< std::fixed << std::setw(4) << std::setprecision(2) << time_evaluate_split
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
<< time_evaluate_split / total_time * 100 << "%)\n"
<< "ApplySplit: "
<< std::fixed << std::setw(4) << std::setprecision(2) << time_apply_split
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
<< time_apply_split / total_time * 100 << "%)\n"
<< "========================================\n"
<< "Total: "
<< std::fixed << std::setw(4) << std::setprecision(2) << total_time;
}
}
protected:
// initialize temp data structure
inline void InitData(const GHistIndexMatrix& gmat,
const std::vector<bst_gpair>& gpair,
const DMatrix& fmat,
const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMakerHist: can only grow new tree";
CHECK((param.max_depth > 0 || param.max_leaves > 0))
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
<< "at least one should be a positive quantity.";
if (param.grow_policy == TrainParam::kDepthWise) {
CHECK(param.max_depth > 0) << "max_depth cannot be 0 (unlimited) "
<< "when grow_policy is depthwise.";
}
const auto& info = fmat.info();
{
// initialize the row set
row_set_collection_.Clear();
// initialize histogram collection
size_t nbins = gmat.cut->row_ptr.back();
hist_.Init(nbins);
#pragma omp parallel
{
this->nthread = omp_get_num_threads();
}
builder_.Init(this->nthread, nbins);
CHECK_EQ(info.root_index.size(), 0);
std::vector<bst_uint>& row_indices = row_set_collection_.row_indices_;
// mark subsample and build list of member rows
if (param.subsample < 1.0f) {
std::bernoulli_distribution coin_flip(param.subsample);
auto& rnd = common::GlobalRandom();
for (bst_uint i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f && coin_flip(rnd)) {
row_indices.push_back(i);
}
}
} else {
for (bst_uint i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f) {
row_indices.push_back(i);
}
}
}
row_set_collection_.Init();
}
{
// initialize feature index
unsigned ncol = static_cast<unsigned>(info.num_col);
feat_index.clear();
for (unsigned i = 0; i < ncol; ++i) {
feat_index.push_back(i);
}
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
CHECK_GT(n, 0)
<< "colsample_bytree=" << param.colsample_bytree
<< " is too small that no feature can be included";
feat_index.resize(n);
}
{
/* determine layout of data */
const auto nrow = info.num_row;
const auto ncol = info.num_col;
const auto nnz = info.num_nonzero;
// number of discrete bins for feature 0
const unsigned nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
if (nrow * ncol == nnz) {
// dense data with zero-based indexing
data_layout_ = kDenseDataZeroBased;
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
// dense data with one-based indexing
data_layout_ = kDenseDataOneBased;
} else {
// sparse data
data_layout_ = kSparseData;
}
}
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data:
choose the column that has a least positive number of discrete bins.
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
const size_t nfeature = row_ptr.size() - 1;
size_t min_nbins_per_feature = 0;
for (size_t i = 0; i < nfeature; ++i) {
const unsigned nbins = row_ptr[i + 1] - row_ptr[i];
if (nbins > 0) {
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
min_nbins_per_feature = nbins;
fid_least_bins_ = i;
}
}
}
CHECK_GT(min_nbins_per_feature, 0);
}
{
snode.reserve(256);
snode.clear();
}
{
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand_.reset(new ExpandQueue(loss_guide));
} else {
qexpand_.reset(new ExpandQueue(depth_wise));
}
}
}
inline void EvaluateSplit(int nid,
const GHistIndexMatrix& gmat,
const HistCollection& hist,
const DMatrix& fmat,
const RegTree& tree,
const std::vector<int>& feat_set) {
// start enumeration
const MetaInfo& info = fmat.info();
for (int fid : feat_set) {
this->EnumerateSplit(-1, gmat, hist[nid], snode[nid], constraints_[nid], info,
&snode[nid].best, fid);
this->EnumerateSplit(+1, gmat, hist[nid], snode[nid], constraints_[nid], info,
&snode[nid].best, fid);
}
}
inline void ApplySplit(int nid,
const GHistIndexMatrix& gmat,
const HistCollection& hist,
const DMatrix& fmat,
RegTree* p_tree) {
// TODO(hcho3): support feature sampling by levels
/* 1. Create child nodes */
NodeEntry& e = snode[nid];
p_tree->AddChilds(nid);
(*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left());
// mark right child as 0, to indicate fresh leaf
int cleft = (*p_tree)[nid].cleft();
int cright = (*p_tree)[nid].cright();
(*p_tree)[cleft].set_leaf(0.0f, 0);
(*p_tree)[cright].set_leaf(0.0f, 0);
/* 2. Categorize member rows */
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread);
row_split_tloc_.resize(nthread);
for (bst_omp_uint i = 0; i < nthread; ++i) {
row_split_tloc_[i].left.clear();
row_split_tloc_[i].right.clear();
}
const bool default_left = (*p_tree)[nid].default_left();
const bst_uint fid = (*p_tree)[nid].split_index();
const bst_float split_pt = (*p_tree)[nid].split_cond();
const bst_uint lower_bound = gmat.cut->row_ptr[fid];
const bst_uint upper_bound = gmat.cut->row_ptr[fid + 1];
// set the split condition correctly
bst_uint split_cond = 0;
// set the condition
for (unsigned i = gmat.cut->row_ptr[fid]; i < gmat.cut->row_ptr[fid + 1]; ++i) {
if (split_pt == gmat.cut->cut[i]) split_cond = i;
}
const auto& rowset = row_set_collection_[nid];
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data */
const size_t column_offset = (data_layout_ == kDenseDataOneBased) ? (fid - 1): fid;
ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column_offset, split_cond);
} else {
ApplySplitSparseData(rowset, gmat, &row_split_tloc_, lower_bound, upper_bound,
split_cond, default_left);
}
row_set_collection_.AddSplit(
nid, row_split_tloc_, (*p_tree)[nid].cleft(), (*p_tree)[nid].cright());
}
inline void ApplySplitDenseData(const RowSetCollection::Elem rowset,
const GHistIndexMatrix& gmat,
std::vector<RowSetCollection::Split>* p_row_split_tloc,
size_t column_offset,
bst_uint split_cond) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
const int K = 8; // loop unrolling factor
const bst_omp_uint nrows = rowset.end - rowset.begin;
const bst_omp_uint rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
bst_uint rid[K];
unsigned rbin[K];
bst_uint tid = omp_get_thread_num();
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
for (int k = 0; k < K; ++k) {
rid[k] = rowset.begin[i + k];
}
for (int k = 0; k < K; ++k) {
rbin[k] = gmat[rid[k]].index[column_offset];
}
for (int k = 0; k < K; ++k) {
if (rbin[k] <= split_cond) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = rowset.begin[i];
const unsigned rbin = gmat[rid].index[column_offset];
if (rbin <= split_cond) {
row_split_tloc[0].left.push_back(rid);
} else {
row_split_tloc[0].right.push_back(rid);
}
}
}
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
const GHistIndexMatrix& gmat,
std::vector<RowSetCollection::Split>* p_row_split_tloc,
bst_uint lower_bound,
bst_uint upper_bound,
bst_uint split_cond,
bool default_left) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
const int K = 8; // loop unrolling factor
const bst_omp_uint nrows = rowset.end - rowset.begin;
const bst_omp_uint rest = nrows % K;
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint i = 0; i < nrows - rest; i += K) {
bst_uint rid[K];
GHistIndexRow row[K];
const unsigned* p[K];
bst_uint tid = omp_get_thread_num();
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
for (int k = 0; k < K; ++k) {
rid[k] = rowset.begin[i + k];
}
for (int k = 0; k < K; ++k) {
row[k] = gmat[rid[k]];
}
for (int k = 0; k < K; ++k) {
p[k] = std::lower_bound(row[k].index, row[k].index + row[k].size, lower_bound);
}
for (int k = 0; k < K; ++k) {
if (p[k] != row[k].index + row[k].size && *p[k] < upper_bound) {
if (*p[k] <= split_cond) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
}
} else {
if (default_left) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
}
}
}
}
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
const bst_uint rid = rowset.begin[i];
const auto row = gmat[rid];
const auto p = std::lower_bound(row.index, row.index + row.size, lower_bound);
auto& left = row_split_tloc[0].left;
auto& right = row_split_tloc[0].right;
if (p != row.index + row.size && *p < upper_bound) {
if (*p <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
}
} else {
if (default_left) {
left.push_back(rid);
} else {
right.push_back(rid);
}
}
}
}
inline void InitNewNode(int nid,
const GHistIndexMatrix& gmat,
const std::vector<bst_gpair>& gpair,
const DMatrix& fmat,
const RegTree& tree) {
{
snode.resize(tree.param.num_nodes, NodeEntry(param));
constraints_.resize(tree.param.num_nodes);
}
// setup constraints before calculating the weight
{
auto& stats = snode[nid].stats;
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
GHistRow hist = hist_[nid];
const std::vector<unsigned>& row_ptr = gmat.cut->row_ptr;
const size_t ibegin = row_ptr[fid_least_bins_];
const size_t iend = row_ptr[fid_least_bins_ + 1];
for (size_t i = ibegin; i < iend; ++i) {
const GHistEntry et = hist.begin[i];
stats.Add(et.sum_grad, et.sum_hess);
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const bst_uint* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
}
}
if (!tree[nid].is_root()) {
const int pid = tree[nid].parent();
constraints_[pid].SetChild(param, tree[pid].split_index(),
snode[tree[pid].cleft()].stats,
snode[tree[pid].cright()].stats,
&constraints_[tree[pid].cleft()],
&constraints_[tree[pid].cright()]);
}
}
// calculating the weights
{
snode[nid].root_gain = static_cast<float>(
constraints_[nid].CalcGain(param, snode[nid].stats));
snode[nid].weight = static_cast<float>(
constraints_[nid].CalcWeight(param, snode[nid].stats));
}
}
// enumerate the split values of specific feature
inline void EnumerateSplit(int d_step,
const GHistIndexMatrix& gmat,
const GHistRow& hist,
const NodeEntry& snode,
const TConstraint& constraint,
const MetaInfo& info,
SplitEntry* p_best,
int fid) {
CHECK(d_step == +1 || d_step == -1);
// aliases
const std::vector<unsigned>& cut_ptr = gmat.cut->row_ptr;
const std::vector<bst_float>& cut_val = gmat.cut->cut;
// statistics on both sides of split
TStats c(param);
TStats e(param);
// best split so far
SplitEntry best;
// bin boundaries
// imin: index (offset) of the minimum value for feature fid
// need this for backward enumeration
const int imin = cut_ptr[fid];
// ibegin, iend: smallest/largest cut points for feature fid
int ibegin, iend;
if (d_step > 0) {
ibegin = cut_ptr[fid];
iend = cut_ptr[fid + 1];
} else {
ibegin = cut_ptr[fid + 1] - 1;
iend = cut_ptr[fid] - 1;
}
for (int i = ibegin; i != iend; i += d_step) {
// start working
// try to find a split
e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
if (e.sum_hess >= param.min_child_weight) {
c.SetSubstract(snode.stats, e);
if (c.sum_hess >= param.min_child_weight) {
bst_float loss_chg;
bst_float split_pt;
if (d_step > 0) {
// forward enumeration: split at right bound of each bin
loss_chg = static_cast<bst_float>(
constraint.CalcSplitGain(param, fid, e, c) -
snode.root_gain);
split_pt = cut_val[i];
} else {
// backward enumeration: split at left bound of each bin
loss_chg = static_cast<bst_float>(
constraint.CalcSplitGain(param, fid, c, e) -
snode.root_gain);
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
split_pt = gmat.cut->min_val[fid];
} else {
split_pt = cut_val[i - 1];
}
}
best.Update(loss_chg, fid, split_pt, d_step == -1);
}
}
}
p_best->Update(best);
}
/* tree growing policies */
struct ExpandEntry {
int nid;
int depth;
bst_float loss_chg;
unsigned timestamp;
ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp)
: nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
};
inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.loss_chg == rhs.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg
}
}
// --data fields--
const TrainParam& param;
// number of omp thread used during training
int nthread;
// Per feature: shuffle index of each feature index
std::vector<bst_uint> feat_index;
// the internal row sets
RowSetCollection row_set_collection_;
// the temp space for split
std::vector<RowSetCollection::Split> row_split_tloc_;
/*! \brief TreeNode Data: statistics for each constructed node */
std::vector<NodeEntry> snode;
/*! \brief culmulative histogram of gradients. */
HistCollection hist_;
size_t fid_least_bins_;
GHistBuilder builder_;
// constraint value
std::vector<TConstraint> constraints_;
typedef std::priority_queue<ExpandEntry,
std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>> ExpandQueue;
std::unique_ptr<ExpandQueue> qexpand_;
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_;
};
std::unique_ptr<Builder> builder_;
};
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([]() {
return new FastHistMaker<GradStats, NoConstraint>();
});
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,107 @@
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
rng = np.random.RandomState(1994)
class TestFastHist(unittest.TestCase):
def test_fast_hist(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
param2 = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'lossguide',
'max_depth': 0,
'max_leaves': 8,
'eval_metric': 'auc'}
res = {}
xgb.train(param2, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
param3 = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'lossguide',
'max_depth': 0,
'max_leaves': 8,
'max_bin': 16,
'eval_metric': 'auc'}
res = {}
xgb.train(param3, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
dpath = 'demo/data/'
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, 10, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc',
'max_bin': 2}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all(x <= y for x, y in zip(L, L[1:]))

View File

@ -18,7 +18,9 @@ make -f dmlc-core/scripts/packages.mk lz4
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "USE_OPENMP=0" >> config.mk
echo 'USE_OPENMP=0' >> config.mk
echo 'TMPVAR := $(XGB_PLUGINS)' >> config.mk
echo 'XGB_PLUGINS = $(filter-out plugin/lz4/plugin.mk, $(TMPVAR))' >> config.mk
fi
if [ ${TASK} == "python_test" ]; then