compile
This commit is contained in:
parent
55e62a7120
commit
7c3a392136
2
Makefile
2
Makefile
@ -19,7 +19,7 @@ SLIB = wrapper/libxgboostwrapper.so
|
|||||||
|
|
||||||
.PHONY: clean all mpi python Rpack
|
.PHONY: clean all mpi python Rpack
|
||||||
|
|
||||||
all: $(BIN) $(OBJ) $(SLIB)
|
all: $(BIN) $(OBJ) $(SLIB) mpi
|
||||||
mpi: $(MPIBIN)
|
mpi: $(MPIBIN)
|
||||||
|
|
||||||
python: wrapper/libxgboostwrapper.so
|
python: wrapper/libxgboostwrapper.so
|
||||||
|
|||||||
@ -5,7 +5,7 @@ then
|
|||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -rf train.col*
|
rm -rf train.col* *.model
|
||||||
k=$1
|
k=$1
|
||||||
|
|
||||||
# split the lib svm file into k subfiles
|
# split the lib svm file into k subfiles
|
||||||
|
|||||||
@ -11,6 +11,10 @@
|
|||||||
#include "../utils/io.h"
|
#include "../utils/io.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
namespace MPI {
|
||||||
|
// forward delcaration of MPI::Datatype, but not include content
|
||||||
|
class Datatype;
|
||||||
|
};
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*! \brief syncrhonizer module that minimumly wraps interface of MPI */
|
/*! \brief syncrhonizer module that minimumly wraps interface of MPI */
|
||||||
namespace sync {
|
namespace sync {
|
||||||
@ -62,23 +66,31 @@ void Bcast(std::string *sendrecv_data, int root);
|
|||||||
class ReduceHandle {
|
class ReduceHandle {
|
||||||
public:
|
public:
|
||||||
// reduce function
|
// reduce function
|
||||||
typedef void (ReduceFunction) (const void *src, void *dst, int len);
|
typedef void (ReduceFunction) (const void *src, void *dst, int len, const MPI::Datatype &dtype);
|
||||||
// constructor
|
// constructor
|
||||||
ReduceHandle(void);
|
ReduceHandle(void);
|
||||||
// destructor
|
// destructor
|
||||||
~ReduceHandle(void);
|
~ReduceHandle(void);
|
||||||
// initialize the reduce function
|
/*!
|
||||||
void Init(ReduceFunction redfunc, bool commute = true);
|
* \brief initialize the reduce function, with the type the reduce function need to deal with
|
||||||
|
*/
|
||||||
|
void Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute = true);
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation
|
* \brief customized in-place all reduce operation
|
||||||
* \param sendrecvbuf the in place send-recv buffer
|
* \param sendrecvbuf the in place send-recv buffer
|
||||||
* \param n4bytes number of nbytes send through all reduce
|
* \param type_n4bytes unit size of the type, in terms of 4bytes
|
||||||
|
* \param count number of elements to send
|
||||||
*/
|
*/
|
||||||
void AllReduce(void *sendrecvbuf, size_t n4bytes);
|
void AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count);
|
||||||
|
/*! \return the number of bytes occupied by the type */
|
||||||
|
static int TypeSize(const MPI::Datatype &dtype);
|
||||||
private:
|
private:
|
||||||
// handle data field
|
// handle data field
|
||||||
void *handle;
|
void *handle;
|
||||||
|
// handle to the type field
|
||||||
|
void *htype;
|
||||||
|
// the created type in 4 bytes
|
||||||
|
size_t created_type_n4bytes;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----- extensions for ease of use ------
|
// ----- extensions for ease of use ------
|
||||||
@ -92,7 +104,7 @@ template<typename DType>
|
|||||||
class Reducer {
|
class Reducer {
|
||||||
public:
|
public:
|
||||||
Reducer(void) {
|
Reducer(void) {
|
||||||
handle.Init(ReduceInner);
|
handle.Init(ReduceInner, kUnit);
|
||||||
utils::Assert(sizeof(DType) % sizeof(int) == 0, "struct must be multiple of int");
|
utils::Assert(sizeof(DType) % sizeof(int) == 0, "struct must be multiple of int");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -102,24 +114,23 @@ class Reducer {
|
|||||||
* \param reducer the reducer function
|
* \param reducer the reducer function
|
||||||
*/
|
*/
|
||||||
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
||||||
handle.AllReduce(sendrecvbuf, count * kUnit);
|
handle.AllReduce(sendrecvbuf, kUnit, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// unit size
|
// unit size
|
||||||
static const size_t kUnit = sizeof(DType) / sizeof(int);
|
static const size_t kUnit = sizeof(DType) / sizeof(int);
|
||||||
// inner implementation of reducer
|
// inner implementation of reducer
|
||||||
inline static void ReduceInner(const void *src_, void *dst_, int len_) {
|
inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
const int *psrc = reinterpret_cast<const int*>(src_);
|
const int *psrc = reinterpret_cast<const int*>(src_);
|
||||||
int *pdst = reinterpret_cast<int*>(dst_);
|
int *pdst = reinterpret_cast<int*>(dst_);
|
||||||
DType tdst, tsrc;
|
DType tdst, tsrc;
|
||||||
utils::Assert(len_ % kUnit == 0, "length not divide by size");
|
for (size_t i = 0; i < len_; ++i) {
|
||||||
for (size_t i = 0; i < len_; i += kUnit) {
|
|
||||||
// use memcpy to avoid alignment issue
|
// use memcpy to avoid alignment issue
|
||||||
std::memcpy(&tdst, pdst + i, sizeof(tdst));
|
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||||
std::memcpy(&tsrc, psrc + i, sizeof(tsrc));
|
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||||
tdst.Reduce(tsrc);
|
tdst.Reduce(tsrc);
|
||||||
std::memcpy(pdst + i, &tdst, sizeof(tdst));
|
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// function handle
|
// function handle
|
||||||
@ -135,38 +146,47 @@ class Reducer {
|
|||||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
||||||
*/
|
*/
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
class ComplexReducer {
|
class SerializeReducer {
|
||||||
public:
|
public:
|
||||||
ComplexReducer(void) {
|
SerializeReducer(void) {
|
||||||
handle.Init(ReduceInner);
|
handle.Init(ReduceInner, 0);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation
|
* \brief customized in-place all reduce operation
|
||||||
* \param sendrecvobj pointer to the object to be reduced
|
* \param sendrecvobj pointer to the object to be reduced
|
||||||
* \param max_n4byte maximum amount of memory needed in 4byte
|
* \param max_n4byte maximum amount of memory needed in 4byte
|
||||||
* \param reducer the reducer function
|
* \param reducer the reducer function
|
||||||
*/
|
*/
|
||||||
inline void AllReduce(DType *sendrecvobj, size_t max_n4byte) {
|
inline void AllReduce(DType *sendrecvobj, size_t max_n4byte, size_t count) {
|
||||||
buffer.resize(max_n4byte);
|
buffer.resize(max_n4byte * count);
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer), max_n4byte * 4);
|
for (size_t i = 0; i < count; ++i) {
|
||||||
sendrecvobj->Save(fs);
|
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
|
||||||
handle.AllReduce(BeginPtr(buffer), max_n4byte);
|
sendrecvobj[i]->Save(fs);
|
||||||
fs.Seek(0);
|
}
|
||||||
sendrecvobj->Load(fs);
|
handle.AllReduce(BeginPtr(buffer), max_n4byte, count);
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
|
||||||
|
sendrecvobj[i]->Load(fs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// unit size
|
// unit size
|
||||||
// inner implementation of reducer
|
// inner implementation of reducer
|
||||||
inline static void ReduceInner(const void *src_, void *dst_, int len_) {
|
inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
utils::MemoryFixSizeBuffer fsrc((void*)(src_), len_);
|
int nbytes = ReduceHandle::TypeSize(dtype);
|
||||||
utils::MemoryFixSizeBuffer fdst(dst_, len_);
|
|
||||||
// temp space
|
// temp space
|
||||||
DType tsrc, tdst;
|
DType tsrc, tdst;
|
||||||
tsrc.Load(fsrc); tdst.Load(fdst);
|
for (int i = 0; i < len_; ++i) {
|
||||||
// govern const check
|
utils::MemoryFixSizeBuffer fsrc((void*)(src_) + i * nbytes, nbytes);
|
||||||
tdst.Reduce(static_cast<const DType &>(tsrc));
|
utils::MemoryFixSizeBuffer fdst(dst_ + i * nbytes, nbytes);
|
||||||
tdst.Save(fdst);
|
tsrc.Load(fsrc);
|
||||||
|
tdst.Load(fdst);
|
||||||
|
// govern const check
|
||||||
|
tdst.Reduce(static_cast<const DType &>(tsrc));
|
||||||
|
fdst.Seek(0);
|
||||||
|
tdst.Save(fdst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// function handle
|
// function handle
|
||||||
ReduceHandle handle;
|
ReduceHandle handle;
|
||||||
|
|||||||
@ -38,8 +38,8 @@ void Bcast(std::string *sendrecv_data, int root) {
|
|||||||
|
|
||||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {}
|
ReduceHandle::ReduceHandle(void) : handle(NULL) {}
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {}
|
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {}
|
||||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {}
|
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t n4byte) {}
|
||||||
} // namespace sync
|
} // namespace sync
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#include "./sync.h"
|
#include "./sync.h"
|
||||||
#include "../utils/utils.h"
|
#include "../utils/utils.h"
|
||||||
#include "mpi.h"
|
#include <mpi.h>
|
||||||
|
|
||||||
// use MPI to implement sync
|
// use MPI to implement sync
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace sync {
|
namespace sync {
|
||||||
@ -60,7 +61,7 @@ void Bcast(std::string *sendrecv_data, int root) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {
|
ReduceHandle::ReduceHandle(void) : handle(NULL), htype(NULL) {
|
||||||
}
|
}
|
||||||
ReduceHandle::~ReduceHandle(void) {
|
ReduceHandle::~ReduceHandle(void) {
|
||||||
if (handle != NULL) {
|
if (handle != NULL) {
|
||||||
@ -68,19 +69,42 @@ ReduceHandle::~ReduceHandle(void) {
|
|||||||
op->Free();
|
op->Free();
|
||||||
delete op;
|
delete op;
|
||||||
}
|
}
|
||||||
|
if (htype != NULL) {
|
||||||
|
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
|
||||||
|
dtype->Free();
|
||||||
|
delete dtype;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
|
return dtype.Get_size();
|
||||||
|
}
|
||||||
|
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {
|
||||||
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
||||||
|
if (type_n4bytes != 0) {
|
||||||
|
MPI::Datatype *dtype = new MPI::Datatype();
|
||||||
|
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
|
||||||
|
dtype->Commit();
|
||||||
|
created_type_n4bytes = type_n4bytes;
|
||||||
|
htype = dtype;
|
||||||
|
}
|
||||||
|
|
||||||
MPI::Op *op = new MPI::Op();
|
MPI::Op *op = new MPI::Op();
|
||||||
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc);
|
MPI::User_function *pf = redfunc;
|
||||||
op->Init(pf, commute);
|
op->Init(pf, commute);
|
||||||
handle = op;
|
handle = op;
|
||||||
}
|
}
|
||||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {
|
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count) {
|
||||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
||||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op);
|
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
|
||||||
}
|
|
||||||
|
|
||||||
|
if (created_type_n4bytes != type_n4bytes || htype == NULL) {
|
||||||
|
dtype->Free();
|
||||||
|
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
|
||||||
|
dtype->Commit();
|
||||||
|
created_type_n4bytes = type_n4bytes;
|
||||||
|
}
|
||||||
|
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
||||||
|
}
|
||||||
} // namespace sync
|
} // namespace sync
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
#include "./updater_refresh-inl.hpp"
|
#include "./updater_refresh-inl.hpp"
|
||||||
#include "./updater_colmaker-inl.hpp"
|
#include "./updater_colmaker-inl.hpp"
|
||||||
#include "./updater_distcol-inl.hpp"
|
#include "./updater_distcol-inl.hpp"
|
||||||
#include "./updater_skmaker-inl.hpp"
|
//#include "./updater_skmaker-inl.hpp"
|
||||||
#include "./updater_histmaker-inl.hpp"
|
#include "./updater_histmaker-inl.hpp"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -18,8 +18,8 @@ IUpdater* CreateUpdater(const char *name) {
|
|||||||
if (!strcmp(name, "sync")) return new TreeSyncher();
|
if (!strcmp(name, "sync")) return new TreeSyncher();
|
||||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
//if (!strcmp(name, "grow_histmaker")) return new CQHistMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
//if (!strcmp(name, "grow_skmaker")) return new SketchMaker();
|
||||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||||
|
|
||||||
utils::Error("unknown updater:%s", name);
|
utils::Error("unknown updater:%s", name);
|
||||||
|
|||||||
@ -306,6 +306,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
hist.data[istart].Add(gpair, info, ridx);
|
hist.data[istart].Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||||
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
|
virtual void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||||
IFMatrix *p_fmat,
|
IFMatrix *p_fmat,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
@ -371,21 +372,22 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// setup maximum size
|
// setup maximum size
|
||||||
unsigned max_size = this->param.max_sketch_size();
|
unsigned max_size = this->param.max_sketch_size();
|
||||||
// synchronize sketch
|
// synchronize sketch
|
||||||
summary_array.Init(sketchs.size(), max_size);
|
summary_array.resize(sketchs.size());
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array.Set(i, out);
|
summary_array[i].Reserve(max_size);
|
||||||
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||||
sreducer.AllReduce(&summary_array, n4bytes);
|
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
this->wspace.rptr.clear();
|
this->wspace.rptr.clear();
|
||||||
this->wspace.rptr.push_back(0);
|
this->wspace.rptr.push_back(0);
|
||||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||||
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
for (size_t i = 1; i < a.size; ++i) {
|
||||||
bst_float cpt = a.data[i].value - rt_eps;
|
bst_float cpt = a.data[i].value - rt_eps;
|
||||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||||
@ -407,7 +409,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
utils::Assert(this->wspace.rptr.size() ==
|
utils::Assert(this->wspace.rptr.size() ==
|
||||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||||
"cut space inconsistent");
|
"cut space inconsistent");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -496,7 +498,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
|
||||||
// thread temp data
|
// thread temp data
|
||||||
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
std::vector< std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||||
// used to hold statistics
|
// used to hold statistics
|
||||||
@ -506,9 +507,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// node statistics
|
// node statistics
|
||||||
std::vector<TStats> node_stats;
|
std::vector<TStats> node_stats;
|
||||||
// summary array
|
// summary array
|
||||||
WXQSketch::SummaryArray summary_array;
|
std::vector< WXQSketch::SummaryContainer> summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||||
// per node, per feature sketch
|
// per node, per feature sketch
|
||||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
@ -580,23 +581,24 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// setup maximum size
|
// setup maximum size
|
||||||
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
unsigned max_size = this->param.max_sketch_size();
|
||||||
// synchronize sketch
|
// synchronize sketch
|
||||||
summary_array.Init(sketchs.size(), max_size);
|
summary_array.resize(sketchs.size());
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array.Set(i, out);
|
summary_array[i].Reserve(max_size);
|
||||||
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||||
sreducer.AllReduce(&summary_array, n4bytes);
|
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
this->wspace.rptr.clear();
|
this->wspace.rptr.clear();
|
||||||
this->wspace.rptr.push_back(0);
|
this->wspace.rptr.push_back(0);
|
||||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||||
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
const WXQSketch::Summary &a = summary_array[wid * tree.param.num_feature + fid];
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
for (size_t i = 1; i < a.size; ++i) {
|
||||||
bst_float cpt = a.data[i].value - rt_eps;
|
bst_float cpt = a.data[i].value - rt_eps;
|
||||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||||
@ -624,9 +626,9 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
private:
|
private:
|
||||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||||
// summary array
|
// summary array
|
||||||
WXQSketch::SummaryArray summary_array;
|
std::vector< WXQSketch::SummaryContainer> summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||||
// local temp column data structure
|
// local temp column data structure
|
||||||
std::vector<size_t> col_ptr;
|
std::vector<size_t> col_ptr;
|
||||||
// local storage of column data
|
// local storage of column data
|
||||||
|
|||||||
@ -224,6 +224,12 @@ struct WQSummary {
|
|||||||
*/
|
*/
|
||||||
inline void SetCombine(const WQSummary &sa,
|
inline void SetCombine(const WQSummary &sa,
|
||||||
const WQSummary &sb) {
|
const WQSummary &sb) {
|
||||||
|
if (sa.size == 0) {
|
||||||
|
this->CopyFrom(sb); return;
|
||||||
|
}
|
||||||
|
if (sb.size == 0) {
|
||||||
|
this->CopyFrom(sa); return;
|
||||||
|
}
|
||||||
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
|
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
|
||||||
const Entry *a = sa.data, *a_end = sa.data + sa.size;
|
const Entry *a = sa.data, *a_end = sa.data + sa.size;
|
||||||
const Entry *b = sb.data, *b_end = sb.data + sb.size;
|
const Entry *b = sb.data, *b_end = sb.data + sb.size;
|
||||||
@ -453,6 +459,12 @@ struct GKSummary {
|
|||||||
}
|
}
|
||||||
inline void SetCombine(const GKSummary &sa,
|
inline void SetCombine(const GKSummary &sa,
|
||||||
const GKSummary &sb) {
|
const GKSummary &sb) {
|
||||||
|
if (sa.size == 0) {
|
||||||
|
this->CopyFrom(sb); return;
|
||||||
|
}
|
||||||
|
if (sb.size == 0) {
|
||||||
|
this->CopyFrom(sa); return;
|
||||||
|
}
|
||||||
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
|
utils::Assert(sa.size > 0 && sb.size > 0, "invalid input for merge");
|
||||||
const Entry *a = sa.data, *a_end = sa.data + sa.size;
|
const Entry *a = sa.data, *a_end = sa.data + sa.size;
|
||||||
const Entry *b = sb.data, *b_end = sb.data + sb.size;
|
const Entry *b = sb.data, *b_end = sb.data + sb.size;
|
||||||
@ -537,96 +549,41 @@ class QuantileSketchTemplate {
|
|||||||
this->SetMerge(begin[0], begin[1]);
|
this->SetMerge(begin[0], begin[1]);
|
||||||
} else {
|
} else {
|
||||||
// recursive merge
|
// recursive merge
|
||||||
SummaryContainer lhs, rhs;
|
SummaryContainer lhs, rhs;
|
||||||
lhs.SetCombine(begin, begin + len / 2);
|
lhs.SetCombine(begin, begin + len / 2);
|
||||||
rhs.SetCombine(begin + len / 2, end);
|
rhs.SetCombine(begin + len / 2, end);
|
||||||
this->Reserve(lhs.size + rhs.size);
|
this->Reserve(lhs.size + rhs.size);
|
||||||
this->SetCombine(lhs, rhs);
|
this->SetCombine(lhs, rhs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
/*!
|
|
||||||
* \brief represent an array of summary
|
|
||||||
* each contains fixed maximum size summary
|
|
||||||
*/
|
|
||||||
class SummaryArray {
|
|
||||||
public:
|
|
||||||
/*!
|
|
||||||
* \brief intialize the SummaryArray
|
|
||||||
* \param num_summary number of summary in the array
|
|
||||||
* \param max_size maximum number of elements in each summary
|
|
||||||
*/
|
|
||||||
inline void Init(unsigned num_summary, unsigned max_size) {
|
|
||||||
this->num_summary = num_summary;
|
|
||||||
this->max_size = max_size;
|
|
||||||
sizes.resize(num_summary);
|
|
||||||
data.resize(num_summary * max_size);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set i-th element of array to be the src summary,
|
|
||||||
* the summary can be pruned if it does not fit into max_size
|
|
||||||
* \param the index in the array
|
|
||||||
* \param src the source summary
|
|
||||||
* \tparam the type if source summary
|
|
||||||
*/
|
|
||||||
template<typename TSrc>
|
|
||||||
inline void Set(size_t i, const TSrc &src) {
|
|
||||||
Summary dst = (*this)[i];
|
|
||||||
dst.SetPrune(src, max_size);
|
|
||||||
this->sizes[i] = dst.size;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief get i-th summary of the array, only use this for read purpose
|
|
||||||
*/
|
|
||||||
inline const Summary operator[](size_t i) const {
|
|
||||||
return Summary((Entry*)BeginPtr(data) + i * max_size, sizes[i]);
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief do elementwise combination of summary array
|
* \brief do elementwise combination of summary array
|
||||||
* this[i] = combine(this[i], src[i]) for each i
|
* this[i] = combine(this[i], src[i]) for each i
|
||||||
* \param src the source summary
|
* \param src the source summary
|
||||||
|
* \param max_nbyte, maximum number of byte allowed in here
|
||||||
*/
|
*/
|
||||||
inline void Reduce(const SummaryArray &src) {
|
inline void Reduce(const Summary &src, size_t max_nbyte) {
|
||||||
utils::Check(num_summary == src.num_summary &&
|
this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry));
|
||||||
max_size == src.max_size, "array shape mismatch in reduce");
|
|
||||||
SummaryContainer temp;
|
SummaryContainer temp;
|
||||||
temp.Reserve(max_size * 2);
|
temp.Reserve(this->size + src.size);
|
||||||
for (unsigned i = 0; i < num_summary; ++i) {
|
temp.SetCombine(*this, src);
|
||||||
temp.SetCombine((*this)[i], src[i]);
|
this->SetPrune(temp, space.size());
|
||||||
this->Set(i, temp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
/*! \brief return the number of bytes this data structure cost in serialization */
|
/*! \brief return the number of bytes this data structure cost in serialization */
|
||||||
inline size_t MemSize(void) const {
|
inline static size_t CalcMemCost(size_t nentry) {
|
||||||
return sizeof(num_summary) + sizeof(max_size)
|
return sizeof(size_t) + sizeof(Entry) * nentry;
|
||||||
+ data.size() * sizeof(Entry) + sizes.size() * sizeof(unsigned);
|
|
||||||
}
|
}
|
||||||
/*! \brief save the data structure into stream */
|
/*! \brief save the data structure into stream */
|
||||||
inline void Save(IStream &fo) const {
|
inline void Save(IStream &fo) const {
|
||||||
fo.Write(&num_summary, sizeof(num_summary));
|
fo.Write(&(this->size), sizeof(this->size));
|
||||||
fo.Write(&max_size, sizeof(max_size));
|
fo.Write(data, this->size * sizeof(Entry));
|
||||||
fo.Write(BeginPtr(sizes), sizes.size() * sizeof(unsigned));
|
|
||||||
fo.Write(BeginPtr(data), data.size() * sizeof(Entry));
|
|
||||||
}
|
}
|
||||||
/*! \brief load data structure from input stream */
|
/*! \brief load data structure from input stream */
|
||||||
inline void Load(IStream &fi) {
|
inline void Load(IStream &fi) {
|
||||||
utils::Check(fi.Read(&num_summary, sizeof(num_summary)) != 0, "invalid SummaryArray");
|
utils::Check(fi.Read(&this->size, sizeof(this->size)) != 0, "invalid SummaryArray 1");
|
||||||
utils::Check(fi.Read(&max_size, sizeof(max_size)) != 0, "invalid SummaryArray");
|
this->Reserve(this->size);
|
||||||
sizes.resize(num_summary);
|
utils::Check(fi.Read(data, this->size * sizeof(Entry)) != 0, "invalid SummaryArray 2");
|
||||||
data.resize(num_summary * max_size);
|
|
||||||
utils::Check(fi.Read(BeginPtr(sizes), sizes.size() * sizeof(unsigned)) != 0, "invalid SummaryArray");
|
|
||||||
utils::Check(fi.Read(BeginPtr(data), data.size() * sizeof(Entry)) != 0, "invalid SummaryArray");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
/*! \brief number of summaries in the group */
|
|
||||||
unsigned num_summary;
|
|
||||||
/*! \brief maximum size of each summary */
|
|
||||||
unsigned max_size;
|
|
||||||
/*! \brief the current size of each summary */
|
|
||||||
std::vector<unsigned> sizes;
|
|
||||||
/*! \brief the data content */
|
|
||||||
std::vector<Entry> data;
|
|
||||||
};
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief intialize the quantile sketch, given the performance specification
|
* \brief intialize the quantile sketch, given the performance specification
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user