Merge commit '3d11f56880521c1d45504c965ae12886e9b72ace'

This commit is contained in:
tqchen
2015-04-08 17:39:45 -07:00
29 changed files with 225 additions and 228 deletions

View File

@@ -19,7 +19,7 @@ namespace solver {
* to remember the state parameters that might need to remember
*/
template<typename DType>
class IObjFunction : public rabit::ISerializable {
class IObjFunction : public rabit::Serializable {
public:
// destructor
virtual ~IObjFunction(void){}
@@ -463,7 +463,7 @@ class LBFGSSolver {
}
}
// global solver state
struct GlobalState : public rabit::ISerializable {
struct GlobalState : public rabit::Serializable {
public:
// memory size of L-BFGS
size_t size_memory;
@@ -514,28 +514,28 @@ class LBFGSSolver {
MapIndex(j, offset_, size_memory)];
}
// load the shift array
virtual void Load(rabit::IStream &fi) {
fi.Read(&size_memory, sizeof(size_memory));
fi.Read(&num_iteration, sizeof(num_iteration));
fi.Read(&num_dim, sizeof(num_dim));
fi.Read(&init_objval, sizeof(init_objval));
fi.Read(&old_objval, sizeof(old_objval));
fi.Read(&offset_, sizeof(offset_));
fi.Read(&data);
virtual void Load(rabit::Stream *fi) {
fi->Read(&size_memory, sizeof(size_memory));
fi->Read(&num_iteration, sizeof(num_iteration));
fi->Read(&num_dim, sizeof(num_dim));
fi->Read(&init_objval, sizeof(init_objval));
fi->Read(&old_objval, sizeof(old_objval));
fi->Read(&offset_, sizeof(offset_));
fi->Read(&data);
this->AllocSpace();
fi.Read(weight, sizeof(DType) * num_dim);
fi->Read(weight, sizeof(DType) * num_dim);
obj->Load(fi);
}
// save the shift array
virtual void Save(rabit::IStream &fo) const {
fo.Write(&size_memory, sizeof(size_memory));
fo.Write(&num_iteration, sizeof(num_iteration));
fo.Write(&num_dim, sizeof(num_dim));
fo.Write(&init_objval, sizeof(init_objval));
fo.Write(&old_objval, sizeof(old_objval));
fo.Write(&offset_, sizeof(offset_));
fo.Write(data);
fo.Write(weight, sizeof(DType) * num_dim);
virtual void Save(rabit::Stream *fo) const {
fo->Write(&size_memory, sizeof(size_memory));
fo->Write(&num_iteration, sizeof(num_iteration));
fo->Write(&num_dim, sizeof(num_dim));
fo->Write(&init_objval, sizeof(init_objval));
fo->Write(&old_objval, sizeof(old_objval));
fo->Write(&offset_, sizeof(offset_));
fo->Write(data);
fo->Write(weight, sizeof(DType) * num_dim);
obj->Save(fo);
}
inline void Shift(void) {
@@ -556,7 +556,7 @@ class LBFGSSolver {
}
};
/*! \brief rolling array that carries history information */
struct HistoryArray : public rabit::ISerializable {
struct HistoryArray : public rabit::Serializable {
public:
HistoryArray(void) : dptr_(NULL) {
num_useful_ = 0;
@@ -609,26 +609,26 @@ class LBFGSSolver {
num_useful_ = num_useful;
}
// load the shift array
virtual void Load(rabit::IStream &fi) {
fi.Read(&num_col_, sizeof(num_col_));
fi.Read(&stride_, sizeof(stride_));
fi.Read(&size_memory_, sizeof(size_memory_));
fi.Read(&num_useful_, sizeof(num_useful_));
virtual void Load(rabit::Stream *fi) {
fi->Read(&num_col_, sizeof(num_col_));
fi->Read(&stride_, sizeof(stride_));
fi->Read(&size_memory_, sizeof(size_memory_));
fi->Read(&num_useful_, sizeof(num_useful_));
this->Init(num_col_, size_memory_);
for (size_t i = 0; i < num_useful_; ++i) {
fi.Read((*this)[i], num_col_ * sizeof(DType));
fi.Read((*this)[i + size_memory_], num_col_ * sizeof(DType));
fi->Read((*this)[i], num_col_ * sizeof(DType));
fi->Read((*this)[i + size_memory_], num_col_ * sizeof(DType));
}
}
// save the shift array
virtual void Save(rabit::IStream &fi) const {
fi.Write(&num_col_, sizeof(num_col_));
fi.Write(&stride_, sizeof(stride_));
fi.Write(&size_memory_, sizeof(size_memory_));
fi.Write(&num_useful_, sizeof(num_useful_));
virtual void Save(rabit::Stream *fo) const {
fo->Write(&num_col_, sizeof(num_col_));
fo->Write(&stride_, sizeof(stride_));
fo->Write(&size_memory_, sizeof(size_memory_));
fo->Write(&num_useful_, sizeof(num_useful_));
for (size_t i = 0; i < num_useful_; ++i) {
fi.Write((*this)[i], num_col_ * sizeof(DType));
fi.Write((*this)[i + size_memory_], num_col_ * sizeof(DType));
fo->Write((*this)[i], num_col_ * sizeof(DType));
fo->Write((*this)[i + size_memory_], num_col_ * sizeof(DType));
}
}