// this is a test case to test whether rabit can recover model when // facing an exception #include #include #include #include #include #include #include #include using namespace rabit; class Model : public rabit::utils::ISerializable { public: std::vector data; // load from stream virtual void Load(rabit::utils::IStream &fi) { fi.Read(&data); } /*! \brief save the model to the stream */ virtual void Save(rabit::utils::IStream &fo) const { fo.Write(data); } virtual void InitModel(int k, int d) { data.resize(k * d + k, 0.0f); } }; inline void KMeans(int ntrial, int iter, int k, int d, std::vector& data, Model *model) { int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); /* const int z = iter + 111; std::vector ndata(model->data.size()); for (size_t i = 0; i < ndata.size(); ++i) { ndata[i] = (i * (rank+1)) % z + model->data[i]; } rabit::Allreduce(&ndata[0], ndata.size()); if (ntrial == iter && rank == 3) { //throw MockException(); } for (size_t i = 0; i < ndata.size(); ++i) { float rmax = (i * 1) % z + model->data[i]; for (int r = 0; r < nproc; ++r) { rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]); } utils::Check(rmax == ndata[i], "[%d] TestMax check failure\n", rank); } model->data = ndata; */ } inline void ReadData(char* data_dir, int d, std::vector* data) { int rank = rabit::GetRank(); std::stringstream ss; ss << data_dir << rank; const char* file = ss.str().c_str(); std::ifstream ifs(file); utils::Check(ifs.good(), "[%d] File %s does not exist\n", rank, file); float v = 0.0f; while(!ifs.eof()) { ifs >> v; data->push_back(v); } utils::Check(data->size() % d == 0, "[%d] Invalid data size. %d instead of %d\n", rank, data->size(), d); } inline void InitCentroids(int k, int d, std::vector& data, Model* model) { int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::vector candidate_centroids(model->data.size() - k); int elements = data.size() / d; for (size_t i = 0; i < k; ++i) { int index = rand() % elements; int start = index * d; int end = start + d; int cstart = i * d; //utils::LogPrintf("[%d] index=%d,start=%d\n", rank, index, start); for (size_t j = start, l = cstart; j < end; ++j, ++l) { candidate_centroids[l] = data[j]; } } for (size_t i = 0; i < k; ++i) { int proc = rand() % nproc; //utils::LogPrintf("[%d] proc=%d\n", rank, proc); std::string tmp_str; int start = i * d; if (proc == rank) { std::ostringstream tmp; for (size_t j = start, l = 0; l < d ; ++j, ++l) { tmp << candidate_centroids[j]; if (l != d-1) tmp << " "; } tmp_str = tmp.str(); //utils::LogPrintf("[%d] centroid %s\n", rank, tmp_str.c_str()); rabit::Bcast(&tmp_str, proc); } else { rabit::Bcast(&tmp_str, proc); } std::stringstream tmp; tmp.str(tmp_str); float val = 0.0f; int j = start; while(tmp >> val) { model->data[j++] = val; //utils::LogPrintf("[%d] model[%d]=%.5f\n", rank, j-1, model->data[j-1]); } //count model->data[j] = 0; } } int main(int argc, char *argv[]) { if (argc < 4) { printf("Usage: \n"); return 0; } int k = atoi(argv[1]); int d = atoi(argv[2]); int max_itr = atoi(argv[3]); rabit::Init(argc, argv); int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::string name = rabit::GetProcessorName(); srand(0); int ntrial = 0; Model model; std::vector data; int iter = rabit::LoadCheckPoint(&model); if (iter == 0) { ReadData(argv[4], d, &data); model.InitModel(k, d); InitCentroids(k, d, data, &model); } else { utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } for (int r = iter; r < max_itr; ++r) { KMeans(ntrial, r, k, d, data, &model); } rabit::Finalize(); return 0; }