This commit is contained in:
tqchen 2014-11-19 15:28:09 -08:00
parent 55e62a7120
commit 7c3a392136
8 changed files with 136 additions and 133 deletions

View File

@ -19,7 +19,7 @@ SLIB = wrapper/libxgboostwrapper.so
.PHONY: clean all mpi python Rpack .PHONY: clean all mpi python Rpack
all: $(BIN) $(OBJ) $(SLIB) all: $(BIN) $(OBJ) $(SLIB) mpi
mpi: $(MPIBIN) mpi: $(MPIBIN)
python: wrapper/libxgboostwrapper.so python: wrapper/libxgboostwrapper.so

View File

@ -5,7 +5,7 @@ then
exit -1 exit -1
fi fi
rm -rf train.col* rm -rf train.col* *.model
k=$1 k=$1
# split the lib svm file into k subfiles # split the lib svm file into k subfiles

View File

@ -11,6 +11,10 @@
#include "../utils/io.h" #include "../utils/io.h"
#include <string> #include <string>
namespace MPI {
// forward delcaration of MPI::Datatype, but not include content
class Datatype;
};
namespace xgboost { namespace xgboost {
/*! \brief syncrhonizer module that minimumly wraps interface of MPI */ /*! \brief syncrhonizer module that minimumly wraps interface of MPI */
namespace sync { namespace sync {
@ -62,23 +66,31 @@ void Bcast(std::string *sendrecv_data, int root);
class ReduceHandle { class ReduceHandle {
public: public:
// reduce function // reduce function
typedef void (ReduceFunction) (const void *src, void *dst, int len); typedef void (ReduceFunction) (const void *src, void *dst, int len, const MPI::Datatype &dtype);
// constructor // constructor
ReduceHandle(void); ReduceHandle(void);
// destructor // destructor
~ReduceHandle(void); ~ReduceHandle(void);
// initialize the reduce function /*!
void Init(ReduceFunction redfunc, bool commute = true); * \brief initialize the reduce function, with the type the reduce function need to deal with
*/
void Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute = true);
/*! /*!
* \brief customized in-place all reduce operation * \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer * \param sendrecvbuf the in place send-recv buffer
* \param n4bytes number of nbytes send through all reduce * \param type_n4bytes unit size of the type, in terms of 4bytes
* \param count number of elements to send
*/ */
void AllReduce(void *sendrecvbuf, size_t n4bytes); void AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count);
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private: private:
// handle data field // handle data field
void *handle; void *handle;
// handle to the type field
void *htype;
// the created type in 4 bytes
size_t created_type_n4bytes;
}; };
// ----- extensions for ease of use ------ // ----- extensions for ease of use ------
@ -92,7 +104,7 @@ template<typename DType>
class Reducer { class Reducer {
public: public:
Reducer(void) { Reducer(void) {
handle.Init(ReduceInner); handle.Init(ReduceInner, kUnit);
utils::Assert(sizeof(DType) % sizeof(int) == 0, "struct must be multiple of int"); utils::Assert(sizeof(DType) % sizeof(int) == 0, "struct must be multiple of int");
} }
/*! /*!
@ -102,24 +114,23 @@ class Reducer {
* \param reducer the reducer function * \param reducer the reducer function
*/ */
inline void AllReduce(DType *sendrecvbuf, size_t count) { inline void AllReduce(DType *sendrecvbuf, size_t count) {
handle.AllReduce(sendrecvbuf, count * kUnit); handle.AllReduce(sendrecvbuf, kUnit, count);
} }
private: private:
// unit size // unit size
static const size_t kUnit = sizeof(DType) / sizeof(int); static const size_t kUnit = sizeof(DType) / sizeof(int);
// inner implementation of reducer // inner implementation of reducer
inline static void ReduceInner(const void *src_, void *dst_, int len_) { inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const int *psrc = reinterpret_cast<const int*>(src_); const int *psrc = reinterpret_cast<const int*>(src_);
int *pdst = reinterpret_cast<int*>(dst_); int *pdst = reinterpret_cast<int*>(dst_);
DType tdst, tsrc; DType tdst, tsrc;
utils::Assert(len_ % kUnit == 0, "length not divide by size"); for (size_t i = 0; i < len_; ++i) {
for (size_t i = 0; i < len_; i += kUnit) {
// use memcpy to avoid alignment issue // use memcpy to avoid alignment issue
std::memcpy(&tdst, pdst + i, sizeof(tdst)); std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
std::memcpy(&tsrc, psrc + i, sizeof(tsrc)); std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
tdst.Reduce(tsrc); tdst.Reduce(tsrc);
std::memcpy(pdst + i, &tdst, sizeof(tdst)); std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
} }
} }
// function handle // function handle
@ -135,10 +146,10 @@ class Reducer {
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d); * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
*/ */
template<typename DType> template<typename DType>
class ComplexReducer { class SerializeReducer {
public: public:
ComplexReducer(void) { SerializeReducer(void) {
handle.Init(ReduceInner); handle.Init(ReduceInner, 0);
} }
/*! /*!
* \brief customized in-place all reduce operation * \brief customized in-place all reduce operation
@ -146,28 +157,37 @@ class ComplexReducer {
* \param max_n4byte maximum amount of memory needed in 4byte * \param max_n4byte maximum amount of memory needed in 4byte
* \param reducer the reducer function * \param reducer the reducer function
*/ */
inline void AllReduce(DType *sendrecvobj, size_t max_n4byte) { inline void AllReduce(DType *sendrecvobj, size_t max_n4byte, size_t count) {
buffer.resize(max_n4byte); buffer.resize(max_n4byte * count);
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer), max_n4byte * 4); for (size_t i = 0; i < count; ++i) {
sendrecvobj->Save(fs); utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
handle.AllReduce(BeginPtr(buffer), max_n4byte); sendrecvobj[i]->Save(fs);
fs.Seek(0); }
sendrecvobj->Load(fs); handle.AllReduce(BeginPtr(buffer), max_n4byte, count);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
sendrecvobj[i]->Load(fs);
}
} }
private: private:
// unit size // unit size
// inner implementation of reducer // inner implementation of reducer
inline static void ReduceInner(const void *src_, void *dst_, int len_) { inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
utils::MemoryFixSizeBuffer fsrc((void*)(src_), len_); int nbytes = ReduceHandle::TypeSize(dtype);
utils::MemoryFixSizeBuffer fdst(dst_, len_);
// temp space // temp space
DType tsrc, tdst; DType tsrc, tdst;
tsrc.Load(fsrc); tdst.Load(fdst); for (int i = 0; i < len_; ++i) {
utils::MemoryFixSizeBuffer fsrc((void*)(src_) + i * nbytes, nbytes);
utils::MemoryFixSizeBuffer fdst(dst_ + i * nbytes, nbytes);
tsrc.Load(fsrc);
tdst.Load(fdst);
// govern const check // govern const check
tdst.Reduce(static_cast<const DType &>(tsrc)); tdst.Reduce(static_cast<const DType &>(tsrc));
fdst.Seek(0);
tdst.Save(fdst); tdst.Save(fdst);
} }
}
// function handle // function handle
ReduceHandle handle; ReduceHandle handle;
// reduce buffer // reduce buffer

View File

@ -38,8 +38,8 @@ void Bcast(std::string *sendrecv_data, int root) {
ReduceHandle::ReduceHandle(void) : handle(NULL) {} ReduceHandle::ReduceHandle(void) : handle(NULL) {}
ReduceHandle::~ReduceHandle(void) {} ReduceHandle::~ReduceHandle(void) {}
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {} void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {}
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {} void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t n4byte) {}
} // namespace sync } // namespace sync
} // namespace xgboost } // namespace xgboost

View File

@ -1,6 +1,7 @@
#include "./sync.h" #include "./sync.h"
#include "../utils/utils.h" #include "../utils/utils.h"
#include "mpi.h" #include <mpi.h>
// use MPI to implement sync // use MPI to implement sync
namespace xgboost { namespace xgboost {
namespace sync { namespace sync {
@ -60,7 +61,7 @@ void Bcast(std::string *sendrecv_data, int root) {
} }
// code for reduce handle // code for reduce handle
ReduceHandle::ReduceHandle(void) : handle(NULL) { ReduceHandle::ReduceHandle(void) : handle(NULL), htype(NULL) {
} }
ReduceHandle::~ReduceHandle(void) { ReduceHandle::~ReduceHandle(void) {
if (handle != NULL) { if (handle != NULL) {
@ -68,19 +69,42 @@ ReduceHandle::~ReduceHandle(void) {
op->Free(); op->Free();
delete op; delete op;
} }
if (htype != NULL) {
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
dtype->Free();
delete dtype;
} }
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) { }
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return dtype.Get_size();
}
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {
utils::Assert(handle == NULL, "cannot initialize reduce handle twice"); utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
if (type_n4bytes != 0) {
MPI::Datatype *dtype = new MPI::Datatype();
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
dtype->Commit();
created_type_n4bytes = type_n4bytes;
htype = dtype;
}
MPI::Op *op = new MPI::Op(); MPI::Op *op = new MPI::Op();
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc); MPI::User_function *pf = redfunc;
op->Init(pf, commute); op->Init(pf, commute);
handle = op; handle = op;
} }
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) { void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count) {
utils::Assert(handle != NULL, "must intialize handle to call AllReduce"); utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle); MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op); MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
}
if (created_type_n4bytes != type_n4bytes || htype == NULL) {
dtype->Free();
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
dtype->Commit();
created_type_n4bytes = type_n4bytes;
}
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
}
} // namespace sync } // namespace sync
} // namespace xgboost } // namespace xgboost

View File

@ -7,7 +7,7 @@
#include "./updater_refresh-inl.hpp" #include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp" #include "./updater_colmaker-inl.hpp"
#include "./updater_distcol-inl.hpp" #include "./updater_distcol-inl.hpp"
#include "./updater_skmaker-inl.hpp" //#include "./updater_skmaker-inl.hpp"
#include "./updater_histmaker-inl.hpp" #include "./updater_histmaker-inl.hpp"
namespace xgboost { namespace xgboost {
@ -18,8 +18,8 @@ IUpdater* CreateUpdater(const char *name) {
if (!strcmp(name, "sync")) return new TreeSyncher(); if (!strcmp(name, "sync")) return new TreeSyncher();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>(); if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>(); if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>(); //if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
if (!strcmp(name, "grow_skmaker")) return new SketchMaker(); //if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>(); if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
utils::Error("unknown updater:%s", name); utils::Error("unknown updater:%s", name);

View File

@ -306,6 +306,7 @@ class CQHistMaker: public HistMaker<TStats> {
hist.data[istart].Add(gpair, info, ridx); hist.data[istart].Add(gpair, info, ridx);
} }
}; };
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
virtual void CreateHist(const std::vector<bst_gpair> &gpair, virtual void CreateHist(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
@ -371,21 +372,22 @@ class CQHistMaker: public HistMaker<TStats> {
// setup maximum size // setup maximum size
unsigned max_size = this->param.max_sketch_size(); unsigned max_size = this->param.max_sketch_size();
// synchronize sketch // synchronize sketch
summary_array.Init(sketchs.size(), max_size); summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array.Set(i, out); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
} }
size_t n4bytes = (summary_array.MemSize() + 3) / 4; size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes); sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
this->wspace.rptr.clear(); this->wspace.rptr.clear();
this->wspace.rptr.push_back(0); this->wspace.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) { for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) { for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid]; const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) { for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - rt_eps; bst_float cpt = a.data[i].value - rt_eps;
if (i == 1 || cpt > this->wspace.cut.back()) { if (i == 1 || cpt > this->wspace.cut.back()) {
@ -496,7 +498,6 @@ class CQHistMaker: public HistMaker<TStats> {
} }
} }
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// thread temp data // thread temp data
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch; std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
// used to hold statistics // used to hold statistics
@ -506,9 +507,9 @@ class CQHistMaker: public HistMaker<TStats> {
// node statistics // node statistics
std::vector<TStats> node_stats; std::vector<TStats> node_stats;
// summary array // summary array
WXQSketch::SummaryArray summary_array; std::vector< WXQSketch::SummaryContainer> summary_array;
// reducer for summary // reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer; sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
// per node, per feature sketch // per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };
@ -580,23 +581,24 @@ class QuantileHistMaker: public HistMaker<TStats> {
} }
} }
// setup maximum size // setup maximum size
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps); unsigned max_size = this->param.max_sketch_size();
// synchronize sketch // synchronize sketch
summary_array.Init(sketchs.size(), max_size); summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out; utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array.Set(i, out); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
} }
size_t n4bytes = (summary_array.MemSize() + 3) / 4; size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes); sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
this->wspace.rptr.clear(); this->wspace.rptr.clear();
this->wspace.rptr.push_back(0); this->wspace.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) { for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) { for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid]; const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) { for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - rt_eps; bst_float cpt = a.data[i].value - rt_eps;
if (i == 1 || cpt > this->wspace.cut.back()) { if (i == 1 || cpt > this->wspace.cut.back()) {
@ -624,9 +626,9 @@ class QuantileHistMaker: public HistMaker<TStats> {
private: private:
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch; typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// summary array // summary array
WXQSketch::SummaryArray summary_array; std::vector< WXQSketch::SummaryContainer> summary_array;
// reducer for summary // reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer; sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
// local temp column data structure // local temp column data structure
std::vector<size_t> col_ptr; std::vector<size_t> col_ptr;
// local storage of column data // local storage of column data

View File

@ -224,6 +224,12 @@ struct WQSummary {
*/ */
inline void SetCombine(const WQSummary &sa, inline void SetCombine(const WQSummary &sa,
const WQSummary &sb) { const WQSummary &sb) {
if (sa.size == 0) {
this->CopyFrom(sb); return;
}
if (sb.size == 0) {
this->CopyFrom(sa); return;
}
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge"); utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
const Entry *a = sa.data, *a_end = sa.data + sa.size; const Entry *a = sa.data, *a_end = sa.data + sa.size;
const Entry *b = sb.data, *b_end = sb.data + sb.size; const Entry *b = sb.data, *b_end = sb.data + sb.size;
@ -453,6 +459,12 @@ struct GKSummary {
} }
inline void SetCombine(const GKSummary &sa, inline void SetCombine(const GKSummary &sa,
const GKSummary &sb) { const GKSummary &sb) {
if (sa.size == 0) {
this->CopyFrom(sb); return;
}
if (sb.size == 0) {
this->CopyFrom(sa); return;
}
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge"); utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
const Entry *a = sa.data, *a_end = sa.data + sa.size; const Entry *a = sa.data, *a_end = sa.data + sa.size;
const Entry *b = sb.data, *b_end = sb.data + sb.size; const Entry *b = sb.data, *b_end = sb.data + sb.size;
@ -544,89 +556,34 @@ class QuantileSketchTemplate {
this->SetCombine(lhs, rhs); this->SetCombine(lhs, rhs);
} }
} }
};
/*!
* \brief represent an array of summary
* each contains fixed maximum size summary
*/
class SummaryArray {
public:
/*!
* \brief intialize the SummaryArray
* \param num_summary number of summary in the array
* \param max_size maximum number of elements in each summary
*/
inline void Init(unsigned num_summary, unsigned max_size) {
this->num_summary = num_summary;
this->max_size = max_size;
sizes.resize(num_summary);
data.resize(num_summary * max_size);
}
/*!
* \brief set i-th element of array to be the src summary,
* the summary can be pruned if it does not fit into max_size
* \param the index in the array
* \param src the source summary
* \tparam the type if source summary
*/
template<typename TSrc>
inline void Set(size_t i, const TSrc &src) {
Summary dst = (*this)[i];
dst.SetPrune(src, max_size);
this->sizes[i] = dst.size;
}
/*!
* \brief get i-th summary of the array, only use this for read purpose
*/
inline const Summary operator[](size_t i) const {
return Summary((Entry*)BeginPtr(data) + i * max_size, sizes[i]);
}
/*! /*!
* \brief do elementwise combination of summary array * \brief do elementwise combination of summary array
* this[i] = combine(this[i], src[i]) for each i * this[i] = combine(this[i], src[i]) for each i
* \param src the source summary * \param src the source summary
* \param max_nbyte, maximum number of byte allowed in here
*/ */
inline void Reduce(const SummaryArray &src) { inline void Reduce(const Summary &src, size_t max_nbyte) {
utils::Check(num_summary == src.num_summary && this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry));
max_size == src.max_size, "array shape mismatch in reduce");
SummaryContainer temp; SummaryContainer temp;
temp.Reserve(max_size * 2); temp.Reserve(this->size + src.size);
for (unsigned i = 0; i < num_summary; ++i) { temp.SetCombine(*this, src);
temp.SetCombine((*this)[i], src[i]); this->SetPrune(temp, space.size());
this->Set(i, temp);
}
} }
/*! \brief return the number of bytes this data structure cost in serialization */ /*! \brief return the number of bytes this data structure cost in serialization */
inline size_t MemSize(void) const { inline static size_t CalcMemCost(size_t nentry) {
return sizeof(num_summary) + sizeof(max_size) return sizeof(size_t) + sizeof(Entry) * nentry;
+ data.size() * sizeof(Entry) + sizes.size() * sizeof(unsigned);
} }
/*! \brief save the data structure into stream */ /*! \brief save the data structure into stream */
inline void Save(IStream &fo) const { inline void Save(IStream &fo) const {
fo.Write(&num_summary, sizeof(num_summary)); fo.Write(&(this->size), sizeof(this->size));
fo.Write(&max_size, sizeof(max_size)); fo.Write(data, this->size * sizeof(Entry));
fo.Write(BeginPtr(sizes), sizes.size() * sizeof(unsigned));
fo.Write(BeginPtr(data), data.size() * sizeof(Entry));
} }
/*! \brief load data structure from input stream */ /*! \brief load data structure from input stream */
inline void Load(IStream &fi) { inline void Load(IStream &fi) {
utils::Check(fi.Read(&num_summary, sizeof(num_summary)) != 0, "invalid SummaryArray"); utils::Check(fi.Read(&this->size, sizeof(this->size)) != 0, "invalid SummaryArray 1");
utils::Check(fi.Read(&max_size, sizeof(max_size)) != 0, "invalid SummaryArray"); this->Reserve(this->size);
sizes.resize(num_summary); utils::Check(fi.Read(data, this->size * sizeof(Entry)) != 0, "invalid SummaryArray 2");
data.resize(num_summary * max_size);
utils::Check(fi.Read(BeginPtr(sizes), sizes.size() * sizeof(unsigned)) != 0, "invalid SummaryArray");
utils::Check(fi.Read(BeginPtr(data), data.size() * sizeof(Entry)) != 0, "invalid SummaryArray");
} }
private:
/*! \brief number of summaries in the group */
unsigned num_summary;
/*! \brief maximum size of each summary */
unsigned max_size;
/*! \brief the current size of each summary */
std::vector<unsigned> sizes;
/*! \brief the data content */
std::vector<Entry> data;
}; };
/*! /*!
* \brief intialize the quantile sketch, given the performance specification * \brief intialize the quantile sketch, given the performance specification