Merge commit '3d11f56880521c1d45504c965ae12886e9b72ace'
This commit is contained in:
@@ -7,22 +7,54 @@
|
||||
|
||||
using namespace rabit;
|
||||
|
||||
// simple dense matrix, mshadow or Eigen matrix was better
|
||||
// this was was OK
|
||||
struct Matrix {
|
||||
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
|
||||
this->nrow = nrow;
|
||||
this->ncol = ncol;
|
||||
data.resize(nrow * ncol);
|
||||
std::fill(data.begin(), data.end(), v);
|
||||
}
|
||||
inline float *operator[](size_t i) {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline const float *operator[](size_t i) const {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline void Print(utils::Stream *fo) {
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
std::ostringstream ss;
|
||||
ss << data[i];
|
||||
if ((i+1) % ncol == 0) {
|
||||
ss << '\n';
|
||||
} else {
|
||||
ss << ' ';
|
||||
}
|
||||
}
|
||||
std::string s = ss.str();
|
||||
}
|
||||
// number of data
|
||||
size_t nrow, ncol;
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
// kmeans model
|
||||
class Model : public rabit::ISerializable {
|
||||
class Model : public rabit::Serializable {
|
||||
public:
|
||||
// matrix of centroids
|
||||
Matrix centroids;
|
||||
// load from stream
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fi.Read(¢roids.data);
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
fi->Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fi->Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fi->Read(¢roids.data);
|
||||
}
|
||||
/*! \brief save the model to the stream */
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fo.Write(centroids.data);
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
fo->Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fo->Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fo->Write(centroids.data);
|
||||
}
|
||||
virtual void InitModel(unsigned num_cluster, unsigned feat_dim) {
|
||||
centroids.Init(num_cluster, feat_dim);
|
||||
@@ -153,7 +185,7 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
}
|
||||
model.Normalize();
|
||||
rabit::CheckPoint(&model);
|
||||
rabit::LazyCheckPoint(&model);
|
||||
}
|
||||
// output the model file to somewhere
|
||||
if (rabit::GetRank() == 0) {
|
||||
|
||||
Reference in New Issue
Block a user