get multinode in
This commit is contained in:
@@ -160,13 +160,13 @@ class SerializeReducer {
|
||||
inline void AllReduce(DType *sendrecvobj, size_t max_n4byte, size_t count) {
|
||||
buffer.resize(max_n4byte * count);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
|
||||
sendrecvobj[i]->Save(fs);
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte, max_n4byte * 4);
|
||||
sendrecvobj[i].Save(fs);
|
||||
}
|
||||
handle.AllReduce(BeginPtr(buffer), max_n4byte, count);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte * 4, max_n4byte * 4);
|
||||
sendrecvobj[i]->Load(fs);
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte, max_n4byte * 4);
|
||||
sendrecvobj[i].Load(fs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,12 +178,12 @@ class SerializeReducer {
|
||||
// temp space
|
||||
DType tsrc, tdst;
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
utils::MemoryFixSizeBuffer fsrc((void*)(src_) + i * nbytes, nbytes);
|
||||
utils::MemoryFixSizeBuffer fdst(dst_ + i * nbytes, nbytes);
|
||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
||||
tsrc.Load(fsrc);
|
||||
tdst.Load(fdst);
|
||||
// govern const check
|
||||
tdst.Reduce(static_cast<const DType &>(tsrc));
|
||||
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
|
||||
fdst.Seek(0);
|
||||
tdst.Save(fdst);
|
||||
}
|
||||
|
||||
@@ -38,6 +38,9 @@ void Bcast(std::string *sendrecv_data, int root) {
|
||||
|
||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
}
|
||||
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {}
|
||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t n4byte) {}
|
||||
} // namespace sync
|
||||
|
||||
@@ -97,9 +97,12 @@ void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t coun
|
||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
|
||||
|
||||
if (created_type_n4bytes != type_n4bytes || htype == NULL) {
|
||||
dtype->Free();
|
||||
if (created_type_n4bytes != type_n4bytes || dtype == NULL) {
|
||||
if (dtype == NULL) {
|
||||
dtype = new MPI::Datatype();
|
||||
} else {
|
||||
dtype->Free();
|
||||
}
|
||||
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
|
||||
dtype->Commit();
|
||||
created_type_n4bytes = type_n4bytes;
|
||||
|
||||
Reference in New Issue
Block a user