change default behavior to behave normal
This commit is contained in:
@@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
// initialize the manager
|
||||
void Init(void);
|
||||
virtual void Init(void);
|
||||
// shutdown the engine
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
|
||||
@@ -20,10 +20,16 @@ namespace rabit {
|
||||
namespace engine {
|
||||
AllreduceRobust::AllreduceRobust(void) {
|
||||
num_local_replica = 0;
|
||||
num_global_replica = 5;
|
||||
default_local_replica = 2;
|
||||
seq_counter = 0;
|
||||
local_chkpt_version = 0;
|
||||
result_buffer_round = 1;
|
||||
}
|
||||
void AllreduceRobust::Init(void) {
|
||||
AllreduceBase::Init();
|
||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
void AllreduceRobust::Shutdown(void) {
|
||||
// need to sync the exec before we shutdown, do a pesudo check point
|
||||
@@ -44,10 +50,7 @@ void AllreduceRobust::Shutdown(void) {
|
||||
*/
|
||||
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||
AllreduceBase::SetParam(name, val);
|
||||
if (!strcmp(name, "rabit_buffer_round")) result_buffer_round = atoi(val);
|
||||
if (!strcmp(name, "rabit_global_replica")) {
|
||||
result_buffer_round = std::max(world_size / atoi(val), 1);
|
||||
}
|
||||
if (!strcmp(name, "rabit_global_replica")) num_global_replica = atoi(val);
|
||||
if (!strcmp(name, "rabit_local_replica")) {
|
||||
num_local_replica = atoi(val);
|
||||
}
|
||||
@@ -151,9 +154,12 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
// skip action in single node
|
||||
if (world_size == 1) return 0;
|
||||
if (local_model != NULL && num_local_replica == 0) {
|
||||
num_local_replica = default_local_replica;
|
||||
}
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL,
|
||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
// check if we succesful
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||
@@ -214,9 +220,12 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
||||
if (world_size == 1) {
|
||||
version_number += 1; return;
|
||||
}
|
||||
if (local_model != NULL && num_local_replica == 0) {
|
||||
num_local_replica = default_local_replica;
|
||||
}
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL,
|
||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
if (num_local_replica != 0) {
|
||||
while (true) {
|
||||
|
||||
@@ -23,6 +23,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
public:
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(void);
|
||||
/*! \brief shutdown the engine */
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
@@ -468,6 +470,10 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
std::string global_checkpoint;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
// number of default local replica
|
||||
int default_local_replica;
|
||||
// number of replica for global state/model
|
||||
int num_global_replica;
|
||||
// --- recovery data structure for local checkpoint
|
||||
// there is two version of the data structure,
|
||||
// at one time one version is valid and another is used as temp memory
|
||||
|
||||
Reference in New Issue
Block a user