diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 419a1de47..ed22d7caf 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -332,8 +332,8 @@ class TreeModel { CHECK_EQ(param.num_nodes, static_cast(stats.size())); fo->Write(¶m, sizeof(TreeParam)); CHECK_NE(param.num_nodes, 0); - fo->Write(BeginPtr(nodes), sizeof(Node) * nodes.size()); - fo->Write(BeginPtr(stats), sizeof(NodeStat) * nodes.size()); + fo->Write(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()); + fo->Write(dmlc::BeginPtr(stats), sizeof(NodeStat) * nodes.size()); if (param.size_leaf_vector != 0) fo->Write(leaf_vector); } /*! diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h new file mode 100644 index 000000000..7efc20289 --- /dev/null +++ b/include/xgboost/tree_updater.h @@ -0,0 +1,83 @@ +/*! + * Copyright 2014 by Contributors + * \file tree_updater.h + * \brief General primitive for tree learning, + * Updating a collection of trees given the information. + * \author Tianqi Chen + */ +#ifndef XGBOOST_TREE_UPDATER_H_ +#define XGBOOST_TREE_UPDATER_H_ + +#include +#include +#include +#include "./base.h" +#include "./data.h" +#include "./tree_model.h" + +namespace xgboost { +/*! + * \brief interface of tree update module, that performs update of a tree. + */ +class TreeUpdater { + public: + /*! \brief virtual destructor */ + virtual ~TreeUpdater() {} + /*! + * \brief Initialize the updater with given arguments. + * \param args arguments to the objective function. + */ + virtual void Init(const std::vector >& args) = 0; + /*! + * \brief perform update to the tree models + * \param gpair the gradient pair statistics of the data + * \param dmat The data matrix passed to the updater. + * \param trees references the trees to be updated, updater will change the content of trees + * note: all the trees in the vector are updated, with the same statistics, + * but maybe different random seeds, usually one tree is passed in at a time, + * there can be multiple trees when we train random forest style model + */ + virtual void Update(const std::vector& gpair, + DMatrix* data, + const std::vector& trees) = 0; + /*! + * \brief this is simply a function for optimizing performance + * this function asks the updater to return the leaf position of each instance in the previous performed update. + * if it is cached in the updater, if it is not available, return nullptr + * \return array of leaf position of each instance in the last updated tree + */ + virtual const int* GetLeafPosition() const { + return nullptr; + } + /*! + * \brief Create a tree updater given name + * \param name Name of the tree updater. + */ + static TreeUpdater* Create(const char* name); +}; + +/*! + * \brief Registry entry for tree updater. + */ +struct TreeUpdaterReg + : public dmlc::FunctionRegEntryBase > { +}; + +/*! + * \brief Macro to register tree updater. + * + * \code + * // example of registering a objective ndcg@k + * XGBOOST_REGISTER_METRIC(ColMaker, "colmaker") + * .describe("Column based tree maker.") + * .set_body([]() { + * return new ColMaker(); + * }); + * \endcode + */ +#define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \ + static ::xgboost::TreeUpdaterReg & __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \ + ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(#Name) +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_H_ diff --git a/old_src/tree/updater.h b/old_src/tree/updater.h deleted file mode 100644 index ff4da5e98..000000000 --- a/old_src/tree/updater.h +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2014 by Contributors - * \file updater.h - * \brief interface to update the tree - * \author Tianqi Chen - */ -#ifndef XGBOOST_TREE_UPDATER_H_ -#define XGBOOST_TREE_UPDATER_H_ - -#include - -#include "../data.h" -#include "./model.h" - -namespace xgboost { -namespace tree { -/*! - * \brief interface of tree update module, that performs update of a tree - */ -class IUpdater { - public: - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - virtual void SetParam(const char *name, const char *val) = 0; - /*! - * \brief perform update to the tree models - * \param gpair the gradient pair statistics of the data - * \param p_fmat feature matrix that provide access to features - * \param info extra side information that may be need, such as root index - * \param trees references the trees to be updated, updater will change the content of trees - * note: all the trees in the vector are updated, with the same statistics, - * but maybe different random seeds, usually one tree is passed in at a time, - * there can be multiple trees when we train random forest style model - */ - virtual void Update(const std::vector &gpair, - IFMatrix *p_fmat, - const BoosterInfo &info, - const std::vector &trees) = 0; - - /*! - * \brief this is simply a function for optimizing performance - * this function asks the updater to return the leaf position of each instance in the p_fmat, - * if it is cached in the updater, if it is not available, return NULL - * \return array of leaf position of each instance in the last updated tree - */ - virtual const int* GetLeafPosition(void) const { - return NULL; - } - // destructor - virtual ~IUpdater(void) {} -}; -/*! - * \brief create an updater based on name - * \param name name of updater - * \return return the updater instance - */ -IUpdater* CreateUpdater(const char *name); -} // namespace tree -} // namespace xgboost -#endif // XGBOOST_TREE_UPDATER_H_ diff --git a/old_src/tree/updater_sync-inl.hpp b/old_src/tree/updater_sync-inl.hpp deleted file mode 100644 index e76d1f76d..000000000 --- a/old_src/tree/updater_sync-inl.hpp +++ /dev/null @@ -1,56 +0,0 @@ -/*! - * Copyright 2014 by Contributors - * \file updater_sync-inl.hpp - * \brief synchronize the tree in all distributed nodes - * \author Tianqi Chen - */ -#ifndef XGBOOST_TREE_UPDATER_SYNC_INL_HPP_ -#define XGBOOST_TREE_UPDATER_SYNC_INL_HPP_ - -#include -#include -#include -#include "../sync/sync.h" -#include "./updater.h" - -namespace xgboost { -namespace tree { -/*! - * \brief syncher that synchronize the tree in all distributed nodes - * can implement various strategies, so far it is always set to node 0's tree - */ -class TreeSyncher: public IUpdater { - public: - virtual ~TreeSyncher(void) {} - virtual void SetParam(const char *name, const char *val) { - } - // update the tree, do pruning - virtual void Update(const std::vector &gpair, - IFMatrix *p_fmat, - const BoosterInfo &info, - const std::vector &trees) { - this->SyncTrees(trees); - } - - private: - // synchronize the trees in different nodes, take tree from rank 0 - inline void SyncTrees(const std::vector &trees) { - if (rabit::GetWorldSize() == 1) return; - std::string s_model; - utils::MemoryBufferStream fs(&s_model); - int rank = rabit::GetRank(); - if (rank == 0) { - for (size_t i = 0; i < trees.size(); ++i) { - trees[i]->SaveModel(fs); - } - } - fs.Seek(0); - rabit::Broadcast(&s_model, 0); - for (size_t i = 0; i < trees.size(); ++i) { - trees[i]->LoadModel(fs); - } - } -}; -} // namespace tree -} // namespace xgboost -#endif // XGBOOST_TREE_UPDATER_SYNC_INL_HPP_ diff --git a/old_src/utils/io.h b/old_src/utils/io.h deleted file mode 100644 index 1fd09310e..000000000 --- a/old_src/utils/io.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright 2014 by Contributors - * \file io.h - * \brief general stream interface for serialization, I/O - * \author Tianqi Chen - */ - -#ifndef XGBOOST_UTILS_IO_H_ -#define XGBOOST_UTILS_IO_H_ -#include -#include -#include -#include -#include "./utils.h" -#include "../sync/sync.h" - -namespace xgboost { -namespace utils { -// reuse the definitions of streams -typedef rabit::Stream IStream; -typedef rabit::utils::SeekStream ISeekStream; -typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer; -typedef rabit::utils::MemoryBufferStream MemoryBufferStream; - -/*! \brief implementation of file i/o stream */ -class FileStream : public ISeekStream { - public: - explicit FileStream(std::FILE *fp) : fp(fp) {} - FileStream(void) { - this->fp = NULL; - } - virtual size_t Read(void *ptr, size_t size) { - return std::fread(ptr, size, 1, fp); - } - virtual void Write(const void *ptr, size_t size) { - Check(std::fwrite(ptr, size, 1, fp) == 1, "FileStream::Write: fwrite error!"); - } - virtual void Seek(size_t pos) { - std::fseek(fp, static_cast(pos), SEEK_SET); // NOLINT(*) - } - virtual size_t Tell(void) { - return std::ftell(fp); - } - virtual bool AtEnd(void) const { - return std::feof(fp) != 0; - } - inline void Close(void) { - if (fp != NULL) { - std::fclose(fp); fp = NULL; - } - } - - private: - std::FILE *fp; -}; -} // namespace utils -} // namespace xgboost -#include "./base64-inl.h" -#endif // XGBOOST_UTILS_IO_H_ diff --git a/src/common/io.h b/src/common/io.h new file mode 100644 index 000000000..86d3bf1ab --- /dev/null +++ b/src/common/io.h @@ -0,0 +1,20 @@ +/*! + * Copyright 2014 by Contributors + * \file io.h + * \brief general stream interface for serialization, I/O + * \author Tianqi Chen + */ + +#ifndef XGBOOST_COMMON_IO_H_ +#define XGBOOST_COMMON_IO_H_ + +#include +#include "./sync.h" + +namespace xgboost { +namespace common { +typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer; +typedef rabit::utils::MemoryBufferStream MemoryBufferStream; +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_IO_H_ diff --git a/include/xgboost/sync.h b/src/common/sync.h similarity index 77% rename from include/xgboost/sync.h rename to src/common/sync.h index d0f18c8bb..c8cc9be14 100644 --- a/include/xgboost/sync.h +++ b/src/common/sync.h @@ -5,8 +5,8 @@ * redirects to rabit header * \author Tianqi Chen */ -#ifndef XGBOOST_SYNC_H_ -#define XGBOOST_SYNC_H_ +#ifndef XGBOOST_COMMON_SYNC_H_ +#define XGBOOST_COMMON_SYNC_H_ #include diff --git a/src/global.cc b/src/global.cc index 5a7b03ac0..7b9a41cc2 100644 --- a/src/global.cc +++ b/src/global.cc @@ -5,11 +5,12 @@ */ #include #include -#include +#include namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); +DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); } // namespace dmlc namespace xgboost { @@ -42,8 +43,14 @@ Metric* Metric::Create(const char* name) { } } -void test() { - RegTree tree; +// implement factory functions +TreeUpdater* TreeUpdater::Create(const char* name) { + auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown tree updater " << name; + } + return (e->body)(); } + } // namespace xgboost diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index bd0bb51d8..bccee0ddf 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -5,9 +5,9 @@ * \author Kailong Chen, Tianqi Chen */ #include -#include #include #include "../common/math.h" +#include "../common/sync.h" namespace xgboost { namespace metric { diff --git a/src/metric/multiclass_metric.cc b/src/metric/multiclass_metric.cc index 51073b105..cd10168f9 100644 --- a/src/metric/multiclass_metric.cc +++ b/src/metric/multiclass_metric.cc @@ -5,8 +5,8 @@ * \author Kailong Chen, Tianqi Chen */ #include -#include #include +#include "../common/sync.h" #include "../common/math.h" namespace xgboost { diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index f74b5aa8c..ee2a0c948 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -5,8 +5,8 @@ * \author Kailong Chen, Tianqi Chen */ #include -#include #include +#include "../common/sync.h" #include "../common/math.h" namespace xgboost { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 4e2f0c3c6..cb0d9d132 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -6,8 +6,9 @@ #include #include - namespace xgboost { +// register tree parameter +DMLC_REGISTER_PARAMETER(TreeParam); // internal function to dump regression tree to text void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*) diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc new file mode 100644 index 000000000..a620833f1 --- /dev/null +++ b/src/tree/updater_sync.cc @@ -0,0 +1,50 @@ +/*! + * Copyright 2014 by Contributors + * \file updater_sync.cc + * \brief synchronize the tree in all distributed nodes + */ +#include +#include +#include +#include +#include "../common/sync.h" +#include "../common/io.h" + +namespace xgboost { +namespace tree { +/*! + * \brief syncher that synchronize the tree in all distributed nodes + * can implement various strategies, so far it is always set to node 0's tree + */ +class TreeSyncher: public TreeUpdater { + public: + void Init(const std::vector >& args) override {} + + void Update(const std::vector &gpair, + DMatrix* dmat, + const std::vector &trees) override { + if (rabit::GetWorldSize() == 1) return; + std::string s_model; + common::MemoryBufferStream fs(&s_model); + int rank = rabit::GetRank(); + if (rank == 0) { + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->SaveModel(&fs); + } + } + fs.Seek(0); + rabit::Broadcast(&s_model, 0); + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->LoadModel(&fs); + } + } +}; + +XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") +.describe("Syncher that synchronize the tree in all distributed nodes.") +.set_body([]() { + return new TreeSyncher(); + }); +} // namespace tree +} // namespace xgboost +