[TREE] Enable updater registry

This commit is contained in:
tqchen
2016-01-01 03:32:40 -08:00
parent a62a66d545
commit c8ccb61b9e
13 changed files with 172 additions and 189 deletions

20
src/common/io.h Normal file
View File

@@ -0,0 +1,20 @@
/*!
* Copyright 2014 by Contributors
* \file io.h
* \brief general stream interface for serialization, I/O
* \author Tianqi Chen
*/
#ifndef XGBOOST_COMMON_IO_H_
#define XGBOOST_COMMON_IO_H_
#include <dmlc/io.h>
#include "./sync.h"
namespace xgboost {
namespace common {
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_

13
src/common/sync.h Normal file
View File

@@ -0,0 +1,13 @@
/*!
* Copyright 2014 by Contributors
* \file sync.h
* \brief the synchronization module of rabit
* redirects to rabit header
* \author Tianqi Chen
*/
#ifndef XGBOOST_COMMON_SYNC_H_
#define XGBOOST_COMMON_SYNC_H_
#include <rabit.h>
#endif // XGBOOST_SYNC_H_

View File

@@ -5,11 +5,12 @@
*/
#include <xgboost/objective.h>
#include <xgboost/metric.h>
#include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
} // namespace dmlc
namespace xgboost {
@@ -42,8 +43,14 @@ Metric* Metric::Create(const char* name) {
}
}
void test() {
RegTree tree;
// implement factory functions
TreeUpdater* TreeUpdater::Create(const char* name) {
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name;
}
return (e->body)();
}
} // namespace xgboost

View File

@@ -5,9 +5,9 @@
* \author Kailong Chen, Tianqi Chen
*/
#include <xgboost/metric.h>
#include <xgboost/sync.h>
#include <cmath>
#include "../common/math.h"
#include "../common/sync.h"
namespace xgboost {
namespace metric {

View File

@@ -5,8 +5,8 @@
* \author Kailong Chen, Tianqi Chen
*/
#include <xgboost/metric.h>
#include <xgboost/sync.h>
#include <cmath>
#include "../common/sync.h"
#include "../common/math.h"
namespace xgboost {

View File

@@ -5,8 +5,8 @@
* \author Kailong Chen, Tianqi Chen
*/
#include <xgboost/metric.h>
#include <xgboost/sync.h>
#include <cmath>
#include "../common/sync.h"
#include "../common/math.h"
namespace xgboost {

View File

@@ -6,8 +6,9 @@
#include <xgboost/tree_model.h>
#include <sstream>
namespace xgboost {
// register tree parameter
DMLC_REGISTER_PARAMETER(TreeParam);
// internal function to dump regression tree to text
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)

50
src/tree/updater_sync.cc Normal file
View File

@@ -0,0 +1,50 @@
/*!
* Copyright 2014 by Contributors
* \file updater_sync.cc
* \brief synchronize the tree in all distributed nodes
*/
#include <xgboost/tree_updater.h>
#include <vector>
#include <string>
#include <limits>
#include "../common/sync.h"
#include "../common/io.h"
namespace xgboost {
namespace tree {
/*!
* \brief syncher that synchronize the tree in all distributed nodes
* can implement various strategies, so far it is always set to node 0's tree
*/
class TreeSyncher: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {}
void Update(const std::vector<bst_gpair> &gpair,
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
if (rabit::GetWorldSize() == 1) return;
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank();
if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(&fs);
}
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(&fs);
}
}
};
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
.describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([]() {
return new TreeSyncher();
});
} // namespace tree
} // namespace xgboost