[TREE] finish move of updater

This commit is contained in:
tqchen 2016-01-01 20:41:20 -08:00
parent 4adc4cf0b9
commit d4677b6561
14 changed files with 358 additions and 343 deletions

View File

@ -11,6 +11,7 @@
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <string>
#include "./base.h" #include "./base.h"
#include "./data.h" #include "./data.h"
#include "./tree_model.h" #include "./tree_model.h"

View File

@ -5,16 +5,17 @@
* base64 is easier to store and pass as text format in mapreduce * base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_UTILS_BASE64_INL_H_ #ifndef XGBOOST_COMMON_BASE64_H_
#define XGBOOST_UTILS_BASE64_INL_H_ #define XGBOOST_COMMON_BASE64_H_
#include <dmlc/logging.h>
#include <cctype> #include <cctype>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include "./io.h" #include "./io.h"
namespace xgboost { namespace xgboost {
namespace utils { namespace common {
/*! \brief buffer reader of the stream that allows you to get */ /*! \brief buffer reader of the stream that allows you to get */
class StreamBufferReader { class StreamBufferReader {
public: public:
@ -26,7 +27,7 @@ class StreamBufferReader {
/*! /*!
* \brief set input stream * \brief set input stream
*/ */
inline void set_stream(IStream *stream) { inline void set_stream(dmlc::Stream *stream) {
stream_ = stream; stream_ = stream;
read_len_ = read_ptr_ = 1; read_len_ = read_ptr_ = 1;
} }
@ -51,7 +52,7 @@ class StreamBufferReader {
private: private:
/*! \brief the underlying stream */ /*! \brief the underlying stream */
IStream *stream_; dmlc::Stream *stream_;
/*! \brief buffer to hold data */ /*! \brief buffer to hold data */
std::string buffer_; std::string buffer_;
/*! \brief length of valid data in buffer */ /*! \brief length of valid data in buffer */
@ -80,9 +81,9 @@ static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64 } // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */ /*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream { class Base64InStream: public dmlc::Stream {
public: public:
explicit Base64InStream(IStream *fs) : reader_(256) { explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
reader_.set_stream(fs); reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0; num_prev = 0; tmp_ch = 0;
} }
@ -134,20 +135,22 @@ class Base64InStream: public IStream {
nvalue = DecodeTable[tmp_ch] << 18; nvalue = DecodeTable[tmp_ch] << 18;
{ {
// second byte // second byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)), tmp_ch = reader_.GetChar();
"invalid base64 format"); CHECK(tmp_ch != EOF && !isspace(tmp_ch)) << "invalid base64 format";
nvalue |= DecodeTable[tmp_ch] << 12; nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen; *cptr++ = (nvalue >> 16) & 0xFF; --tlen;
} }
{ {
// third byte // third byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)), tmp_ch = reader_.GetChar();
"invalid base64 format"); CHECK(tmp_ch != EOF && !isspace(tmp_ch)) << "invalid base64 format";
// handle termination // handle termination
if (tmp_ch == '=') { if (tmp_ch == '=') {
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format"); tmp_ch = reader_.GetChar();
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)), CHECK(tmp_ch == '=') << "invalid base64 format";
"invalid base64 format"); tmp_ch = reader_.GetChar();
CHECK(tmp_ch == EOF || isspace(tmp_ch))
<< "invalid base64 format";
break; break;
} }
nvalue |= DecodeTable[tmp_ch] << 6; nvalue |= DecodeTable[tmp_ch] << 6;
@ -159,11 +162,13 @@ class Base64InStream: public IStream {
} }
{ {
// fourth byte // fourth byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)), tmp_ch = reader_.GetChar();
"invalid base64 format"); CHECK(tmp_ch != EOF && !isspace(tmp_ch))
<< "invalid base64 format";
if (tmp_ch == '=') { if (tmp_ch == '=') {
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)), tmp_ch = reader_.GetChar();
"invalid base64 format"); CHECK(tmp_ch == EOF || isspace(tmp_ch))
<< "invalid base64 format";
break; break;
} }
nvalue |= DecodeTable[tmp_ch]; nvalue |= DecodeTable[tmp_ch];
@ -177,12 +182,12 @@ class Base64InStream: public IStream {
tmp_ch = reader_.GetChar(); tmp_ch = reader_.GetChar();
} }
if (kStrictCheck) { if (kStrictCheck) {
utils::Check(tlen == 0, "Base64InStream: read incomplete"); CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
} }
return size - tlen; return size - tlen;
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write"); LOG(FATAL) << "Base64InStream do not support write";
} }
private: private:
@ -194,9 +199,9 @@ class Base64InStream: public IStream {
static const bool kStrictCheck = false; static const bool kStrictCheck = false;
}; };
/*! \brief the stream that write to base64, note we take from file pointers */ /*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream { class Base64OutStream: public dmlc::Stream {
public: public:
explicit Base64OutStream(IStream *fp) : fp(fp) { explicit Base64OutStream(dmlc::Stream *fp) : fp(fp) {
buf_top = 0; buf_top = 0;
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
@ -218,7 +223,7 @@ class Base64OutStream: public IStream {
} }
} }
virtual size_t Read(void *ptr, size_t size) { virtual size_t Read(void *ptr, size_t size) {
utils::Error("Base64OutStream do not support read"); LOG(FATAL) << "Base64OutStream do not support read";
return 0; return 0;
} }
/*! /*!
@ -245,7 +250,7 @@ class Base64OutStream: public IStream {
} }
private: private:
IStream *fp; dmlc::Stream *fp;
int buf_top; int buf_top;
unsigned char buf[4]; unsigned char buf[4];
std::string out_buf; std::string out_buf;
@ -262,6 +267,6 @@ class Base64OutStream: public IStream {
} }
} }
}; };
} // namespace utils } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_UTILS_BASE64_INL_H_ #endif // XGBOOST_COMMON_BASE64_H_

View File

@ -8,8 +8,8 @@
#ifndef XGBOOST_COMMON_BITMAP_H_ #ifndef XGBOOST_COMMON_BITMAP_H_
#define XGBOOST_COMMON_BITMAP_H_ #define XGBOOST_COMMON_BITMAP_H_
#include <vector>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <vector>
namespace xgboost { namespace xgboost {
namespace common { namespace common {

View File

@ -4,18 +4,17 @@
* \brief helper class to load in configures from file * \brief helper class to load in configures from file
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_UTILS_CONFIG_H_ #ifndef XGBOOST_COMMON_CONFIG_H_
#define XGBOOST_UTILS_CONFIG_H_ #define XGBOOST_COMMON_CONFIG_H_
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <istream> #include <istream>
#include <fstream> #include <fstream>
#include "./utils.h"
namespace xgboost { namespace xgboost {
namespace utils { namespace common {
/*! /*!
* \brief base implementation of config reader * \brief base implementation of config reader
*/ */
@ -79,11 +78,11 @@ class ConfigReaderBase {
case '\\': *tok += this->GetChar(); break; case '\\': *tok += this->GetChar(); break;
case '\"': return; case '\"': return;
case '\r': case '\r':
case '\n': Error("ConfigReader: unterminated string"); case '\n': LOG(FATAL)<< "ConfigReader: unterminated string";
default: *tok += ch_buf; default: *tok += ch_buf;
} }
} }
Error("ConfigReader: unterminated string"); LOG(FATAL) << "ConfigReader: unterminated string";
} }
inline void ParseStrML(std::string *tok) { inline void ParseStrML(std::string *tok) {
while ((ch_buf = this->GetChar()) != EOF) { while ((ch_buf = this->GetChar()) != EOF) {
@ -93,7 +92,7 @@ class ConfigReaderBase {
default: *tok += ch_buf; default: *tok += ch_buf;
} }
} }
Error("unterminated string"); LOG(FATAL) << "unterminated string";
} }
// return newline // return newline
inline bool GetNextToken(std::string *tok) { inline bool GetNextToken(std::string *tok) {
@ -106,13 +105,13 @@ class ConfigReaderBase {
if (tok->length() == 0) { if (tok->length() == 0) {
ParseStr(tok); ch_buf = this->GetChar(); return new_line; ParseStr(tok); ch_buf = this->GetChar(); return new_line;
} else { } else {
Error("ConfigReader: token followed directly by string"); LOG(FATAL) << "ConfigReader: token followed directly by string";
} }
case '\'': case '\'':
if (tok->length() == 0) { if (tok->length() == 0) {
ParseStrML(tok); ch_buf = this->GetChar(); return new_line; ParseStrML(tok); ch_buf = this->GetChar(); return new_line;
} else { } else {
Error("ConfigReader: token followed directly by string"); LOG(FATAL) << "ConfigReader: token followed directly by string";
} }
case '=': case '=':
if (tok->length() == 0) { if (tok->length() == 0) {
@ -177,7 +176,7 @@ class ConfigIterator: public ConfigStreamReader {
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) { explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
fi.open(fname); fi.open(fname);
if (fi.fail()) { if (fi.fail()) {
utils::Error("cannot open file %s", fname); LOG(FATAL) << "cannot open file " << fname;
} }
ConfigReaderBase::Init(); ConfigReaderBase::Init();
} }
@ -189,6 +188,6 @@ class ConfigIterator: public ConfigStreamReader {
private: private:
std::ifstream fi; std::ifstream fi;
}; };
} // namespace utils } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_UTILS_CONFIG_H_ #endif // XGBOOST_COMMON_CONFIG_H_

View File

@ -4,19 +4,19 @@
* \brief util to compute quantiles * \brief util to compute quantiles
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_UTILS_QUANTILE_H_ #ifndef XGBOOST_COMMON_QUANTILE_H_
#define XGBOOST_UTILS_QUANTILE_H_ #define XGBOOST_COMMON_QUANTILE_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include "./io.h"
#include "./utils.h"
namespace xgboost { namespace xgboost {
namespace utils { namespace common {
/*! /*!
* \brief experimental wsummary * \brief experimental wsummary
* \tparam DType type of data content * \tparam DType type of data content
@ -35,7 +35,7 @@ struct WQSummary {
/*! \brief the value of data */ /*! \brief the value of data */
DType value; DType value;
// constructor // constructor
Entry(void) {} Entry() {}
// constructor // constructor
Entry(RType rmin, RType rmax, RType wmin, DType value) Entry(RType rmin, RType rmax, RType wmin, DType value)
: rmin(rmin), rmax(rmax), wmin(wmin), value(value) {} : rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
@ -44,15 +44,15 @@ struct WQSummary {
* \param eps the tolerate level for violating the relation * \param eps the tolerate level for violating the relation
*/ */
inline void CheckValid(RType eps = 0) const { inline void CheckValid(RType eps = 0) const {
utils::Assert(rmin >= 0 && rmax >= 0 && wmin >= 0, "nonneg constraint"); CHECK(rmin >= 0 && rmax >= 0 && wmin >= 0) << "nonneg constraint";
utils::Assert(rmax- rmin - wmin > -eps, "relation constraint: min/max"); CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
} }
/*! \return rmin estimation for v strictly bigger than value */ /*! \return rmin estimation for v strictly bigger than value */
inline RType rmin_next(void) const { inline RType rmin_next() const {
return rmin + wmin; return rmin + wmin;
} }
/*! \return rmax estimation for v strictly smaller than value */ /*! \return rmax estimation for v strictly smaller than value */
inline RType rmax_prev(void) const { inline RType rmax_prev() const {
return rmax - wmin; return rmax - wmin;
} }
}; };
@ -65,7 +65,7 @@ struct WQSummary {
// weight of instance // weight of instance
RType weight; RType weight;
// default constructor // default constructor
QEntry(void) {} QEntry() {}
// constructor // constructor
QEntry(DType value, RType weight) QEntry(DType value, RType weight)
: value(value), weight(weight) {} : value(value), weight(weight) {}
@ -113,7 +113,7 @@ struct WQSummary {
/*! /*!
* \return the maximum error of the Summary * \return the maximum error of the Summary
*/ */
inline RType MaxError(void) const { inline RType MaxError() const {
RType res = data[0].rmax - data[0].rmin - data[0].wmin; RType res = data[0].rmax - data[0].rmin - data[0].wmin;
for (size_t i = 1; i < size; ++i) { for (size_t i = 1; i < size; ++i) {
res = std::max(data[i].rmax_prev() - data[i - 1].rmin_next(), res); res = std::max(data[i].rmax_prev() - data[i - 1].rmin_next(), res);
@ -147,7 +147,7 @@ struct WQSummary {
} }
} }
/*! \return maximum rank in the summary */ /*! \return maximum rank in the summary */
inline RType MaxRank(void) const { inline RType MaxRank() const {
return data[size - 1].rmax; return data[size - 1].rmax;
} }
/*! /*!
@ -168,8 +168,8 @@ struct WQSummary {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
data[i].CheckValid(eps); data[i].CheckValid(eps);
if (i != 0) { if (i != 0) {
utils::Assert(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin, "rmin range constraint"); CHECK(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin) << "rmin range constraint";
utils::Assert(data[i].rmax >= data[i - 1].rmax + data[i].wmin, "rmax range constraint"); CHECK(data[i].rmax >= data[i - 1].rmax + data[i].wmin) << "rmax range constraint";
} }
} }
} }
@ -196,7 +196,7 @@ struct WQSummary {
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2 // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
while (i < src.size - 1 while (i < src.size - 1
&& dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; && dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
utils::Assert(i != src.size - 1, "this cannot happen"); CHECK(i != src.size - 1);
if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) { if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) {
if (i != lastidx) { if (i != lastidx) {
data[size++] = src.data[i]; lastidx = i; data[size++] = src.data[i]; lastidx = i;
@ -224,7 +224,7 @@ struct WQSummary {
if (sb.size == 0) { if (sb.size == 0) {
this->CopyFrom(sa); return; this->CopyFrom(sa); return;
} }
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge"); CHECK(sa.size > 0 && sb.size > 0);
const Entry *a = sa.data, *a_end = sa.data + sa.size; const Entry *a = sa.data, *a_end = sa.data + sa.size;
const Entry *b = sb.data, *b_end = sb.data + sb.size; const Entry *b = sb.data, *b_end = sb.data + sb.size;
// extended rmin value // extended rmin value
@ -272,18 +272,19 @@ struct WQSummary {
RType err_mingap, err_maxgap, err_wgap; RType err_mingap, err_maxgap, err_wgap;
this->FixError(&err_mingap, &err_maxgap, &err_wgap); this->FixError(&err_mingap, &err_maxgap, &err_wgap);
if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) { if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) {
utils::Printf("INFO: mingap=%g, maxgap=%g, wgap=%g\n", LOG(INFO) << "mingap=" << err_mingap
err_mingap, err_maxgap, err_wgap); << ", maxgap=" << err_maxgap
<< ", wgap=" << err_wgap;
} }
CHECK(size <= sa.size + sb.size) << "bug in combine";
utils::Assert(size <= sa.size + sb.size, "bug in combine");
} }
// helper function to print the current content of sketch // helper function to print the current content of sketch
inline void Print() const { inline void Print() const {
for (size_t i = 0; i < this->size; ++i) { for (size_t i = 0; i < this->size; ++i) {
utils::Printf("[%lu] rmin=%g, rmax=%g, wmin=%g, v=%g\n", LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin
i, data[i].rmin, data[i].rmax, << ", rmax=" << data[i].rmax
data[i].wmin, data[i].value); << ", wmin=" << data[i].wmin
<< ", v=" << data[i].value;
} }
} }
// try to fix rounding error // try to fix rounding error
@ -320,7 +321,7 @@ struct WQSummary {
for (size_t i = 0; i < this->size; ++i) { for (size_t i = 0; i < this->size; ++i) {
if (data[i].rmin + data[i].wmin > data[i].rmax + tol || if (data[i].rmin + data[i].wmin > data[i].rmax + tol ||
data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) { data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) {
utils::Printf("----%s: Check not Pass------\n", msg); LOG(INFO) << "----------check not pass----------";
this->Print(); this->Print();
return false; return false;
} }
@ -380,12 +381,11 @@ struct WXQSummary : public WQSummary<DType, RType> {
} }
if (nbig >= n - 1) { if (nbig >= n - 1) {
// see what was the case // see what was the case
utils::Printf("LOG: check quantile stats, nbig=%lu, n=%lu\n", nbig, n); LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n;
utils::Printf("LOG: srcsize=%lu, maxsize=%lu, range=%g, chunk=%g\n", LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize
src.size, maxsize, static_cast<double>(range), << ", range=" << range << ", chunk=" << chunk;
static_cast<double>(chunk));
src.Print(); src.Print();
utils::Assert(nbig < n - 1, "quantile: too many large chunk"); CHECK(nbig < n - 1) << "quantile: too many large chunk";
} }
this->data[0] = src.data[0]; this->data[0] = src.data[0];
this->size = 1; this->size = 1;
@ -440,7 +440,7 @@ struct GKSummary {
/*! \brief the value of data */ /*! \brief the value of data */
DType value; DType value;
// constructor // constructor
Entry(void) {} Entry() {}
// constructor // constructor
Entry(RType rmin, RType rmax, DType value) Entry(RType rmin, RType rmax, DType value)
: rmin(rmin), rmax(rmax), value(value) {} : rmin(rmin), rmax(rmax), value(value) {}
@ -470,7 +470,7 @@ struct GKSummary {
GKSummary(Entry *data, size_t size) GKSummary(Entry *data, size_t size)
: data(data), size(size) {} : data(data), size(size) {}
/*! \brief the maximum error of the summary */ /*! \brief the maximum error of the summary */
inline RType MaxError(void) const { inline RType MaxError() const {
RType res = 0; RType res = 0;
for (size_t i = 1; i < size; ++i) { for (size_t i = 1; i < size; ++i) {
res = std::max(data[i].rmax - data[i-1].rmin, res); res = std::max(data[i].rmax - data[i-1].rmin, res);
@ -478,7 +478,7 @@ struct GKSummary {
return res; return res;
} }
/*! \return maximum rank in the summary */ /*! \return maximum rank in the summary */
inline RType MaxRank(void) const { inline RType MaxRank() const {
return data[size - 1].rmax; return data[size - 1].rmax;
} }
/*! /*!
@ -493,7 +493,7 @@ struct GKSummary {
// assume always valid // assume always valid
} }
/*! \brief used for debug purpose, print the summary */ /*! \brief used for debug purpose, print the summary */
inline void Print(void) const { inline void Print() const {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
std::cout << "x=" << data[i].value << "\t" std::cout << "x=" << data[i].value << "\t"
<< "[" << data[i].rmin << "," << data[i].rmax << "]" << "[" << data[i].rmin << "," << data[i].rmax << "]"
@ -536,7 +536,7 @@ struct GKSummary {
if (sb.size == 0) { if (sb.size == 0) {
this->CopyFrom(sa); return; this->CopyFrom(sa); return;
} }
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge"); CHECK(sa.size > 0 && sb.size > 0) << "invalid input for merge";
const Entry *a = sa.data, *a_end = sa.data + sa.size; const Entry *a = sa.data, *a_end = sa.data + sa.size;
const Entry *b = sb.data, *b_end = sb.data + sb.size; const Entry *b = sb.data, *b_end = sb.data + sb.size;
this->size = sa.size + sb.size; this->size = sa.size + sb.size;
@ -569,7 +569,7 @@ struct GKSummary {
++dst; ++b; ++dst; ++b;
} while (b != b_end); } while (b != b_end);
} }
utils::Assert(dst == data + size, "bug in combine"); CHECK(dst == data + size) << "bug in combine";
} }
}; };
@ -592,15 +592,15 @@ class QuantileSketchTemplate {
std::vector<Entry> space; std::vector<Entry> space;
SummaryContainer(const SummaryContainer &src) : Summary(NULL, src.size) { SummaryContainer(const SummaryContainer &src) : Summary(NULL, src.size) {
this->space = src.space; this->space = src.space;
this->data = BeginPtr(this->space); this->data = dmlc::BeginPtr(this->space);
} }
SummaryContainer(void) : Summary(NULL, 0) { SummaryContainer() : Summary(NULL, 0) {
} }
/*! \brief reserve space for summary */ /*! \brief reserve space for summary */
inline void Reserve(size_t size) { inline void Reserve(size_t size) {
if (size > space.size()) { if (size > space.size()) {
space.resize(size); space.resize(size);
this->data = BeginPtr(space); this->data = dmlc::BeginPtr(space);
} }
} }
/*! /*!
@ -610,7 +610,7 @@ class QuantileSketchTemplate {
*/ */
inline void SetMerge(const Summary *begin, inline void SetMerge(const Summary *begin,
const Summary *end) { const Summary *end) {
utils::Assert(begin < end, "can not set combine to empty instance"); CHECK(begin < end) << "can not set combine to empty instance";
size_t len = end - begin; size_t len = end - begin;
if (len == 1) { if (len == 1) {
this->Reserve(begin[0].size); this->Reserve(begin[0].size);
@ -655,11 +655,10 @@ class QuantileSketchTemplate {
/*! \brief load data structure from input stream */ /*! \brief load data structure from input stream */
template<typename TStream> template<typename TStream>
inline void Load(TStream &fi) { // NOLINT(*) inline void Load(TStream &fi) { // NOLINT(*)
utils::Check(fi.Read(&this->size, sizeof(this->size)) != 0, "invalid SummaryArray 1"); CHECK_EQ(fi.Read(&this->size, sizeof(this->size)), sizeof(this->size));
this->Reserve(this->size); this->Reserve(this->size);
if (this->size != 0) { if (this->size != 0) {
utils::Check(fi.Read(this->data, this->size * sizeof(Entry)) != 0, CHECK_EQ(fi.Read(this->data, this->size * sizeof(Entry)), sizeof(this->size));
"invalid SummaryArray 2");
} }
} }
}; };
@ -678,8 +677,8 @@ class QuantileSketchTemplate {
} }
// check invariant // check invariant
size_t n = (1UL << nlevel); size_t n = (1UL << nlevel);
utils::Assert(n * limit_size >= maxn, "invalid init parameter"); CHECK(n * limit_size >= maxn) << "invalid init parameter";
utils::Assert(nlevel <= limit_size * eps, "invalid init parameter"); CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
// lazy reserve the space, if there is only one value, no need to allocate space // lazy reserve the space, if there is only one value, no need to allocate space
inqueue.queue.resize(1); inqueue.queue.resize(1);
inqueue.qtail = 0; inqueue.qtail = 0;
@ -707,7 +706,7 @@ class QuantileSketchTemplate {
inqueue.Push(x, w); inqueue.Push(x, w);
} }
/*! \brief push up temp */ /*! \brief push up temp */
inline void PushTemp(void) { inline void PushTemp() {
temp.Reserve(limit_size * 2); temp.Reserve(limit_size * 2);
for (size_t l = 1; true; ++l) { for (size_t l = 1; true; ++l) {
this->InitLevel(l + 1); this->InitLevel(l + 1);
@ -769,7 +768,7 @@ class QuantileSketchTemplate {
data.resize(limit_size * nlevel); data.resize(limit_size * nlevel);
level.resize(nlevel, Summary(NULL, 0)); level.resize(nlevel, Summary(NULL, 0));
for (size_t l = 0; l < level.size(); ++l) { for (size_t l = 0; l < level.size(); ++l) {
level[l].data = BeginPtr(data) + l * limit_size; level[l].data = dmlc::BeginPtr(data) + l * limit_size;
} }
} }
// input data queue // input data queue
@ -793,7 +792,7 @@ class QuantileSketchTemplate {
*/ */
template<typename DType, typename RType = unsigned> template<typename DType, typename RType = unsigned>
class WQuantileSketch : class WQuantileSketch :
public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> >{ public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> > {
}; };
/*! /*!
@ -803,7 +802,7 @@ class WQuantileSketch :
*/ */
template<typename DType, typename RType = unsigned> template<typename DType, typename RType = unsigned>
class WXQuantileSketch : class WXQuantileSketch :
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> >{ public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
}; };
/*! /*!
* \brief Quantile sketch use WQSummary * \brief Quantile sketch use WQSummary
@ -812,9 +811,8 @@ class WXQuantileSketch :
*/ */
template<typename DType, typename RType = unsigned> template<typename DType, typename RType = unsigned>
class GKQuantileSketch : class GKQuantileSketch :
public QuantileSketchTemplate<DType, RType, GKSummary<DType, RType> >{ public QuantileSketchTemplate<DType, RType, GKSummary<DType, RType> > {
}; };
} // namespace common
} // namespace utils
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_UTILS_QUANTILE_H_ #endif // XGBOOST_COMMON_QUANTILE_H_

View File

@ -10,4 +10,4 @@
#include <rabit.h> #include <rabit.h>
#endif // XGBOOST_SYNC_H_ #endif // XGBOOST_COMMON_SYNC_H_

View File

@ -7,6 +7,7 @@
#include <xgboost/metric.h> #include <xgboost/metric.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include "./common/random.h" #include "./common/random.h"
#include "./common/base64.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);

View File

@ -55,7 +55,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
// number of threads to be used for tree construction, // number of threads to be used for tree construction,
// if OpenMP is enabled, if equals 0, use system default // if OpenMP is enabled, if equals 0, use system default
int nthread; int nthread;
// whether to not print info during training.
bool silent;
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(eta).set_lower_bound(0.0f).set_default(0.3f) DMLC_DECLARE_FIELD(eta).set_lower_bound(0.0f).set_default(0.3f)
@ -98,6 +99,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.describe("EXP Param: Cache aware optimization."); .describe("EXP Param: Cache aware optimization.");
DMLC_DECLARE_FIELD(nthread).set_default(0) DMLC_DECLARE_FIELD(nthread).set_default(0)
.describe("Number of threads used for training."); .describe("Number of threads used for training.");
DMLC_DECLARE_FIELD(silent).set_default(false)
.describe("Not print information during trainig.");
} }
// calculate the cost of loss function // calculate the cost of loss function

View File

@ -1,18 +1,24 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014 by Contributors
* \file updater_basemaker-inl.hpp * \file updater_basemaker-inl.h
* \brief implement a common tree constructor * \brief implement a common tree constructor
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_ #ifndef XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_
#define XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_ #define XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_
#include <xgboost/base.h>
#include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <limits> #include <limits>
#include "../sync/sync.h" #include <utility>
#include "../utils/random.h" #include "./param.h"
#include "../utils/quantile.h" #include "../common/sync.h"
#include "../common/io.h"
#include "../common/random.h"
#include "../common/quantile.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -20,13 +26,10 @@ namespace tree {
* \brief base tree maker class that defines common operation * \brief base tree maker class that defines common operation
* needed in tree making * needed in tree making
*/ */
class BaseMaker: public IUpdater { class BaseMaker: public TreeUpdater {
public: public:
// destructor void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
virtual ~BaseMaker(void) {} param.Init(args);
// set training parameter
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
} }
protected: protected:
@ -34,31 +37,31 @@ class BaseMaker: public IUpdater {
struct FMetaHelper { struct FMetaHelper {
public: public:
/*! \brief find type of each feature, use column format */ /*! \brief find type of each feature, use column format */
inline void InitByCol(IFMatrix *p_fmat, inline void InitByCol(DMatrix* p_fmat,
const RegTree &tree) { const RegTree& tree) {
fminmax.resize(tree.param.num_feature * 2); fminmax.resize(tree.param.num_feature * 2);
std::fill(fminmax.begin(), fminmax.end(), std::fill(fminmax.begin(), fminmax.end(),
-std::numeric_limits<bst_float>::max()); -std::numeric_limits<bst_float>::max());
// start accumulating statistics // start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(); dmlc::DataIter<ColBatch>* iter = p_fmat->ColIterator();
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch& batch = iter->Value();
for (bst_uint i = 0; i < batch.size; ++i) { for (bst_uint i = 0; i < batch.size; ++i) {
const bst_uint fid = batch.col_index[i]; const bst_uint fid = batch.col_index[i];
const ColBatch::Inst &c = batch[i]; const ColBatch::Inst& c = batch[i];
if (c.length != 0) { if (c.length != 0) {
fminmax[fid * 2 + 0] = std::max(-c[0].fvalue, fminmax[fid * 2 + 0]); fminmax[fid * 2 + 0] = std::max(-c[0].fvalue, fminmax[fid * 2 + 0]);
fminmax[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax[fid * 2 + 1]); fminmax[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax[fid * 2 + 1]);
} }
} }
} }
rabit::Allreduce<rabit::op::Max>(BeginPtr(fminmax), fminmax.size()); rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
} }
// get feature type, 0:empty 1:binary 2:real // get feature type, 0:empty 1:binary 2:real
inline int Type(bst_uint fid) const { inline int Type(bst_uint fid) const {
utils::Assert(fid * 2 + 1 < fminmax.size(), CHECK_LT(fid * 2 + 1, fminmax.size())
"FeatHelper fid exceed query bound "); << "FeatHelper fid exceed query bound ";
bst_float a = fminmax[fid * 2]; bst_float a = fminmax[fid * 2];
bst_float b = fminmax[fid * 2 + 1]; bst_float b = fminmax[fid * 2 + 1];
if (a == -std::numeric_limits<bst_float>::max()) return 0; if (a == -std::numeric_limits<bst_float>::max()) return 0;
@ -79,12 +82,12 @@ class BaseMaker: public IUpdater {
if (this->Type(fid) != 0) findex.push_back(fid); if (this->Type(fid) != 0) findex.push_back(fid);
} }
unsigned n = static_cast<unsigned>(p * findex.size()); unsigned n = static_cast<unsigned>(p * findex.size());
random::Shuffle(findex); std::shuffle(findex.begin(), findex.end(), common::GlobalRandom());
findex.resize(n); findex.resize(n);
// sync the findex if it is subsample // sync the findex if it is subsample
std::string s_cache; std::string s_cache;
utils::MemoryBufferStream fc(&s_cache); common::MemoryBufferStream fc(&s_cache);
utils::IStream &fs = fc; dmlc::Stream& fs = fc;
if (rabit::GetRank() == 0) { if (rabit::GetRank() == 0) {
fs.Write(findex); fs.Write(findex);
} }
@ -113,7 +116,7 @@ class BaseMaker: public IUpdater {
return n.cdefault(); return n.cdefault();
} }
/*! \brief get number of omp thread in current context */ /*! \brief get number of omp thread in current context */
inline static int get_nthread(void) { inline static int get_nthread() {
int nthread; int nthread;
#pragma omp parallel #pragma omp parallel
{ {
@ -124,11 +127,11 @@ class BaseMaker: public IUpdater {
// ------class member helpers--------- // ------class member helpers---------
/*! \brief initialize temp data structure */ /*! \brief initialize temp data structure */
inline void InitData(const std::vector<bst_gpair> &gpair, inline void InitData(const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat, const DMatrix &fmat,
const std::vector<unsigned> &root_index,
const RegTree &tree) { const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
"TreeMaker: can only grow new tree"); << "TreeMaker: can only grow new tree";
const std::vector<unsigned> &root_index = fmat.info().root_index;
{ {
// setup position // setup position
position.resize(gpair.size()); position.resize(gpair.size());
@ -137,8 +140,8 @@ class BaseMaker: public IUpdater {
} else { } else {
for (size_t i = 0; i < position.size(); ++i) { for (size_t i = 0; i < position.size(); ++i) {
position[i] = root_index[i]; position[i] = root_index[i];
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, CHECK_LT(root_index[i], (unsigned)tree.param.num_roots)
"root index exceed setting"); << "root index exceed setting";
} }
} }
// mark delete for the deleted datas // mark delete for the deleted datas
@ -147,9 +150,11 @@ class BaseMaker: public IUpdater {
} }
// mark subsample // mark subsample
if (param.subsample < 1.0f) { if (param.subsample < 1.0f) {
std::bernoulli_distribution coin_flip(param.subsample);
auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < position.size(); ++i) { for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) continue; if (gpair[i].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[i] = ~position[i]; if (!coin_flip(rnd)) position[i] = ~position[i];
} }
} }
} }
@ -197,7 +202,8 @@ class BaseMaker: public IUpdater {
* \param tree the regression tree structure * \param tree the regression tree structure
*/ */
inline void ResetPositionCol(const std::vector<int> &nodes, inline void ResetPositionCol(const std::vector<int> &nodes,
IFMatrix *p_fmat, const RegTree &tree) { DMatrix *p_fmat,
const RegTree &tree) {
// set the positions in the nondefault // set the positions in the nondefault
this->SetNonDefaultPositionCol(nodes, p_fmat, tree); this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
// set rest of instances to default position // set rest of instances to default position
@ -234,7 +240,8 @@ class BaseMaker: public IUpdater {
* \param tree the regression tree structure * \param tree the regression tree structure
*/ */
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes, virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
IFMatrix *p_fmat, const RegTree &tree) { DMatrix *p_fmat,
const RegTree &tree) {
// step 1, classify the non-default data into right places // step 1, classify the non-default data into right places
std::vector<unsigned> fsplits; std::vector<unsigned> fsplits;
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
@ -246,7 +253,7 @@ class BaseMaker: public IUpdater {
std::sort(fsplits.begin(), fsplits.end()); std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) { for (size_t i = 0; i < batch.size; ++i) {
@ -273,12 +280,12 @@ class BaseMaker: public IUpdater {
/*! \brief helper function to get statistics from a tree */ /*! \brief helper function to get statistics from a tree */
template<typename TStats> template<typename TStats>
inline void GetNodeStats(const std::vector<bst_gpair> &gpair, inline void GetNodeStats(const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat, const DMatrix &fmat,
const RegTree &tree, const RegTree &tree,
const BoosterInfo &info,
std::vector< std::vector<TStats> > *p_thread_temp, std::vector< std::vector<TStats> > *p_thread_temp,
std::vector<TStats> *p_node_stats) { std::vector<TStats> *p_node_stats) {
std::vector< std::vector<TStats> > &thread_temp = *p_thread_temp; std::vector< std::vector<TStats> > &thread_temp = *p_thread_temp;
const MetaInfo &info = fmat.info();
thread_temp.resize(this->get_nthread()); thread_temp.resize(this->get_nthread());
p_node_stats->resize(tree.param.num_nodes); p_node_stats->resize(tree.param.num_nodes);
#pragma omp parallel #pragma omp parallel
@ -323,7 +330,7 @@ class BaseMaker: public IUpdater {
/*! \brief current size of sketch */ /*! \brief current size of sketch */
double next_goal; double next_goal;
// pointer to the sketch to put things in // pointer to the sketch to put things in
utils::WXQuantileSketch<bst_float, bst_float> *sketch; common::WXQuantileSketch<bst_float, bst_float> *sketch;
// initialize the space // initialize the space
inline void Init(unsigned max_size) { inline void Init(unsigned max_size) {
next_goal = -1.0f; next_goal = -1.0f;
@ -351,13 +358,13 @@ class BaseMaker: public IUpdater {
last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
// push to sketch // push to sketch
sketch->temp.data[sketch->temp.size] = sketch->temp.data[sketch->temp.size] =
utils::WXQuantileSketch<bst_float, bst_float>:: common::WXQuantileSketch<bst_float, bst_float>::
Entry(static_cast<bst_float>(rmin), Entry(static_cast<bst_float>(rmin),
static_cast<bst_float>(rmax), static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue); static_cast<bst_float>(wmin), last_fvalue);
utils::Assert(sketch->temp.size < max_size, CHECK_LT(sketch->temp.size, max_size)
"invalid maximum size max_size=%u, stemp.size=%lu\n", << "invalid maximum size max_size=" << max_size
max_size, sketch->temp.size); << ", stemp.size" << sketch->temp.size;
++sketch->temp.size; ++sketch->temp.size;
} }
if (sketch->temp.size == max_size) { if (sketch->temp.size == max_size) {
@ -382,12 +389,12 @@ class BaseMaker: public IUpdater {
inline void Finalize(unsigned max_size) { inline void Finalize(unsigned max_size) {
double rmax = rmin + wmin; double rmax = rmin + wmin;
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
utils::Assert(sketch->temp.size <= max_size, CHECK_LE(sketch->temp.size, max_size)
"Finalize: invalid maximum size, max_size=%u, stemp.size=%lu", << "Finalize: invalid maximum size, max_size=" << max_size
sketch->temp.size, max_size); << ", stemp.size=" << sketch->temp.size;
// push to sketch // push to sketch
sketch->temp.data[sketch->temp.size] = sketch->temp.data[sketch->temp.size] =
utils::WXQuantileSketch<bst_float, bst_float>:: common::WXQuantileSketch<bst_float, bst_float>::
Entry(static_cast<bst_float>(rmin), Entry(static_cast<bst_float>(rmin),
static_cast<bst_float>(rmax), static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue); static_cast<bst_float>(wmin), last_fvalue);
@ -424,4 +431,4 @@ class BaseMaker: public IUpdater {
}; };
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_BASEMAKER_INL_HPP_ #endif // XGBOOST_TREE_UPDATER_BASEMAKER_INL_H_

View File

@ -747,7 +747,7 @@ class DistColMaker : public ColMaker<TStats> {
// update position after the tree is pruned // update position after the tree is pruned
builder.UpdatePosition(dmat, *trees[0]); builder.UpdatePosition(dmat, *trees[0]);
} }
virtual const int* GetLeafPosition() const { const int* GetLeafPosition() const override {
return builder.GetLeafPosition(); return builder.GetLeafPosition();
} }
@ -771,7 +771,7 @@ class DistColMaker : public ColMaker<TStats> {
this->position[ridx] = nid; this->position[ridx] = nid;
} }
} }
const int* GetLeafPosition() const override { inline const int* GetLeafPosition() const {
return dmlc::BeginPtr(this->position); return dmlc::BeginPtr(this->position);
} }

View File

@ -1,38 +1,35 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014 by Contributors
* \file updater_histmaker-inl.hpp * \file updater_histmaker.cc
* \brief use histogram counting to construct a tree * \brief use histogram counting to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_ #include <xgboost/base.h>
#define XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_ #include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "../sync/sync.h" #include "../common/sync.h"
#include "../utils/quantile.h" #include "../common/quantile.h"
#include "../utils/group_data.h" #include "../common/group_data.h"
#include "./updater_basemaker-inl.hpp" #include "./updater_basemaker-inl.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
template<typename TStats> template<typename TStats>
class HistMaker: public BaseMaker { class HistMaker: public BaseMaker {
public: public:
virtual ~HistMaker(void) {} void Update(const std::vector<bst_gpair> &gpair,
virtual void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
IFMatrix *p_fmat, const std::vector<RegTree*> &trees) override {
const BoosterInfo &info, TStats::CheckInfo(p_fmat->info());
const std::vector<RegTree*> &trees) {
TStats::CheckInfo(info);
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.eta;
param.learning_rate = lr / trees.size(); param.eta = lr / trees.size();
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, info, trees[i]); this->Update(gpair, p_fmat, trees[i]);
} }
param.learning_rate = lr; param.eta = lr;
} }
protected: protected:
@ -45,19 +42,18 @@ class HistMaker: public BaseMaker {
/*! \brief size of histogram */ /*! \brief size of histogram */
unsigned size; unsigned size;
// default constructor // default constructor
HistUnit(void) {} HistUnit() {}
// constructor // constructor
HistUnit(const bst_float *cut, TStats *data, unsigned size) HistUnit(const bst_float *cut, TStats *data, unsigned size)
: cut(cut), data(data), size(size) {} : cut(cut), data(data), size(size) {}
/*! \brief add a histogram to data */ /*! \brief add a histogram to data */
inline void Add(bst_float fv, inline void Add(bst_float fv,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info, const MetaInfo &info,
const bst_uint ridx) { const bst_uint ridx) {
unsigned i = std::upper_bound(cut, cut + size, fv) - cut; unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
utils::Assert(size != 0, "try insert into size=0"); CHECK_NE(size, 0) << "try insert into size=0";
utils::Assert(i < size, CHECK_LT(i, size);
"maximum value must be in cut, fv = %g, cutmax=%g", fv, cut[size-1]);
data[i].Add(gpair, info, ridx); data[i].Add(gpair, info, ridx);
} }
}; };
@ -92,13 +88,13 @@ class HistMaker: public BaseMaker {
for (size_t i = 0; i < hset[tid].data.size(); ++i) { for (size_t i = 0; i < hset[tid].data.size(); ++i) {
hset[tid].data[i].Clear(); hset[tid].data[i].Clear();
} }
hset[tid].rptr = BeginPtr(rptr); hset[tid].rptr = dmlc::BeginPtr(rptr);
hset[tid].cut = BeginPtr(cut); hset[tid].cut = dmlc::BeginPtr(cut);
hset[tid].data.resize(cut.size(), TStats(param)); hset[tid].data.resize(cut.size(), TStats(param));
} }
} }
// aggregate all statistics to hset[0] // aggregate all statistics to hset[0]
inline void Aggregate(void) { inline void Aggregate() {
bst_omp_uint nsize = static_cast<bst_omp_uint>(cut.size()); bst_omp_uint nsize = static_cast<bst_omp_uint>(cut.size());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) { for (bst_omp_uint i = 0; i < nsize; ++i) {
@ -108,11 +104,11 @@ class HistMaker: public BaseMaker {
} }
} }
/*! \brief clear the workspace */ /*! \brief clear the workspace */
inline void Clear(void) { inline void Clear() {
cut.clear(); rptr.resize(1); rptr[0] = 0; cut.clear(); rptr.resize(1); rptr[0] = 0;
} }
/*! \brief total size */ /*! \brief total size */
inline size_t Size(void) const { inline size_t Size() const {
return rptr.size() - 1; return rptr.size() - 1;
} }
}; };
@ -124,18 +120,17 @@ class HistMaker: public BaseMaker {
std::vector<bst_uint> fwork_set; std::vector<bst_uint> fwork_set;
// update function implementation // update function implementation
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) { RegTree *p_tree) {
this->InitData(gpair, *p_fmat, info.root_index, *p_tree); this->InitData(gpair, *p_fmat, *p_tree);
this->InitWorkSet(p_fmat, *p_tree, &fwork_set); this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
for (int depth = 0; depth < param.max_depth; ++depth) { for (int depth = 0; depth < param.max_depth; ++depth) {
// reset and propose candidate split // reset and propose candidate split
this->ResetPosAndPropose(gpair, p_fmat, info, fwork_set, *p_tree); this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
// create histogram // create histogram
this->CreateHist(gpair, p_fmat, info, fwork_set, *p_tree); this->CreateHist(gpair, p_fmat, fwork_set, *p_tree);
// find split based on histogram statistics // find split based on histogram statistics
this->FindSplit(depth, gpair, p_fmat, info, fwork_set, p_tree); this->FindSplit(depth, gpair, p_fmat, fwork_set, p_tree);
// reset position after split // reset position after split
this->ResetPositionAfterSplit(p_fmat, *p_tree); this->ResetPositionAfterSplit(p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree); this->UpdateQueueExpand(*p_tree);
@ -144,19 +139,18 @@ class HistMaker: public BaseMaker {
} }
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
} }
} }
// this function does two jobs // this function does two jobs
// (1) reset the position in array position, to be the latest leaf id // (1) reset the position in array position, to be the latest leaf id
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair, virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
const std::vector <bst_uint> &fset, const std::vector <bst_uint> &fset,
const RegTree &tree) = 0; const RegTree &tree) = 0;
// initialize the current working set of features in this round // initialize the current working set of features in this round
virtual void InitWorkSet(IFMatrix *p_fmat, virtual void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
std::vector<bst_uint> *p_fset) { std::vector<bst_uint> *p_fset) {
p_fset->resize(tree.param.num_feature); p_fset->resize(tree.param.num_feature);
@ -165,12 +159,11 @@ class HistMaker: public BaseMaker {
} }
} }
// reset position after split, this is not a must, depending on implementation // reset position after split, this is not a must, depending on implementation
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat, virtual void ResetPositionAfterSplit(DMatrix *p_fmat,
const RegTree &tree) { const RegTree &tree) {
} }
virtual void CreateHist(const std::vector<bst_gpair> &gpair, virtual void CreateHist(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
const std::vector <bst_uint> &fset, const std::vector <bst_uint> &fset,
const RegTree &tree) = 0; const RegTree &tree) = 0;
@ -212,8 +205,7 @@ class HistMaker: public BaseMaker {
} }
inline void FindSplit(int depth, inline void FindSplit(int depth,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
const std::vector <bst_uint> &fset, const std::vector <bst_uint> &fset,
RegTree *p_tree) { RegTree *p_tree) {
const size_t num_feature = fset.size(); const size_t num_feature = fset.size();
@ -224,8 +216,7 @@ class HistMaker: public BaseMaker {
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const int nid = qexpand[wid]; const int nid = qexpand[wid];
utils::Assert(node2workindex[nid] == static_cast<int>(wid), CHECK_EQ(node2workindex[nid], static_cast<int>(wid));
"node2workindex inconsistent");
SplitEntry &best = sol[wid]; SplitEntry &best = sol[wid];
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0]; TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
for (size_t i = 0; i < fset.size(); ++i) { for (size_t i = 0; i < fset.size(); ++i) {
@ -255,7 +246,7 @@ class HistMaker: public BaseMaker {
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]); this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum); this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
} else { } else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
} }
} }
} }
@ -279,10 +270,10 @@ class CQHistMaker: public HistMaker<TStats> {
*/ */
inline void Add(bst_float fv, inline void Add(bst_float fv,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info, const MetaInfo &info,
const bst_uint ridx) { const bst_uint ridx) {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
utils::Assert(istart != hist.size, "the bound variable must be max"); CHECK_NE(istart, hist.size);
hist.data[istart].Add(gpair, info, ridx); hist.data[istart].Add(gpair, info, ridx);
} }
/*! /*!
@ -292,25 +283,25 @@ class CQHistMaker: public HistMaker<TStats> {
inline void Add(bst_float fv, inline void Add(bst_float fv,
bst_gpair gstats) { bst_gpair gstats) {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
utils::Assert(istart != hist.size, "the bound variable must be max"); CHECK_NE(istart, hist.size);
hist.data[istart].Add(gstats); hist.data[istart].Add(gstats);
} }
}; };
// sketch type used for this // sketch type used for this
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch; typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// initialize the work set of tree // initialize the work set of tree
virtual void InitWorkSet(IFMatrix *p_fmat, void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
std::vector<bst_uint> *p_fset) { std::vector<bst_uint> *p_fset) override {
feat_helper.InitByCol(p_fmat, tree); feat_helper.InitByCol(p_fmat, tree);
feat_helper.SampleCol(this->param.colsample_bytree, p_fset); feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
} }
// code to create histogram // code to create histogram
virtual void CreateHist(const std::vector<bst_gpair> &gpair, void CreateHist(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, const std::vector<bst_uint> &fset,
const std::vector<bst_uint> &fset, const RegTree &tree) override {
const RegTree &tree) { const MetaInfo &info = p_fmat->info();
// fill in reverse map // fill in reverse map
feat2workindex.resize(tree.param.num_feature); feat2workindex.resize(tree.param.num_feature);
std::fill(feat2workindex.begin(), feat2workindex.end(), -1); std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
@ -327,7 +318,7 @@ class CQHistMaker: public HistMaker<TStats> {
{ {
thread_hist.resize(this->get_nthread()); thread_hist.resize(this->get_nthread());
// start accumulating statistics // start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fset);
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch &batch = iter->Value();
@ -353,21 +344,22 @@ class CQHistMaker: public HistMaker<TStats> {
// sync the histogram // sync the histogram
// if it is C++11, use lazy evaluation for Allreduce // if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size(), lazy_get_hist); this->wspace.hset[0].data.size(), lazy_get_hist);
#else #else
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size()); this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size());
#endif #endif
} }
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat, void ResetPositionAfterSplit(DMatrix *p_fmat,
const RegTree &tree) { const RegTree &tree) override {
this->ResetPositionCol(this->qexpand, p_fmat, tree); this->ResetPositionCol(this->qexpand, p_fmat, tree);
} }
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair, void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, const std::vector<bst_uint> &fset,
const std::vector<bst_uint> &fset, const RegTree &tree) override {
const RegTree &tree) { const MetaInfo &info = p_fmat->info();
// fill in reverse map // fill in reverse map
feat2workindex.resize(tree.param.num_feature); feat2workindex.resize(tree.param.num_feature);
std::fill(feat2workindex.begin(), feat2workindex.end(), -1); std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
@ -380,7 +372,7 @@ class CQHistMaker: public HistMaker<TStats> {
feat2workindex[fset[i]] = -2; feat2workindex[fset[i]] = -2;
} }
} }
this->GetNodeStats(gpair, *p_fmat, tree, info, this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats, &node_stats); &thread_stats, &node_stats);
sketchs.resize(this->qexpand.size() * freal_set.size()); sketchs.resize(this->qexpand.size() * freal_set.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
@ -403,7 +395,7 @@ class CQHistMaker: public HistMaker<TStats> {
// number of rows in // number of rows in
const size_t nrows = p_fmat->buffered_rowset().size(); const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics // start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(freal_set); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch &batch = iter->Value();
@ -422,18 +414,19 @@ class CQHistMaker: public HistMaker<TStats> {
} }
} }
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch"); CHECK_EQ(summary_array.size(), sketchs.size());
}; };
if (summary_array.size() != 0) { if (summary_array.size() != 0) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary); sreducer.Allreduce(dmlc::BeginPtr(summary_array),
nbytes, summary_array.size(), lazy_get_summary);
#else #else
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
#endif #endif
} }
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
@ -460,7 +453,7 @@ class CQHistMaker: public HistMaker<TStats> {
} }
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size())); this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
} else { } else {
utils::Assert(offset == -2, "BUG in mark"); CHECK_EQ(offset, -2);
bst_float cpt = feat_helper.MaxValue(fset[i]); bst_float cpt = feat_helper.MaxValue(fset[i]);
this->wspace.cut.push_back(cpt + fabs(cpt) + rt_eps); this->wspace.cut.push_back(cpt + fabs(cpt) + rt_eps);
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size())); this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
@ -470,15 +463,14 @@ class CQHistMaker: public HistMaker<TStats> {
this->wspace.cut.push_back(0.0f); this->wspace.cut.push_back(0.0f);
this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size())); this->wspace.rptr.push_back(static_cast<unsigned>(this->wspace.cut.size()));
} }
utils::Assert(this->wspace.rptr.size() == CHECK_EQ(this->wspace.rptr.size(),
(fset.size() + 1) * this->qexpand.size() + 1, (fset.size() + 1) * this->qexpand.size() + 1);
"cut space inconsistent");
} }
private: private:
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair, inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
const ColBatch::Inst &c, const ColBatch::Inst &c,
const BoosterInfo &info, const MetaInfo &info,
const RegTree &tree, const RegTree &tree,
const std::vector<bst_uint> &fset, const std::vector<bst_uint> &fset,
bst_uint fid_offset, bst_uint fid_offset,
@ -623,11 +615,11 @@ class CQHistMaker: public HistMaker<TStats> {
// set of index from fset that are real // set of index from fset that are real
std::vector<bst_uint> freal_set; std::vector<bst_uint> freal_set;
// thread temp data // thread temp data
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch; std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
// used to hold statistics // used to hold statistics
std::vector< std::vector<TStats> > thread_stats; std::vector<std::vector<TStats> > thread_stats;
// used to hold start pointer // used to hold start pointer
std::vector< std::vector<HistEntry> > thread_hist; std::vector<std::vector<HistEntry> > thread_hist;
// node statistics // node statistics
std::vector<TStats> node_stats; std::vector<TStats> node_stats;
// summary array // summary array
@ -635,18 +627,18 @@ class CQHistMaker: public HistMaker<TStats> {
// reducer for summary // reducer for summary
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer; rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
// per node, per feature sketch // per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };
template<typename TStats> template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> { class QuantileHistMaker: public HistMaker<TStats> {
protected: protected:
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch; typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair, void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, const std::vector <bst_uint> &fset,
const std::vector <bst_uint> &fset, const RegTree &tree) override {
const RegTree &tree) { const MetaInfo &info = p_fmat->info();
// initialize the data structure // initialize the data structure
int nthread = BaseMaker::get_nthread(); int nthread = BaseMaker::get_nthread();
sketchs.resize(this->qexpand.size() * tree.param.num_feature); sketchs.resize(this->qexpand.size() * tree.param.num_feature);
@ -654,12 +646,13 @@ class QuantileHistMaker: public HistMaker<TStats> {
sketchs[i].Init(info.num_row, this->param.sketch_eps); sketchs[i].Init(info.num_row, this->param.sketch_eps);
} }
// start accumulating statistics // start accumulating statistics
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator(); dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const RowBatch &batch = iter->Value(); const RowBatch &batch = iter->Value();
// parallel convert to column major format // parallel convert to column major format
utils::ParallelGroupBuilder<SparseBatch::Entry> builder(&col_ptr, &col_data, &thread_col_ptr); common::ParallelGroupBuilder<SparseBatch::Entry>
builder(&col_ptr, &col_data, &thread_col_ptr);
builder.InitBudget(tree.param.num_feature, nthread); builder.InitBudget(tree.param.num_feature, nthread);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size); const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
@ -711,14 +704,14 @@ class QuantileHistMaker: public HistMaker<TStats> {
// synchronize sketch // synchronize sketch
summary_array.resize(sketchs.size()); summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out; common::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
this->wspace.rptr.clear(); this->wspace.rptr.clear();
@ -745,9 +738,8 @@ class QuantileHistMaker: public HistMaker<TStats> {
this->wspace.cut.push_back(0.0f); this->wspace.cut.push_back(0.0f);
this->wspace.rptr.push_back(this->wspace.cut.size()); this->wspace.rptr.push_back(this->wspace.cut.size());
} }
utils::Assert(this->wspace.rptr.size() == CHECK_EQ(this->wspace.rptr.size(),
(tree.param.num_feature + 1) * this->qexpand.size() + 1, (tree.param.num_feature + 1) * this->qexpand.size() + 1);
"cut space inconsistent");
} }
private: private:
@ -759,11 +751,15 @@ class QuantileHistMaker: public HistMaker<TStats> {
std::vector<size_t> col_ptr; std::vector<size_t> col_ptr;
// local storage of column data // local storage of column data
std::vector<SparseBatch::Entry> col_data; std::vector<SparseBatch::Entry> col_data;
std::vector< std::vector<size_t> > thread_col_ptr; std::vector<std::vector<size_t> > thread_col_ptr;
// per node, per feature sketch // per node, per feature sketch
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs; std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs;
}; };
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([]() {
return new CQHistMaker<GradStats>();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_

View File

@ -1,43 +1,42 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014 by Contributors
* \file updater_prune-inl.hpp * \file updater_prune.cc
* \brief prune a tree given the statistics * \brief prune a tree given the statistics
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
#define XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
#include <vector> #include <xgboost/tree_updater.h>
#include <string>
#include <memory>
#include "./param.h" #include "./param.h"
#include "./updater.h" #include "../common/sync.h"
#include "./updater_sync-inl.hpp" #include "../common/io.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
/*! \brief pruner that prunes a tree after growing finishes */ /*! \brief pruner that prunes a tree after growing finishes */
class TreePruner: public IUpdater { class TreePruner: public TreeUpdater {
public: public:
virtual ~TreePruner(void) {} TreePruner() {
syncher.reset(TreeUpdater::Create("sync"));
}
// set training parameter // set training parameter
virtual void SetParam(const char *name, const char *val) { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
using namespace std; param.Init(args);
param.SetParam(name, val); syncher->Init(args);
syncher.SetParam(name, val);
if (!strcmp(name, "silent")) silent = atoi(val);
} }
// update the tree, do pruning // update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair, void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, const std::vector<RegTree*> &trees) override {
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.eta;
param.learning_rate = lr / trees.size(); param.eta = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->DoPrune(*trees[i]); this->DoPrune(*trees[i]);
} }
param.learning_rate = lr; param.eta = lr;
syncher.Update(gpair, p_fmat, info, trees); syncher->Update(gpair, p_fmat, trees);
} }
private: private:
@ -49,9 +48,9 @@ class TreePruner: public IUpdater {
++s.leaf_child_cnt; ++s.leaf_child_cnt;
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) { if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
// need to be pruned // need to be pruned
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight); tree.ChangeToLeaf(pid, param.eta * s.base_weight);
// tail recursion // tail recursion
return this->TryPruneLeaf(tree, pid, depth - 1, npruned+2); return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
} else { } else {
return npruned; return npruned;
} }
@ -68,20 +67,24 @@ class TreePruner: public IUpdater {
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
} }
} }
if (silent == 0) { if (!param.silent) {
utils::Printf("tree pruning end, %d roots, %d extra nodes, %d pruned nodes, max_depth=%d\n", LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
tree.param.num_roots, tree.num_extra_nodes(), npruned, tree.MaxDepth()); << tree.num_extra_nodes() << " extra nodes, " << npruned
<< " pruned nodes, max_depth=" << tree.MaxDepth();
} }
} }
private: private:
// synchronizer // synchronizer
TreeSyncher syncher; std::unique_ptr<TreeUpdater> syncher;
// shutup
int silent;
// training parameter // training parameter
TrainParam param; TrainParam param;
}; };
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.")
.set_body([]() {
return new TreePruner();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_

View File

@ -1,39 +1,34 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014 by Contributors
* \file updater_refresh-inl.hpp * \file updater_refresh.cc
* \brief refresh the statistics and leaf value on the tree on the dataset * \brief refresh the statistics and leaf value on the tree on the dataset
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
#define XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
#include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <limits> #include <limits>
#include "../sync/sync.h"
#include "./param.h" #include "./param.h"
#include "./updater.h" #include "../common/sync.h"
#include "../utils/omp.h" #include "../common/io.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */ /*! \brief pruner that prunes a tree after growing finishs */
template<typename TStats> template<typename TStats>
class TreeRefresher: public IUpdater { class TreeRefresher: public TreeUpdater {
public: public:
virtual ~TreeRefresher(void) {} void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
// set training parameter param.Init(args);
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
} }
// update the tree, do pruning // update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair, void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, const std::vector<RegTree*> &trees) {
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return; if (trees.size() == 0) return;
// number of threads // number of threads
// thread temporal space // thread temporal space
std::vector< std::vector<TStats> > stemp; std::vector<std::vector<TStats> > stemp;
std::vector<RegTree::FVec> fvec_temp; std::vector<RegTree::FVec> fvec_temp;
// setup temp space for each thread // setup temp space for each thread
int nthread; int nthread;
@ -60,13 +55,13 @@ class TreeRefresher: public IUpdater {
auto lazy_get_stats = [&]() auto lazy_get_stats = [&]()
#endif #endif
{ {
const MetaInfo &info = p_fmat->info();
// start accumulating statistics // start accumulating statistics
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator(); dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const RowBatch &batch = iter->Value(); const RowBatch &batch = iter->Value();
utils::Check(batch.size < std::numeric_limits<unsigned>::max(), CHECK_LT(batch.size, std::numeric_limits<unsigned>::max());
"too large batch size ");
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size); const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) { for (bst_omp_uint i = 0; i < nbatch; ++i) {
@ -78,7 +73,7 @@ class TreeRefresher: public IUpdater {
int offset = 0; int offset = 0;
for (size_t j = 0; j < trees.size(); ++j) { for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair, info, ridx, AddStats(*trees[j], feats, gpair, info, ridx,
BeginPtr(stemp[tid]) + offset); dmlc::BeginPtr(stemp[tid]) + offset);
offset += trees[j]->param.num_nodes; offset += trees[j]->param.num_nodes;
} }
feats.Drop(inst); feats.Drop(inst);
@ -94,29 +89,29 @@ class TreeRefresher: public IUpdater {
} }
}; };
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats); reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
#else #else
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size()); reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
#endif #endif
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.eta;
param.learning_rate = lr / trees.size(); param.eta = lr / trees.size();
int offset = 0; int offset = 0;
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) { for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
this->Refresh(BeginPtr(stemp[0]) + offset, rid, trees[i]); this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, trees[i]);
} }
offset += trees[i]->param.num_nodes; offset += trees[i]->param.num_nodes;
} }
// set learning rate back // set learning rate back
param.learning_rate = lr; param.eta = lr;
} }
private: private:
inline static void AddStats(const RegTree &tree, inline static void AddStats(const RegTree &tree,
const RegTree::FVec &feat, const RegTree::FVec &feat,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info, const MetaInfo &info,
const bst_uint ridx, const bst_uint ridx,
TStats *gstats) { TStats *gstats) {
// start from groups that belongs to current data // start from groups that belongs to current data
@ -136,7 +131,7 @@ class TreeRefresher: public IUpdater {
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess); tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid)); gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); tree[nid].set_leaf(tree.stat(nid).base_weight * param.eta);
} else { } else {
tree.stat(nid).loss_chg = static_cast<float>( tree.stat(nid).loss_chg = static_cast<float>(
gstats[tree[nid].cleft()].CalcGain(param) + gstats[tree[nid].cleft()].CalcGain(param) +
@ -152,6 +147,10 @@ class TreeRefresher: public IUpdater {
rabit::Reducer<TStats, TStats::Reduce> reducer; rabit::Reducer<TStats, TStats::Reduce> reducer;
}; };
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([]() {
return new TreeRefresher<GradStats>();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_

View File

@ -1,57 +1,56 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014 by Contributors
* \file updater_skmaker-inl.hpp * \file updater_skmaker.cc
* \brief use approximation sketch to construct a tree, * \brief use approximation sketch to construct a tree,
a refresh is needed to make the statistics exactly correct a refresh is needed to make the statistics exactly correct
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_
#define XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_
#include <xgboost/base.h>
#include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "../sync/sync.h" #include "../common/sync.h"
#include "../utils/quantile.h" #include "../common/quantile.h"
#include "./updater_basemaker-inl.hpp" #include "../common/group_data.h"
#include "./updater_basemaker-inl.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
class SketchMaker: public BaseMaker { class SketchMaker: public BaseMaker {
public: public:
virtual ~SketchMaker(void) {} void Update(const std::vector<bst_gpair> &gpair,
virtual void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
IFMatrix *p_fmat, const std::vector<RegTree*> &trees) override {
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.eta;
param.learning_rate = lr / trees.size(); param.eta = lr / trees.size();
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, info, trees[i]); this->Update(gpair, p_fmat, trees[i]);
} }
param.learning_rate = lr; param.eta = lr;
} }
protected: protected:
inline void Update(const std::vector<bst_gpair> &gpair, inline void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info, RegTree *p_tree) {
RegTree *p_tree) { this->InitData(gpair, *p_fmat, *p_tree);
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) { for (int depth = 0; depth < param.max_depth; ++depth) {
this->GetNodeStats(gpair, *p_fmat, *p_tree, info, this->GetNodeStats(gpair, *p_fmat, *p_tree,
&thread_stats, &node_stats); &thread_stats, &node_stats);
this->BuildSketch(gpair, p_fmat, info, *p_tree); this->BuildSketch(gpair, p_fmat, *p_tree);
this->SyncNodeStats(); this->SyncNodeStats();
this->FindSplit(depth, gpair, p_fmat, info, p_tree); this->FindSplit(depth, gpair, p_fmat, p_tree);
this->ResetPositionCol(qexpand, p_fmat, *p_tree); this->ResetPositionCol(qexpand, p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree); this->UpdateQueueExpand(*p_tree);
// if nothing left to be expand, break // if nothing left to be expand, break
if (qexpand.size() == 0) break; if (qexpand.size() == 0) break;
} }
if (qexpand.size() != 0) { if (qexpand.size() != 0) {
this->GetNodeStats(gpair, *p_fmat, *p_tree, info, this->GetNodeStats(gpair, *p_fmat, *p_tree,
&thread_stats, &node_stats); &thread_stats, &node_stats);
this->SyncNodeStats(); this->SyncNodeStats();
} }
@ -68,11 +67,11 @@ class SketchMaker: public BaseMaker {
// set left leaves // set left leaves
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
} }
} }
// define the sketch we want to use // define the sketch we want to use
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch; typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
private: private:
// statistics needed in the gradient calculation // statistics needed in the gradient calculation
@ -94,7 +93,7 @@ class SketchMaker: public BaseMaker {
} }
// accumulate statistics // accumulate statistics
inline void Add(const std::vector<bst_gpair> &gpair, inline void Add(const std::vector<bst_gpair> &gpair,
const BoosterInfo &info, const MetaInfo &info,
bst_uint ridx) { bst_uint ridx) {
const bst_gpair &b = gpair[ridx]; const bst_gpair &b = gpair[ridx];
if (b.grad >= 0.0f) { if (b.grad >= 0.0f) {
@ -133,9 +132,9 @@ class SketchMaker: public BaseMaker {
} }
}; };
inline void BuildSketch(const std::vector<bst_gpair> &gpair, inline void BuildSketch(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
const RegTree &tree) { const RegTree &tree) {
const MetaInfo& info = p_fmat->info();
sketchs.resize(this->qexpand.size() * tree.param.num_feature * 3); sketchs.resize(this->qexpand.size() * tree.param.num_feature * 3);
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps); sketchs[i].Init(info.num_row, this->param.sketch_eps);
@ -144,7 +143,7 @@ class SketchMaker: public BaseMaker {
// number of rows in // number of rows in
const size_t nrows = p_fmat->buffered_rowset().size(); const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics // start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
iter->BeforeFirst(); iter->BeforeFirst();
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch &batch = iter->Value();
@ -164,13 +163,13 @@ class SketchMaker: public BaseMaker {
// synchronize sketch // synchronize sketch
summary_array.resize(sketchs.size()); summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sketch_reducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sketch_reducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
} }
// update sketch information in column fid // update sketch information in column fid
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair, inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
@ -256,20 +255,19 @@ class SketchMaker: public BaseMaker {
} }
} }
inline void SyncNodeStats(void) { inline void SyncNodeStats(void) {
utils::Assert(qexpand.size() != 0, "qexpand must not be empty"); CHECK_NE(qexpand.size(), 0);
std::vector<SKStats> tmp(qexpand.size()); std::vector<SKStats> tmp(qexpand.size());
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
tmp[i] = node_stats[qexpand[i]]; tmp[i] = node_stats[qexpand[i]];
} }
stats_reducer.Allreduce(BeginPtr(tmp), tmp.size()); stats_reducer.Allreduce(dmlc::BeginPtr(tmp), tmp.size());
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
node_stats[qexpand[i]] = tmp[i]; node_stats[qexpand[i]] = tmp[i];
} }
} }
inline void FindSplit(int depth, inline void FindSplit(int depth,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, DMatrix *p_fmat,
const BoosterInfo &info,
RegTree *p_tree) { RegTree *p_tree) {
const bst_uint num_feature = p_tree->param.num_feature; const bst_uint num_feature = p_tree->param.num_feature;
// get the best split condition for each node // get the best split condition for each node
@ -278,8 +276,7 @@ class SketchMaker: public BaseMaker {
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const int nid = qexpand[wid]; const int nid = qexpand[wid];
utils::Assert(node2workindex[nid] == static_cast<int>(wid), CHECK_EQ(node2workindex[nid], static_cast<int>(wid));
"node2workindex inconsistent");
SplitEntry &best = sol[wid]; SplitEntry &best = sol[wid];
for (bst_uint fid = 0; fid < num_feature; ++fid) { for (bst_uint fid = 0; fid < num_feature; ++fid) {
unsigned base = (wid * p_tree->param.num_feature + fid) * 3; unsigned base = (wid * p_tree->param.num_feature + fid) * 3;
@ -305,7 +302,7 @@ class SketchMaker: public BaseMaker {
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else { } else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
} }
} }
} }
@ -380,9 +377,9 @@ class SketchMaker: public BaseMaker {
// thread temp data // thread temp data
// used to hold temporal sketch // used to hold temporal sketch
std::vector< std::vector<SketchEntry> > thread_sketch; std::vector<std::vector<SketchEntry> > thread_sketch;
// used to hold statistics // used to hold statistics
std::vector< std::vector<SKStats> > thread_stats; std::vector<std::vector<SKStats> > thread_stats;
// node statistics // node statistics
std::vector<SKStats> node_stats; std::vector<SKStats> node_stats;
// summary array // summary array
@ -392,8 +389,14 @@ class SketchMaker: public BaseMaker {
// reducer for summary // reducer for summary
rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer; rabit::SerializeReducer<WXQSketch::SummaryContainer> sketch_reducer;
// per node, per feature sketch // per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };
XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker")
.describe("Approximate sketching maker.")
.set_body([]() {
return new SketchMaker();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_SKMAKER_INL_HPP_