From a1a1a8895e3c9d5301c678c217d3100985c4ee96 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 3 Dec 2014 18:23:58 -0800 Subject: [PATCH] add kmeans --- toolkit/kmeans.cpp | 9 +++++++-- toolkit/toolkit_util.h | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/toolkit/kmeans.cpp b/toolkit/kmeans.cpp index 5811f7fc8..c8884417c 100644 --- a/toolkit/kmeans.cpp +++ b/toolkit/kmeans.cpp @@ -81,8 +81,8 @@ inline size_t GetCluster(const Matrix ¢roids, } int main(int argc, char *argv[]) { - if (argc < 4) { - printf("Usage: num_cluster max_iter\n"); + if (argc < 5) { + printf("Usage: num_cluster max_iter \n"); return 0; } srand(0); @@ -131,6 +131,11 @@ int main(int argc, char *argv[]) { model.Normalize(); rabit::CheckPoint(model); } + // output the model file to somewhere + if (rabit::GetRank() == 0) { + model.centroids.Print(argv[4]); + } rabit::Finalize(); return 0; } + diff --git a/toolkit/toolkit_util.h b/toolkit/toolkit_util.h index 71bf888d0..e1ccc7003 100644 --- a/toolkit/toolkit_util.h +++ b/toolkit/toolkit_util.h @@ -77,6 +77,25 @@ struct Matrix { inline const float *operator[](size_t i) const { return &data[0] + i * ncol; } + inline void Print(const char *fname) { + FILE *fo; + if (!strcmp(fname, "stdout")) { + fo = stdout; + } else { + fo = utils::FopenCheck(fname, "r"); + } + fprintf(fo, "%lu %lu\n", nrow, ncol); + for (size_t i = 0; i < data.size(); ++i) { + fprintf(fo, "%g", data[i]); + if ((i+1) % ncol == 0) { + fprintf(fo, "\n"); + } else { + fprintf(fo, " "); + } + } + // close the filed + if (fo != stdout) fclose(fo); + } // number of data size_t nrow, ncol; std::vector data;