add kmeans

This commit is contained in:
tqchen 2014-12-03 18:23:58 -08:00
parent 69af79d45d
commit a1a1a8895e
2 changed files with 26 additions and 2 deletions

View File

@ -81,8 +81,8 @@ inline size_t GetCluster(const Matrix &centroids,
}
int main(int argc, char *argv[]) {
if (argc < 4) {
printf("Usage: <data_dir> num_cluster max_iter\n");
if (argc < 5) {
printf("Usage: <data_dir> num_cluster max_iter <out_model>\n");
return 0;
}
srand(0);
@ -131,6 +131,11 @@ int main(int argc, char *argv[]) {
model.Normalize();
rabit::CheckPoint(model);
}
// output the model file to somewhere
if (rabit::GetRank() == 0) {
model.centroids.Print(argv[4]);
}
rabit::Finalize();
return 0;
}

View File

@ -77,6 +77,25 @@ struct Matrix {
inline const float *operator[](size_t i) const {
return &data[0] + i * ncol;
}
inline void Print(const char *fname) {
FILE *fo;
if (!strcmp(fname, "stdout")) {
fo = stdout;
} else {
fo = utils::FopenCheck(fname, "r");
}
fprintf(fo, "%lu %lu\n", nrow, ncol);
for (size_t i = 0; i < data.size(); ++i) {
fprintf(fo, "%g", data[i]);
if ((i+1) % ncol == 0) {
fprintf(fo, "\n");
} else {
fprintf(fo, " ");
}
}
// close the filed
if (fo != stdout) fclose(fo);
}
// number of data
size_t nrow, ncol;
std::vector<float> data;