Merge commit '3d11f56880521c1d45504c965ae12886e9b72ace'

This commit is contained in:
tqchen
2015-04-08 17:39:45 -07:00
29 changed files with 225 additions and 228 deletions

View File

@@ -7,22 +7,54 @@
using namespace rabit;
// simple dense matrix, mshadow or Eigen matrix was better
// this was was OK
struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
this->nrow = nrow;
this->ncol = ncol;
data.resize(nrow * ncol);
std::fill(data.begin(), data.end(), v);
}
inline float *operator[](size_t i) {
return &data[0] + i * ncol;
}
inline const float *operator[](size_t i) const {
return &data[0] + i * ncol;
}
inline void Print(utils::Stream *fo) {
for (size_t i = 0; i < data.size(); ++i) {
std::ostringstream ss;
ss << data[i];
if ((i+1) % ncol == 0) {
ss << '\n';
} else {
ss << ' ';
}
}
std::string s = ss.str();
}
// number of data
size_t nrow, ncol;
std::vector<float> data;
};
// kmeans model
class Model : public rabit::ISerializable {
class Model : public rabit::Serializable {
public:
// matrix of centroids
Matrix centroids;
// load from stream
virtual void Load(rabit::IStream &fi) {
fi.Read(&centroids.nrow, sizeof(centroids.nrow));
fi.Read(&centroids.ncol, sizeof(centroids.ncol));
fi.Read(&centroids.data);
virtual void Load(rabit::Stream *fi) {
fi->Read(&centroids.nrow, sizeof(centroids.nrow));
fi->Read(&centroids.ncol, sizeof(centroids.ncol));
fi->Read(&centroids.data);
}
/*! \brief save the model to the stream */
virtual void Save(rabit::IStream &fo) const {
fo.Write(&centroids.nrow, sizeof(centroids.nrow));
fo.Write(&centroids.ncol, sizeof(centroids.ncol));
fo.Write(centroids.data);
virtual void Save(rabit::Stream *fo) const {
fo->Write(&centroids.nrow, sizeof(centroids.nrow));
fo->Write(&centroids.ncol, sizeof(centroids.ncol));
fo->Write(centroids.data);
}
virtual void InitModel(unsigned num_cluster, unsigned feat_dim) {
centroids.Init(num_cluster, feat_dim);
@@ -153,7 +185,7 @@ int main(int argc, char *argv[]) {
}
}
model.Normalize();
rabit::CheckPoint(&model);
rabit::LazyCheckPoint(&model);
}
// output the model file to somewhere
if (rabit::GetRank() == 0) {