163 lines
4.9 KiB
C++
163 lines
4.9 KiB
C++
// this is a test case to test whether rabit can recover model when
|
|
// facing an exception
|
|
#include <rabit.h>
|
|
#include <rabit/utils.h>
|
|
#include "./toolkit_util.h"
|
|
#include <time.h>
|
|
|
|
using namespace rabit;
|
|
|
|
// kmeans model
|
|
class Model : public rabit::ISerializable {
|
|
public:
|
|
// matrix of centroids
|
|
Matrix centroids;
|
|
// load from stream
|
|
virtual void Load(rabit::IStream &fi) {
|
|
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
|
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
|
fi.Read(¢roids.data);
|
|
}
|
|
/*! \brief save the model to the stream */
|
|
virtual void Save(rabit::IStream &fo) const {
|
|
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
|
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
|
fo.Write(centroids.data);
|
|
}
|
|
virtual void InitModel(unsigned num_cluster, unsigned feat_dim) {
|
|
centroids.Init(num_cluster, feat_dim);
|
|
}
|
|
// normalize L2 norm
|
|
inline void Normalize(void) {
|
|
for (size_t i = 0; i < centroids.nrow; ++i) {
|
|
float *row = centroids[i];
|
|
double wsum = 0.0;
|
|
for (size_t j = 0; j < centroids.ncol; ++j) {
|
|
wsum += row[j] * row[j];
|
|
}
|
|
wsum = sqrt(wsum);
|
|
if (wsum < 1e-6) return;
|
|
float winv = 1.0 / wsum;
|
|
for (size_t j = 0; j < centroids.ncol; ++j) {
|
|
row[j] *= winv;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
inline void InitCentroids(const SparseMat &data, Matrix *centroids) {
|
|
int num_cluster = centroids->nrow;
|
|
for (int i = 0; i < num_cluster; ++i) {
|
|
int index = Random(data.NumRow());
|
|
SparseMat::Vector v = data[index];
|
|
for (unsigned j = 0; j < v.length; ++j) {
|
|
(*centroids)[i][v[j].findex] = v[j].fvalue;
|
|
}
|
|
}
|
|
for (int i = 0; i < num_cluster; ++i) {
|
|
int proc = Random(rabit::GetWorldSize());
|
|
rabit::Broadcast((*centroids)[i], centroids->ncol * sizeof(float), proc);
|
|
}
|
|
}
|
|
|
|
inline double Cos(const float *row,
|
|
const SparseMat::Vector &v) {
|
|
double rdot = 0.0, rnorm = 0.0;
|
|
for (unsigned i = 0; i < v.length; ++i) {
|
|
rdot += row[v[i].findex] * v[i].fvalue;
|
|
rnorm += v[i].fvalue * v[i].fvalue;
|
|
}
|
|
return rdot / sqrt(rnorm);
|
|
}
|
|
inline size_t GetCluster(const Matrix ¢roids,
|
|
const SparseMat::Vector &v) {
|
|
size_t imin = 0;
|
|
double dmin = Cos(centroids[0], v);
|
|
for (size_t k = 1; k < centroids.nrow; ++k) {
|
|
double dist = Cos(centroids[k], v);
|
|
if (dist > dmin) {
|
|
dmin = dist; imin = k;
|
|
}
|
|
}
|
|
return imin;
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
if (argc < 5) {
|
|
if (rabit::GetRank() == 0) {
|
|
rabit::TrackerPrintf("Usage: <data_dir> num_cluster max_iter <out_model>\n");
|
|
}
|
|
return 0;
|
|
}
|
|
clock_t tStart = clock();
|
|
|
|
srand(0);
|
|
// load the data
|
|
SparseMat data;
|
|
data.Load(argv[1]);
|
|
// set the parameters
|
|
int num_cluster = atoi(argv[2]);
|
|
int max_iter = atoi(argv[3]);
|
|
// intialize rabit engine
|
|
rabit::Init(argc, argv);
|
|
// load model
|
|
Model model;
|
|
int iter = rabit::LoadCheckPoint(&model);
|
|
if (iter == 0) {
|
|
rabit::Allreduce<op::Max>(&data.feat_dim, 1);
|
|
model.InitModel(num_cluster, data.feat_dim);
|
|
InitCentroids(data, &model.centroids);
|
|
model.Normalize();
|
|
rabit::TrackerPrintf("[%d] start at %s\n",
|
|
rabit::GetRank(), rabit::GetProcessorName().c_str());
|
|
} else {
|
|
rabit::TrackerPrintf("[%d] restart iter=%d\n", rabit::GetRank(), iter);
|
|
}
|
|
const unsigned num_feat = data.feat_dim;
|
|
// matrix to store the result
|
|
Matrix temp;
|
|
for (int r = iter; r < max_iter; ++r) {
|
|
temp.Init(num_cluster, num_feat + 1, 0.0f);
|
|
#if __cplusplus >= 201103L
|
|
auto lazy_get_centroid = [&]()
|
|
#endif
|
|
{
|
|
// lambda function used to calculate the data if necessary
|
|
// this function may not be called when the result can be directly recovered
|
|
const size_t ndata = data.NumRow();
|
|
for (size_t i = 0; i < ndata; ++i) {
|
|
SparseMat::Vector v = data[i];
|
|
size_t k = GetCluster(model.centroids, v);
|
|
// temp[k] += v
|
|
for (size_t j = 0; j < v.length; ++j) {
|
|
temp[k][v[j].findex] += v[j].fvalue;
|
|
}
|
|
// use last column to record counts
|
|
temp[k][num_feat] += 1.0f;
|
|
}
|
|
};
|
|
// call allreduce
|
|
#if __cplusplus >= 201103L
|
|
rabit::Allreduce<op::Sum>(&temp.data[0], temp.data.size(), lazy_get_centroid);
|
|
#else
|
|
rabit::Allreduce<op::Sum>(&temp.data[0], temp.data.size());
|
|
#endif
|
|
// set number
|
|
for (int k = 0; k < num_cluster; ++k) {
|
|
float cnt = temp[k][num_feat];
|
|
utils::Check(cnt != 0.0f, "get zero sized cluster");
|
|
for (unsigned i = 0; i < num_feat; ++i) {
|
|
model.centroids[k][i] = temp[k][i] / cnt;
|
|
}
|
|
}
|
|
model.Normalize();
|
|
rabit::CheckPoint(&model);
|
|
}
|
|
// output the model file to somewhere
|
|
if (rabit::GetRank() == 0) {
|
|
model.centroids.Print(argv[4]);
|
|
}
|
|
rabit::TrackerPrintf("[%d] Time taken: %f seconds\n", rabit::GetRank(), static_cast<float>(clock() - tStart) / CLOCKS_PER_SEC);
|
|
rabit::Finalize();
|
|
return 0;
|
|
}
|