Fix compiler warnings. (#8022)
- Remove/fix unused parameters - Remove deprecated code in rabit. - Update dmlc-core.
This commit is contained in:
@@ -144,74 +144,13 @@ class AllreduceBase : public IEngine {
|
||||
"Broadcast failed");
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to guarantees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \brief deprecated
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to guarantees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met(see detailed explanation).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in current version finishes.
|
||||
* In another words, global_model model can be changed only between last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||
*
|
||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||
* improve efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to guarantees that global_model
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
version_number += 1;
|
||||
}
|
||||
int LoadCheckPoint() override { return 0; }
|
||||
|
||||
// deprecated, increase internal version number
|
||||
void CheckPoint() override { version_number += 1; }
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
@@ -65,31 +65,21 @@ class AllreduceMock : public AllreduceBase {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) override {
|
||||
int LoadCheckPoint() override {
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
if (force_local_ == 0) {
|
||||
return AllreduceBase::LoadCheckPoint(global_model, local_model);
|
||||
return AllreduceBase::LoadCheckPoint();
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
return AllreduceBase::LoadCheckPoint(&dum, &com);
|
||||
return AllreduceBase::LoadCheckPoint();
|
||||
}
|
||||
}
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) override {
|
||||
void CheckPoint() override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
|
||||
double tstart = dmlc::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint_;
|
||||
if (force_local_ == 0) {
|
||||
AllreduceBase::CheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceBase::CheckPoint(&dum, &com);
|
||||
}
|
||||
AllreduceBase::CheckPoint();
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
double tcost = dmlc::GetTime() - tstart;
|
||||
if (report_stats_ != 0 && rank == 0) {
|
||||
@@ -105,11 +95,6 @@ class AllreduceMock : public AllreduceBase {
|
||||
tsum_allgather_ = 0.0;
|
||||
}
|
||||
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
|
||||
AllreduceBase::LazyCheckPoint(global_model);
|
||||
}
|
||||
|
||||
protected:
|
||||
// force checkpoint to local
|
||||
int force_local_;
|
||||
@@ -122,30 +107,6 @@ class AllreduceMock : public AllreduceBase {
|
||||
double time_checkpoint_;
|
||||
|
||||
private:
|
||||
struct DummySerializer : public Serializable {
|
||||
void Load(Stream *fi) override {}
|
||||
void Save(Stream *fo) const override {}
|
||||
};
|
||||
struct ComboSerializer : public Serializable {
|
||||
Serializable *lhs;
|
||||
Serializable *rhs;
|
||||
const Serializable *c_lhs;
|
||||
const Serializable *c_rhs;
|
||||
ComboSerializer(Serializable *lhs, Serializable *rhs)
|
||||
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
||||
: lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
void Load(Stream *fi) override {
|
||||
if (lhs != nullptr) lhs->Load(fi);
|
||||
if (rhs != nullptr) rhs->Load(fi);
|
||||
}
|
||||
void Save(Stream *fo) const override {
|
||||
if (c_lhs != nullptr) c_lhs->Save(fo);
|
||||
if (c_rhs != nullptr) c_rhs->Save(fo);
|
||||
}
|
||||
};
|
||||
// key to identify the mock stage
|
||||
struct MockKey {
|
||||
int rank;
|
||||
|
||||
@@ -100,8 +100,7 @@ void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||
mpi::OpType ,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
|
||||
prepare_arg);
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@@ -120,6 +120,7 @@ void Allreduce(void *sendrecvbuf,
|
||||
default: utils::Error("unknown enum_op");
|
||||
}
|
||||
}
|
||||
|
||||
void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
@@ -298,46 +299,6 @@ RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
|
||||
rbt_ulong *out_global_len,
|
||||
char **out_local_model,
|
||||
rbt_ulong *out_local_len) {
|
||||
// no-op as XGBoost 1.3
|
||||
using rabit::BeginPtr;
|
||||
using namespace rabit::c_api; // NOLINT(*)
|
||||
static std::string global_buffer;
|
||||
static std::string local_buffer;
|
||||
|
||||
ReadWrapper sg(&global_buffer);
|
||||
ReadWrapper sl(&local_buffer);
|
||||
int version;
|
||||
|
||||
if (out_local_model == nullptr) {
|
||||
version = rabit::LoadCheckPoint(&sg, nullptr);
|
||||
*out_global_model = BeginPtr(global_buffer);
|
||||
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
|
||||
} else {
|
||||
version = rabit::LoadCheckPoint(&sg, &sl);
|
||||
*out_global_model = BeginPtr(global_buffer);
|
||||
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
|
||||
*out_local_model = BeginPtr(local_buffer);
|
||||
*out_local_len = static_cast<rbt_ulong>(local_buffer.length());
|
||||
}
|
||||
return version;
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
|
||||
const char *local_model, rbt_ulong local_len) {
|
||||
using namespace rabit::c_api; // NOLINT(*)
|
||||
WriteWrapper sg(global_model, global_len);
|
||||
WriteWrapper sl(local_model, local_len);
|
||||
if (local_model == nullptr) {
|
||||
rabit::CheckPoint(&sg, nullptr);
|
||||
} else {
|
||||
rabit::CheckPoint(&sg, &sl);
|
||||
}
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitVersionNumber() {
|
||||
return rabit::VersionNumber();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user