diff --git a/src/sync/sync.h b/src/sync/sync.h index 8d83ab5fb..fe34983ef 100644 --- a/src/sync/sync.h +++ b/src/sync/sync.h @@ -8,6 +8,7 @@ #include #include #include "../utils/utils.h" +#include "../utils/io.h" #include namespace xgboost { @@ -125,6 +126,54 @@ class Reducer { ReduceHandle handle; }; +/*! + * \brief template class to make customized reduce, complex reducer handles all the data structure that can be + * serialized/deserialzed into fixed size buffer + * Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize + * + * \tparam DType data type that to be reduced, DType must contain following functions: + * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d); + */ +template +class ComplexReducer { + public: + ComplexReducer(void) { + handle.Init(ReduceInner); + } + /*! + * \brief customized in-place all reduce operation + * \param sendrecvobj pointer to the object to be reduced + * \param max_n4byte maximum amount of memory needed in 4byte + * \param reducer the reducer function + */ + inline void AllReduce(DType *sendrecvobj, size_t max_n4byte) { + buffer.resize(max_n4byte); + utils::MemoryFixSizeBuffer fs(BeginPtr(buffer), max_n4byte * 4); + sendrecvobj->Save(fs); + handle.AllReduce(BeginPtr(buffer), max_n4byte); + fs.Seek(0); + sendrecvobj->Load(fs); + } + + private: + // unit size + // inner implementation of reducer + inline static void ReduceInner(const void *src_, void *dst_, int len_) { + utils::MemoryFixSizeBuffer fsrc((void*)(src_), len_); + utils::MemoryFixSizeBuffer fdst(dst_, len_); + // temp space + DType tsrc, tdst; + tsrc.Load(fsrc); tdst.Load(fdst); + // govern const check + tdst.Reduce(static_cast(tsrc)); + tdst.Save(fdst); + } + // function handle + ReduceHandle handle; + // reduce buffer + std::vector buffer; +}; + } // namespace sync } // namespace xgboost #endif diff --git a/src/tree/param.h b/src/tree/param.h index 47d31df1e..6402ef76a 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -38,6 +38,8 @@ struct TrainParam{ float opt_dense_col; // accuracy of sketch float sketch_eps; + // accuracy of sketch + float sketch_ratio; // leaf vector size int size_leaf_vector; // option for parallelization @@ -61,6 +63,7 @@ struct TrainParam{ size_leaf_vector = 0; parallel_option = 2; sketch_eps = 0.1f; + sketch_ratio = 1.4f; } /*! * \brief set parameters from outside @@ -83,6 +86,7 @@ struct TrainParam{ if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast(atof(val)); if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast(atof(val)); if (!strcmp(name, "sketch_eps")) sketch_eps = static_cast(atof(val)); + if (!strcmp(name, "sketch_ratio")) sketch_ratio = static_cast(atof(val)); if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast(atof(val)); if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val); if (!strcmp(name, "max_depth")) max_depth = atoi(val); diff --git a/src/tree/updater_histmaker-inl.hpp b/src/tree/updater_histmaker-inl.hpp index 40c4a5497..97e4d0aea 100644 --- a/src/tree/updater_histmaker-inl.hpp +++ b/src/tree/updater_histmaker-inl.hpp @@ -124,8 +124,7 @@ class HistMaker: public IUpdater { /*! \brief map active node to is working index offset in qexpand*/ std::vector node2workindex; // reducer for histogram - sync::Reducer histred; - + sync::Reducer histred; // helper function to get to next level of the tree // must work on non-leaf node inline static int NextLevel(const SparseBatch::Inst &inst, const RegTree &tree, int nid) { @@ -142,7 +141,6 @@ class HistMaker: public IUpdater { } return n.cdefault(); } - // this function does two jobs // (1) reset the position in array position, to be the latest leaf id // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly @@ -416,15 +414,38 @@ class QuantileHistMaker: public HistMaker { } } } + // setup maximum size + size_t max_size = static_cast(this->param.sketch_ratio / this->param.sketch_eps); // synchronize sketch - - - // now we have all the results in the sketchs, try to setup the cut point + summary_array.Init(sketchs.size(), max_size); + size_t n4bytes = (summary_array.MemSize() + 3) / 4; + sreducer.AllReduce(&summary_array, n4bytes); + // now we get the final result of sketch, setup the cut + for (size_t wid = 0; wid < this->qexpand.size(); ++wid) { + for (size_t fid = 0; fid < tree.param.num_feature; ++fid) { + const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid]; + for (size_t i = 0; i < a.size; ++i) { + bst_float cpt = a.data[i].value + rt_eps; + if (i == 0 || cpt > this->wspace.cut.back()){ + this->wspace.cut.push_back(cpt); + } + } + this->wspace.rptr.push_back(this->wspace.cut.size()); + } + // reserve last value for global statistics + this->wspace.cut.push_back(0.0f); + this->wspace.rptr.push_back(this->wspace.cut.size()); + } + utils::Assert(this->wspace.rptr.size() == + (tree.param.num_feature + 1) * this->qexpand.size(), "cut space inconsistent"); } private: - // - + typedef utils::WXQuantileSketch WXQSketch; + // summary array + WXQSketch::SummaryArray summary_array; + // reducer for summary + sync::ComplexReducer sreducer; // local temp column data structure std::vector col_ptr; // local storage of column data diff --git a/src/utils/io.h b/src/utils/io.h index 7dd550dc8..1a748feab 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -97,6 +97,45 @@ class ISeekStream: public IStream { virtual size_t Tell(void) = 0; }; +/*! \brief fixed size memory buffer */ +struct MemoryFixSizeBuffer : public ISeekStream { + public: + MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) { + curr_ptr_ = 0; + } + virtual ~MemoryFixSizeBuffer(void) {} + virtual size_t Read(void *ptr, size_t size) { + utils::Assert(curr_ptr_ <= buffer_size_, + "read can not have position excceed buffer length"); + size_t nread = std::min(buffer_size_ - curr_ptr_, size); + if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + virtual void Write(const void *ptr, size_t size) { + if (size == 0) return; + utils::Assert(curr_ptr_ + size <= buffer_size_, + "write position exceed fixed buffer size"); + memcpy(p_buffer_ + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + virtual void Seek(size_t pos) { + curr_ptr_ = static_cast(pos); + } + virtual size_t Tell(void) { + return curr_ptr_; + } + + private: + /*! \brief in memory buffer */ + char *p_buffer_; + /*! \brief current pointer */ + size_t buffer_size_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryFixSizeBuffer + /*! \brief a in memory buffer that can be read and write as stream interface */ struct MemoryBufferStream : public ISeekStream { public: diff --git a/src/utils/quantile.h b/src/utils/quantile.h index 62dc36e6c..c27fa9bfe 100644 --- a/src/utils/quantile.h +++ b/src/utils/quantile.h @@ -515,13 +515,95 @@ class QuantileSketchTemplate { } } }; + /*! + * \brief represent an array of summary + * each contains fixed maximum size summary + */ + class SummaryArray { + public: + /*! + * \brief intialize the SummaryArray + * \param num_summary number of summary in the array + * \param max_size maximum number of elements in each summary + */ + inline void Init(unsigned num_summary, unsigned max_size) { + this->num_summary = num_summary; + this->max_size = max_size; + sizes.resize(num_summary); + data.resize(num_summary * max_size); + } + /*! + * \brief set i-th element of array to be the src summary, + * the summary can be pruned if it does not fit into max_size + * \param the index in the array + * \param src the source summary + * \tparam the type if source summary + */ + template + inline void Set(size_t i, const TSrc &src) { + Summary dst = (*this)[i]; + dst.SetPrune(src, max_size); + this->sizes[i] = dst.size; + } + /*! + * \brief get i-th summary of the array, only use this for read purpose + */ + inline const Summary operator[](size_t i) const { + return Summary((Entry*)BeginPtr(data) + i * max_size, sizes[i]); + } + /*! + * \brief do elementwise combination of summary array + * this[i] = combine(this[i], src[i]) for each i + * \param src the source summary + */ + inline void Reduce(const SummaryArray &src) { + utils::Check(num_summary == src.num_summary && + max_size == src.max_size, "array shape mismatch in reduce"); + SummaryContainer temp; + temp.Reserve(max_size * 2); + for (unsigned i = 0; i < num_summary; ++i) { + temp.SetCombine((*this)[i], src[i]); + this->Set(i, temp); + } + } + /*! \brief return the number of bytes this data structure cost in serialization */ + inline size_t MemSize(void) const { + return sizeof(num_summary) + sizeof(max_size) + + data.size() * sizeof(Entry) + sizes.size() * sizeof(unsigned); + } + /*! \brief save the data structure into stream */ + inline void Save(IStream &fo) const { + fo.Write(&num_summary, sizeof(num_summary)); + fo.Write(&max_size, sizeof(max_size)); + fo.Write(BeginPtr(sizes), sizes.size() * sizeof(unsigned)); + fo.Write(BeginPtr(data), data.size() * sizeof(Entry)); + } + /*! \brief load data structure from input stream */ + inline void Load(IStream &fi) { + utils::Check(fi.Read(&num_summary, sizeof(num_summary)) != 0, "invalid SummaryArray"); + utils::Check(fi.Read(&max_size, sizeof(max_size)) != 0, "invalid SummaryArray"); + sizes.resize(num_summary); + data.resize(num_summary * max_size); + utils::Check(fi.Read(BeginPtr(sizes), sizes.size() * sizeof(unsigned)) != 0, "invalid SummaryArray"); + utils::Check(fi.Read(BeginPtr(data), data.size() * sizeof(Entry)) != 0, "invalid SummaryArray"); + } + + private: + /*! \brief number of summaries in the group */ + unsigned num_summary; + /*! \brief maximum size of each summary */ + unsigned max_size; + /*! \brief the current size of each summary */ + std::vector sizes; + /*! \brief the data content */ + std::vector data; + }; /*! * \brief intialize the quantile sketch, given the performance specification * \param maxn maximum number of data points can be feed into sketch * \param eps accuracy level of summary */ inline void Init(size_t maxn, double eps) { - //nlevel = std::max(log2(ceil(maxn * eps)) - 2.0, 1.0); nlevel = 1; while (true) { limit_size = ceil(nlevel / eps) + 1;