From 532575b75257a52df7a624a844b37f39551468d0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 13 Jan 2015 14:41:37 -0800 Subject: [PATCH] ok --- test/test_local_recover.py | 25 +++++++++++++++++++++++++ wrapper/rabit.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100755 test/test_local_recover.py diff --git a/test/test_local_recover.py b/test/test_local_recover.py new file mode 100755 index 000000000..02dcc3e7f --- /dev/null +++ b/test/test_local_recover.py @@ -0,0 +1,25 @@ +#!/usr/bin/python +import rabit +import numpy as np + +rabit.init(with_mock = True) +rank = rabit.get_rank() +n = 10 +nround = 3 +data = np.ones(n) * rank + +version, model, local = rabit.load_checkpoint(True) +if version == 0: + model = np.zeros(n) + local = np.ones(n) +else: + print '[%d] restart from version %d' % (rank, version) + +for i in xrange(version, nround): + res = rabit.allreduce(data + model+local, rabit.SUM) + print '[%d] iter=%d: %s' % (rank, i, str(res)) + model = res + local[:] = i + rabit.checkpoint(model, local) + +rabit.finalize() diff --git a/wrapper/rabit.py b/wrapper/rabit.py index cd380bf2f..a4932a0a4 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -223,7 +223,7 @@ def load_model__(ptr, length): length: int the length of buffer """ - data = (ctypes.c_char * length).from_address(addressof(ptr.contents)) + data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents)) return pickle.loads(data.raw) def load_checkpoint(with_local = False):