[TREE] Enable updater registry
This commit is contained in:
20
src/common/io.h
Normal file
20
src/common/io.h
Normal file
@@ -0,0 +1,20 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file io.h
|
||||
* \brief general stream interface for serialization, I/O
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_COMMON_IO_H_
|
||||
#define XGBOOST_COMMON_IO_H_
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include "./sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
||||
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_IO_H_
|
||||
13
src/common/sync.h
Normal file
13
src/common/sync.h
Normal file
@@ -0,0 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file sync.h
|
||||
* \brief the synchronization module of rabit
|
||||
* redirects to rabit header
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_SYNC_H_
|
||||
#define XGBOOST_COMMON_SYNC_H_
|
||||
|
||||
#include <rabit.h>
|
||||
|
||||
#endif // XGBOOST_SYNC_H_
|
||||
@@ -5,11 +5,12 @@
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
@@ -42,8 +43,14 @@ Metric* Metric::Create(const char* name) {
|
||||
}
|
||||
}
|
||||
|
||||
void test() {
|
||||
RegTree tree;
|
||||
// implement factory functions
|
||||
TreeUpdater* TreeUpdater::Create(const char* name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
#include "../common/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/sync.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/sync.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
DMLC_REGISTER_PARAMETER(TreeParam);
|
||||
|
||||
// internal function to dump regression tree to text
|
||||
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
|
||||
|
||||
50
src/tree/updater_sync.cc
Normal file
50
src/tree/updater_sync.cc
Normal file
@@ -0,0 +1,50 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file updater_sync.cc
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include "../common/sync.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*!
|
||||
* \brief syncher that synchronize the tree in all distributed nodes
|
||||
* can implement various strategies, so far it is always set to node 0's tree
|
||||
*/
|
||||
class TreeSyncher: public TreeUpdater {
|
||||
public:
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {}
|
||||
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
if (rabit::GetWorldSize() == 1) return;
|
||||
std::string s_model;
|
||||
common::MemoryBufferStream fs(&s_model);
|
||||
int rank = rabit::GetRank();
|
||||
if (rank == 0) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(&fs);
|
||||
}
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->LoadModel(&fs);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||
.set_body([]() {
|
||||
return new TreeSyncher();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user