From dcea64c8386f2ea3b5037aa3787de618daedd5ea Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 1 Dec 2014 21:41:37 -0800 Subject: [PATCH] check in model recover --- test/test_model_recover.cpp | 141 ++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 test/test_model_recover.cpp diff --git a/test/test_model_recover.cpp b/test/test_model_recover.cpp new file mode 100644 index 000000000..a7f4d7677 --- /dev/null +++ b/test/test_model_recover.cpp @@ -0,0 +1,141 @@ +// this is a test case to test whether rabit can recover model when +// facing an exception +#include +#include +#include +#include +#include +#include + +using namespace rabit; + +struct MockException { +}; + +// dummy model +class Model : public rabit::utils::ISerializable { + public: + // iterations + std::vector data; + // load from stream + virtual void Load(rabit::utils::IStream &fi) { + fi.Read(&data); + } + /*! \brief save the model to the stream */ + virtual void Save(rabit::utils::IStream &fo) const { + fo.Write(data); + } + virtual void InitModel(size_t n) { + data.resize(n, 1.0f); + } +}; + +inline void TestMax(test::Mock &mock, Model *model, int ntrial, int iter) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = iter + 111; + + std::vector ndata(model->data.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z + model->data[i]; + } + mock.AllReduce(&ndata[0], ndata.size()); + if (ntrial == iter && rank == 3) { + throw MockException(); + } + for (size_t i = 0; i < ndata.size(); ++i) { + float rmax = (i * 1) % z + model->data[i]; + for (int r = 0; r < nproc; ++r) { + rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]); + } + utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank); + } + model->data = ndata; +} + +inline void TestSum(test::Mock &mock, Model *model, int ntrial, int iter) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = 131 + iter; + + std::vector ndata(model->data.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z + model->data[i]; + } + mock.AllReduce(&ndata[0], ndata.size()); + + if (ntrial == iter && rank == 0) { + throw MockException(); + } + + for (size_t i = 0; i < ndata.size(); ++i) { + float rsum = model->data[i] * nproc; + for (int r = 0; r < nproc; ++r) { + rsum += (float)((i * (r+1)) % z); + } + utils::Check(fabsf(rsum - ndata[i]) < 1e-5 , + "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]); + } + model->data = ndata; +} + +inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) { + int rank = rabit::GetRank(); + std::string s; s.resize(n); + for (size_t i = 0; i < n; ++i) { + s[i] = char(i % 126 + 1); + } + std::string res; + if (root == rank) { + res = s; + mock.Broadcast(&res, root); + } else { + mock.Broadcast(&res, root); + } + utils::Check(res == s, "[%d] TestBcast fail", rank); +} + +int main(int argc, char *argv[]) { + if (argc < 3) { + printf("Usage: \n"); + return 0; + } + int n = atoi(argv[1]); + rabit::Init(argc, argv); + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + std::string name = rabit::GetProcessorName(); + test::Mock mock(rank, argv[2], argv[3]); + Model model; + srand(0); + int ntrial = 0; + while (true) { + try { + int iter = rabit::LoadCheckPoint(&model); + if (iter == 0) { + model.InitModel(n); + } else { + utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + } + for (int r = iter; r < 3; ++r) { + TestMax(mock, &model, ntrial, r); + utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + int step = std::max(nproc / 3, 1); + for (int i = 0; i < nproc; i += step) { + TestBcast(mock, n, i, ntrial); + } + utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + TestSum(mock, &model, ntrial, r); + utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + rabit::CheckPoint(model); + utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + } + break; + } catch (MockException &e) { + //rabit::engine::GetEngine()->InitAfterException(); + ++ntrial; + } + } + rabit::Finalize(); + return 0; +}