add kmeans
This commit is contained in:
parent
69af79d45d
commit
a1a1a8895e
@ -81,8 +81,8 @@ inline size_t GetCluster(const Matrix ¢roids,
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
if (argc < 4) {
|
||||
printf("Usage: <data_dir> num_cluster max_iter\n");
|
||||
if (argc < 5) {
|
||||
printf("Usage: <data_dir> num_cluster max_iter <out_model>\n");
|
||||
return 0;
|
||||
}
|
||||
srand(0);
|
||||
@ -131,6 +131,11 @@ int main(int argc, char *argv[]) {
|
||||
model.Normalize();
|
||||
rabit::CheckPoint(model);
|
||||
}
|
||||
// output the model file to somewhere
|
||||
if (rabit::GetRank() == 0) {
|
||||
model.centroids.Print(argv[4]);
|
||||
}
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -77,6 +77,25 @@ struct Matrix {
|
||||
inline const float *operator[](size_t i) const {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline void Print(const char *fname) {
|
||||
FILE *fo;
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fo = stdout;
|
||||
} else {
|
||||
fo = utils::FopenCheck(fname, "r");
|
||||
}
|
||||
fprintf(fo, "%lu %lu\n", nrow, ncol);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
fprintf(fo, "%g", data[i]);
|
||||
if ((i+1) % ncol == 0) {
|
||||
fprintf(fo, "\n");
|
||||
} else {
|
||||
fprintf(fo, " ");
|
||||
}
|
||||
}
|
||||
// close the filed
|
||||
if (fo != stdout) fclose(fo);
|
||||
}
|
||||
// number of data
|
||||
size_t nrow, ncol;
|
||||
std::vector<float> data;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user