Merge remote-tracking branch 'upstream/master'

This commit is contained in:
El Potaeto 2015-03-11 22:14:35 +01:00
commit 09091884be
66 changed files with 1998 additions and 779 deletions

View File

@ -16,6 +16,13 @@ ifeq ($(cxx11),1)
else else
endif endif
ifeq ($(hdfs),1)
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(JAVA_HOME)/jre/lib/amd64/server -lhdfs -ljvm
else
CFLAGS+= -DRABIT_USE_HDFS=0
endif
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
MOCKBIN = xgboost.mock MOCKBIN = xgboost.mock
@ -83,8 +90,10 @@ Rpack:
cp -r src xgboost/src/src cp -r src xgboost/src/src
mkdir xgboost/src/subtree mkdir xgboost/src/subtree
mkdir xgboost/src/subtree/rabit mkdir xgboost/src/subtree/rabit
mkdir xgboost/src/subtree/rabit/rabit-learn
cp -r subtree/rabit/include xgboost/src/subtree/rabit/include cp -r subtree/rabit/include xgboost/src/subtree/rabit/include
cp -r subtree/rabit/src xgboost/src/subtree/rabit/src cp -r subtree/rabit/src xgboost/src/subtree/rabit/src
cp -r subtree/rabit/rabit-learn/io xgboost/src/subtree/rabit/rabit-learn/io
rm -rf xgboost/src/subtree/rabit/src/*.o rm -rf xgboost/src/subtree/rabit/src/*.o
mkdir xgboost/src/wrapper mkdir xgboost/src/wrapper
cp wrapper/xgboost_wrapper.h xgboost/src/wrapper cp wrapper/xgboost_wrapper.h xgboost/src/wrapper

View File

@ -24,6 +24,7 @@ Learning about the model: [Introduction to Boosted Trees](http://homes.cs.washin
What's New What's New
===== =====
* [Distributed XGBoost now runs on YARN](multi-node/hadoop)!
* [xgboost user group](https://groups.google.com/forum/#!forum/xgboost-user/) for tracking changes, sharing your experience on xgboost * [xgboost user group](https://groups.google.com/forum/#!forum/xgboost-user/) for tracking changes, sharing your experience on xgboost
* [Distributed XGBoost](multi-node) is now available!! * [Distributed XGBoost](multi-node) is now available!!
* New features in the lastest changes :) * New features in the lastest changes :)

View File

@ -6,7 +6,7 @@ require(methods)
testsize <- 550000 testsize <- 550000
dtrain <- read.csv("data/training.csv", header=TRUE, nrows=350001) dtrain <- read.csv("data/training.csv", header=TRUE, nrows=350001)
dtrain$Label = as.numeric(dtrain$Label=='s')
# gbm.time = system.time({ # gbm.time = system.time({
# gbm.model <- gbm(Label ~ ., data = dtrain[, -c(1,32)], n.trees = 120, # gbm.model <- gbm(Label ~ ., data = dtrain[, -c(1,32)], n.trees = 120,
# interaction.depth = 6, shrinkage = 0.1, bag.fraction = 1, # interaction.depth = 6, shrinkage = 0.1, bag.fraction = 1,
@ -15,8 +15,8 @@ dtrain <- read.csv("data/training.csv", header=TRUE, nrows=350001)
# print(gbm.time) # print(gbm.time)
# Test result: 761.48 secs # Test result: 761.48 secs
dtrain[33] <- dtrain[33] == "s" # dtrain[33] <- dtrain[33] == "s"
label <- as.numeric(dtrain[[33]]) # label <- as.numeric(dtrain[[33]])
data <- as.matrix(dtrain[2:31]) data <- as.matrix(dtrain[2:31])
weight <- as.numeric(dtrain[[32]]) * testsize / length(label) weight <- as.numeric(dtrain[[32]]) * testsize / length(label)
@ -51,21 +51,21 @@ for (i in 1:length(threads)){
xgboost.time xgboost.time
# [[1]] # [[1]]
# user system elapsed # user system elapsed
# 444.98 1.96 450.22 # 99.015 0.051 98.982
# #
# [[2]] # [[2]]
# user system elapsed # user system elapsed
# 188.15 0.82 102.41 # 100.268 0.317 55.473
# #
# [[3]] # [[3]]
# user system elapsed # user system elapsed
# 143.29 0.79 44.18 # 111.682 0.777 35.963
# #
# [[4]] # [[4]]
# user system elapsed # user system elapsed
# 176.60 1.45 34.04 # 149.396 1.851 32.661
# #
# [[5]] # [[5]]
# user system elapsed # user system elapsed
# 180.15 2.85 35.26 # 157.390 5.988 40.949

View File

@ -1,7 +1,7 @@
Distributed XGBoost: Hadoop Version Distributed XGBoost: Hadoop Yarn Version
==== ====
* The script in this fold shows an example of how to run distributed xgboost on hadoop platform. * The script in this fold shows an example of how to run distributed xgboost on hadoop platform with YARN
* It relies on [Rabit Library](https://github.com/tqchen/rabit) (Reliable Allreduce and Broadcast Interface) and Hadoop Streaming. Rabit provides an interface to aggregate gradient values and split statistics, that allow xgboost to run reliably on hadoop. You do not need to care how to update model in each iteration, just use the script ```rabit_hadoop.py```. For those who want to know how it exactly works, plz refer to the main page of [Rabit](https://github.com/tqchen/rabit). * It relies on [Rabit Library](https://github.com/tqchen/rabit) (Reliable Allreduce and Broadcast Interface) and Yarn. Rabit provides an interface to aggregate gradient values and split statistics, that allow xgboost to run reliably on hadoop. You do not need to care how to update model in each iteration, just use the script ```rabit_yarn.py```. For those who want to know how it exactly works, plz refer to the main page of [Rabit](https://github.com/tqchen/rabit).
* Quick start: run ```bash run_mushroom.sh <n_hadoop_workers> <n_thread_per_worker> <path_in_HDFS>``` * Quick start: run ```bash run_mushroom.sh <n_hadoop_workers> <n_thread_per_worker> <path_in_HDFS>```
- This is the hadoop version of binary classification example in the demo folder. - This is the hadoop version of binary classification example in the demo folder.
- More info of the usage of xgboost can be refered to [wiki page](https://github.com/tqchen/xgboost/wiki) - More info of the usage of xgboost can be refered to [wiki page](https://github.com/tqchen/xgboost/wiki)
@ -9,35 +9,32 @@ Distributed XGBoost: Hadoop Version
Before you run the script Before you run the script
==== ====
* Make sure you have set up the hadoop environment. * Make sure you have set up the hadoop environment.
* If you want to only use single machine multi-threading, try single machine examples in the [demo folder](../../demo). - Check variable $HADOOP_PREFIX exists (e.g. run ```echo $HADOOP_PREFIX```)
* Build: run ```bash build.sh``` in the root folder, it will automatically download rabit and build xgboost. - Compile xgboost with hdfs support by typing ```make hdfs=1```
* Check whether the environment variable $HADOOP_HOME exists (e.g. run ```echo $HADOOP_HOME```). If not, please set up hadoop-streaming.jar path in rabit_hadoop.py.
How to Use How to Use
==== ====
* Input data format: LIBSVM format. The example here uses generated data in demo/data folder. * Input data format: LIBSVM format. The example here uses generated data in demo/data folder.
* Put the training data in HDFS (hadoop distributed file system). * Put the training data in HDFS (hadoop distributed file system).
* Use rabit ```rabit_hadoop.py``` to submit training task to hadoop, and save the final model file. * Use rabit ```rabit_yarn.py``` to submit training task to yarn
* Get the final model file from HDFS, and locally do prediction as well as visualization of model. * Get the final model file from HDFS, and locally do prediction as well as visualization of model.
Single machine vs Hadoop version Single machine vs Hadoop version
==== ====
If you have used xgboost (single machine version) before, this section will show you how to run xgboost on hadoop with a slight modification on conf file. If you have used xgboost (single machine version) before, this section will show you how to run xgboost on hadoop with a slight modification on conf file.
* Hadoop version needs to set up how many slave nodes/machines/workers you would like to use at first. * IO: instead of reading and writing file locally, we now use HDFS, put ```hdfs://``` prefix to the address of file you like to access
* IO: instead of reading and writing file locally, hadoop version use "stdin" to read training file and use "stdout" to store the final model file. Therefore, you should change the parameters "data" and "model_out" in conf file to ```data=stdin``` and ```model_out=stdout```. * File cache: ```rabit_yarn.py``` also provide several ways to cache necesary files, including binary file (xgboost), conf file
* File cache: ```rabit_hadoop.py``` also provide several ways to cache necesary files, including binary file (xgboost), conf file, small size of dataset which used for eveluation during the training process, and so on. - ```rabit_yarn.py``` will automatically cache files in the command line. For example, ```rabit_yarn.py -n 3 $localPath/xgboost mushroom.hadoop.conf``` will cache "xgboost" and "mushroom.hadoop.conf".
- Any file used in config file, excluding stdin, should be cached in the script. ```rabit_hadoop.py``` will automatically cache files in the command line. For example, ```rabit_hadoop.py -n 3 -i $hdfsPath/agaricus.txt.train -o $hdfsPath/mushroom.final.model $localPath/xgboost mushroom.hadoop.conf``` will cache "xgboost" and "mushroom.hadoop.conf".
- You could also use "-f" to manually cache one or more files, like ```-f file1 -f file2``` or ```-f file1#file2``` (use "#" to spilt file names). - You could also use "-f" to manually cache one or more files, like ```-f file1 -f file2``` or ```-f file1#file2``` (use "#" to spilt file names).
- The local path of cached files in command is "./". - The local path of cached files in command is "./".
- Since the cached files will be packaged and delivered to hadoop slave nodes, the cached file should not be large. For instance, trying to cache files of GB size may reduce the performance. - Since the cached files will be packaged and delivered to hadoop slave nodes, the cached file should not be large.
* Hadoop version also support evaluting each training round. You just need to modify parameters "eval_train". * Hadoop version also support evaluting each training round. You just need to modify parameters "eval_train".
* More details of submission can be referred to the usage of ```rabit_hadoop.py```. * More details of submission can be referred to the usage of ```rabit_yarn.py```.
* The model saved by hadoop version is compatible with single machine version. * The model saved by hadoop version is compatible with single machine version.
Notes Notes
==== ====
* The code has been tested on MapReduce 1 (MRv1) and YARN. * The code has been tested on YARN.
- We recommend to run it on MapReduce 2 (MRv2, YARN) so that multi-threading can be enabled.
* The code is optimized with multi-threading, so you will want to run one xgboost per node/worker for best performance. * The code is optimized with multi-threading, so you will want to run one xgboost per node/worker for best performance.
- You will want to set <n_thread_per_worker> to be number of cores you have on each machine. - You will want to set <n_thread_per_worker> to be number of cores you have on each machine.
- You will need YARN to set specify number of cores of each worker * It is also possible to submit job with hadoop streaming, however, YARN is highly recommended for efficiency reason

View File

@ -25,10 +25,10 @@ save_period = 0
# eval[test] = "agaricus.txt.test" # eval[test] = "agaricus.txt.test"
# Plz donot modify the following parameters # Plz donot modify the following parameters
# The path of training data # The path of training data, with prefix hdfs
data = stdin #data = hdfs:/data/
# The path of model file # The path of model file
model_out = stdout #model_out =
# split pattern of xgboost # split pattern of xgboost
dsplit = row dsplit = row
# evaluate on training data as well each round # evaluate on training data as well each round

View File

@ -8,11 +8,16 @@ fi
# put the local training file to HDFS # put the local training file to HDFS
hadoop fs -mkdir $3/data hadoop fs -mkdir $3/data
hadoop fs -put ../../demo/data/agaricus.txt.train $3/data hadoop fs -put ../../demo/data/agaricus.txt.train $3/data
hadoop fs -put ../../demo/data/agaricus.txt.test $3/data
../../subtree/rabit/tracker/rabit_hadoop.py -n $1 -nt $2 -i $3/data/agaricus.txt.train -o $3/mushroom.final.model ../../xgboost mushroom.hadoop.conf nthread=$2 # running rabit, pass address in hdfs
../../subtree/rabit/tracker/rabit_yarn.py -n $1 --vcores $2 ../../xgboost mushroom.hadoop.conf nthread=$2\
data=hdfs://$3/data/agaricus.txt.train\
eval[test]=hdfs://$3/data/agaricus.txt.test\
model_out=hdfs://$3/mushroom.final.model
# get the final model file # get the final model file
hadoop fs -get $3/mushroom.final.model/part-00000 ./final.model hadoop fs -get $3/mushroom.final.model final.model
# output prediction task=pred # output prediction task=pred
../../xgboost mushroom.hadoop.conf task=pred model_in=final.model test:data=../../demo/data/agaricus.txt.test ../../xgboost mushroom.hadoop.conf task=pred model_in=final.model test:data=../../demo/data/agaricus.txt.test

View File

@ -69,11 +69,11 @@ class GBTree : public IGradBooster {
trees[i]->SaveModel(fo); trees[i]->SaveModel(fo);
} }
if (tree_info.size() != 0) { if (tree_info.size() != 0) {
fo.Write(&tree_info[0], sizeof(int) * tree_info.size()); fo.Write(BeginPtr(tree_info), sizeof(int) * tree_info.size());
} }
if (mparam.num_pbuffer != 0 && with_pbuffer) { if (mparam.num_pbuffer != 0 && with_pbuffer) {
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float)); fo.Write(BeginPtr(pred_buffer), pred_buffer.size() * sizeof(float));
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned)); fo.Write(BeginPtr(pred_counter), pred_counter.size() * sizeof(unsigned));
} }
} }
// initialize the predic buffer // initialize the predic buffer

View File

@ -14,10 +14,11 @@
namespace xgboost { namespace xgboost {
namespace io { namespace io {
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { DataMatrix* LoadDataMatrix(const char *fname, bool silent,
if (!std::strcmp(fname, "stdin")) { bool savebuffer, bool loadsplit) {
if (!std::strcmp(fname, "stdin") || loadsplit) {
DMatrixSimple *dmat = new DMatrixSimple(); DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadText(fname, silent); dmat->LoadText(fname, silent, loadsplit);
return dmat; return dmat;
} }
int magic; int magic;

View File

@ -19,9 +19,14 @@ typedef learner::DMatrix DataMatrix;
* \param fname file name to be loaded * \param fname file name to be loaded
* \param silent whether print message during loading * \param silent whether print message during loading
* \param savebuffer whether temporal buffer the file if the file is in text format * \param savebuffer whether temporal buffer the file if the file is in text format
* \param loadsplit whether we only load a split of input files
* such that each worker node get a split of the data
* \return a loaded DMatrix * \return a loaded DMatrix
*/ */
DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuffer = true); DataMatrix* LoadDataMatrix(const char *fname,
bool silent,
bool savebuffer,
bool loadsplit);
/*! /*!
* \brief save DataMatrix into stream, * \brief save DataMatrix into stream,
* note: the saved dmatrix format may not be in exactly same as input * note: the saved dmatrix format may not be in exactly same as input

View File

@ -11,12 +11,14 @@
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <sstream>
#include <algorithm> #include <algorithm>
#include "../data.h" #include "../data.h"
#include "../utils/utils.h" #include "../utils/utils.h"
#include "../learner/dmatrix.h" #include "../learner/dmatrix.h"
#include "./io.h" #include "./io.h"
#include "./simple_fmatrix-inl.hpp" #include "./simple_fmatrix-inl.hpp"
#include "../sync/sync.h"
namespace xgboost { namespace xgboost {
namespace io { namespace io {
@ -77,63 +79,59 @@ class DMatrixSimple : public DataMatrix {
return row_ptr_.size() - 2; return row_ptr_.size() - 2;
} }
/*! /*!
* \brief load from text file * \brief load split of input, used in distributed mode
* \param fname name of text data * \param uri the uri of input
* \param loadsplit whether loadsplit of data or all the data
* \param silent whether print information or not * \param silent whether print information or not
*/ */
inline void LoadText(const char* fname, bool silent = false) { inline void LoadText(const char *uri, bool silent = false, bool loadsplit = false) {
using namespace std; int rank = 0, npart = 1;
if (loadsplit) {
rank = rabit::GetRank();
npart = rabit::GetWorldSize();
}
rabit::io::InputSplit *in =
rabit::io::CreateInputSplit(uri, rank, npart);
this->Clear(); this->Clear();
FILE* file; std::string line;
if (!strcmp(fname, "stdin")) { while (in->NextLine(&line)) {
file = stdin; float label;
} else { std::istringstream ss(line);
file = utils::FopenCheck(fname, "r"); std::vector<RowBatch::Entry> feats;
} ss >> label;
float label; bool init = true; while (!ss.eof()) {
char tmp[1024]; RowBatch::Entry e;
std::vector<RowBatch::Entry> feats; if (!(ss >> e.index)) break;
while (fscanf(file, "%s", tmp) == 1) { ss.ignore(32, ':');
RowBatch::Entry e; if (!(ss >> e.fvalue)) break;
if (sscanf(tmp, "%u:%f", &e.index, &e.fvalue) == 2) {
feats.push_back(e); feats.push_back(e);
} else {
if (!init) {
info.labels.push_back(label);
this->AddRow(feats);
}
feats.clear();
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
init = false;
} }
info.labels.push_back(label);
this->AddRow(feats);
} }
delete in;
info.labels.push_back(label);
this->AddRow(feats);
if (!silent) { if (!silent) {
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n", utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()),
static_cast<unsigned long>(info.num_col()), static_cast<unsigned long>(info.num_col()),
static_cast<unsigned long>(row_data_.size()), fname); static_cast<unsigned long>(row_data_.size()), uri);
}
if (file != stdin) {
fclose(file);
} }
// try to load in additional file // try to load in additional file
std::string name = fname; if (!loadsplit) {
std::string gname = name + ".group"; std::string name = uri;
if (info.TryLoadGroup(gname.c_str(), silent)) { std::string gname = name + ".group";
utils::Check(info.group_ptr.back() == info.num_row(), if (info.TryLoadGroup(gname.c_str(), silent)) {
"DMatrix: group data does not match the number of rows in features"); utils::Check(info.group_ptr.back() == info.num_row(),
} "DMatrix: group data does not match the number of rows in features");
std::string wname = name + ".weight"; }
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) { std::string wname = name + ".weight";
utils::Check(info.weights.size() == info.num_row(), if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
"DMatrix: weight data does not match the number of rows in features"); utils::Check(info.weights.size() == info.num_row(),
} "DMatrix: weight data does not match the number of rows in features");
std::string mname = name + ".base_margin"; }
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) { std::string mname = name + ".base_margin";
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
}
} }
} }
/*! /*!

View File

@ -12,7 +12,6 @@
#include <limits> #include <limits>
#include "../sync/sync.h" #include "../sync/sync.h"
#include "../utils/io.h" #include "../utils/io.h"
#include "../utils/base64.h"
#include "./objective.h" #include "./objective.h"
#include "./evaluation.h" #include "./evaluation.h"
#include "../gbm/gbm.h" #include "../gbm/gbm.h"
@ -178,44 +177,37 @@ class BoostLearner : public rabit::ISerializable {
} }
// rabit load model from rabit checkpoint // rabit load model from rabit checkpoint
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::IStream &fi) {
RabitStreamAdapter fs(fi);
// for row split, we should not keep pbuffer // for row split, we should not keep pbuffer
this->LoadModel(fs, distributed_mode != 2, false); this->LoadModel(fi, distributed_mode != 2, false);
} }
// rabit save model to rabit checkpoint // rabit save model to rabit checkpoint
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::IStream &fo) const {
RabitStreamAdapter fs(fo);
// for row split, we should not keep pbuffer // for row split, we should not keep pbuffer
this->SaveModel(fs, distributed_mode != 2); this->SaveModel(fo, distributed_mode != 2);
} }
/*! /*!
* \brief load model from file * \brief load model from file
* \param fname file name * \param fname file name
*/ */
inline void LoadModel(const char *fname) { inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb"); utils::IStream *fi = rabit::io::CreateStream(fname, "r");
utils::FileStream fi(fp);
std::string header; header.resize(4); std::string header; header.resize(4);
// check header for different binary encode // check header for different binary encode
// can be base64 or binary // can be base64 or binary
if (fi.Read(&header[0], 4) != 0) { utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
// base64 format // base64 format
if (header == "bs64") { if (header == "bs64") {
utils::Base64InStream bsin(fp); utils::Base64InStream bsin(fi);
bsin.InitPosition(); bsin.InitPosition();
this->LoadModel(bsin); this->LoadModel(bsin);
fclose(fp); } else if (header == "binf") {
return; this->LoadModel(*fi);
} } else {
if (header == "binf") { delete fi;
this->LoadModel(fi); fi = rabit::io::CreateStream(fname, "r");
fclose(fp); this->LoadModel(*fi);
return;
}
} }
fi.Seek(0); delete fi;
this->LoadModel(fi);
fclose(fp);
} }
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const { inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam)); fo.Write(&mparam, sizeof(ModelParam));
@ -226,33 +218,20 @@ class BoostLearner : public rabit::ISerializable {
/*! /*!
* \brief save model into file * \brief save model into file
* \param fname file name * \param fname file name
* \param save_base64 whether save in base64 format
*/ */
inline void SaveModel(const char *fname) const { inline void SaveModel(const char *fname, bool save_base64 = false) const {
FILE *fp; utils::IStream *fo = rabit::io::CreateStream(fname, "w");
bool use_stdout = false;; if (save_base64 != 0 || !strcmp(fname, "stdout")) {
#ifndef XGBOOST_STRICT_CXX98_ fo->Write("bs64\t", 5);
if (!strcmp(fname, "stdout")) { utils::Base64OutStream bout(fo);
fp = stdout;
use_stdout = true;
} else
#endif
{
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
std::string header;
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
this->SaveModel(bout); this->SaveModel(bout);
bout.Finish('\n'); bout.Finish('\n');
} else { } else {
fo.Write("binf", 4); fo->Write("binf", 4);
this->SaveModel(fo); this->SaveModel(*fo);
}
if (!use_stdout) {
fclose(fp);
} }
delete fo;
} }
/*! /*!
* \brief check if data matrix is ready to be used by training, * \brief check if data matrix is ready to be used by training,
@ -512,23 +491,6 @@ class BoostLearner : public rabit::ISerializable {
// data structure field // data structure field
/*! \brief the entries indicates that we have internal prediction cache */ /*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_; std::vector<CacheEntry> cache_;
private:
// adapt rabit stream to utils stream
struct RabitStreamAdapter : public utils::IStream {
// rabit stream
rabit::IStream &fs;
// constructr
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
// destructor
virtual ~RabitStreamAdapter(void){}
virtual size_t Read(void *ptr, size_t size) {
return fs.Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
fs.Write(ptr, size);
}
};
}; };
} // namespace learner } // namespace learner
} // namespace xgboost } // namespace xgboost

View File

@ -7,6 +7,7 @@
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "../../subtree/rabit/include/rabit.h" #include "../../subtree/rabit/include/rabit.h"
#include "../../subtree/rabit/rabit-learn/io/io.h"
#endif // XGBOOST_SYNC_H_ #endif // XGBOOST_SYNC_H_

View File

@ -296,9 +296,10 @@ class TreeModel {
utils::Check(fi.Read(&param, sizeof(Param)) > 0, utils::Check(fi.Read(&param, sizeof(Param)) > 0,
"TreeModel: wrong format"); "TreeModel: wrong format");
nodes.resize(param.num_nodes); stats.resize(param.num_nodes); nodes.resize(param.num_nodes); stats.resize(param.num_nodes);
utils::Check(fi.Read(&nodes[0], sizeof(Node) * nodes.size()) > 0, utils::Assert(param.num_nodes != 0, "invalid model");
utils::Check(fi.Read(BeginPtr(nodes), sizeof(Node) * nodes.size()) > 0,
"TreeModel: wrong format"); "TreeModel: wrong format");
utils::Check(fi.Read(&stats[0], sizeof(NodeStat) * stats.size()) > 0, utils::Check(fi.Read(BeginPtr(stats), sizeof(NodeStat) * stats.size()) > 0,
"TreeModel: wrong format"); "TreeModel: wrong format");
if (param.size_leaf_vector != 0) { if (param.size_leaf_vector != 0) {
utils::Check(fi.Read(&leaf_vector), "TreeModel: wrong format"); utils::Check(fi.Read(&leaf_vector), "TreeModel: wrong format");
@ -322,8 +323,9 @@ class TreeModel {
utils::Assert(param.num_nodes == static_cast<int>(stats.size()), utils::Assert(param.num_nodes == static_cast<int>(stats.size()),
"Tree::SaveModel"); "Tree::SaveModel");
fo.Write(&param, sizeof(Param)); fo.Write(&param, sizeof(Param));
fo.Write(&nodes[0], sizeof(Node) * nodes.size()); utils::Assert(param.num_nodes != 0, "invalid model");
fo.Write(&stats[0], sizeof(NodeStat) * nodes.size()); fo.Write(BeginPtr(nodes), sizeof(Node) * nodes.size());
fo.Write(BeginPtr(stats), sizeof(NodeStat) * nodes.size());
if (param.size_leaf_vector != 0) fo.Write(leaf_vector); if (param.size_leaf_vector != 0) fo.Write(leaf_vector);
} }
/*! /*!

View File

@ -57,10 +57,10 @@ class SketchMaker: public BaseMaker {
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
this->SetStats(nid, node_stats[nid], p_tree); this->SetStats(nid, node_stats[nid], p_tree);
if (!(*p_tree)[nid].is_leaf()) { if (!(*p_tree)[nid].is_leaf()) {
p_tree->stat(nid).loss_chg = p_tree->stat(nid).loss_chg = static_cast<float>(
node_stats[(*p_tree)[nid].cleft()].CalcGain(param) + node_stats[(*p_tree)[nid].cleft()].CalcGain(param) +
node_stats[(*p_tree)[nid].cright()].CalcGain(param) - node_stats[(*p_tree)[nid].cright()].CalcGain(param) -
node_stats[nid].CalcGain(param); node_stats[nid].CalcGain(param));
} }
} }
// set left leaves // set left leaves
@ -207,9 +207,9 @@ class SketchMaker: public BaseMaker {
} else { } else {
for (size_t i = 0; i < this->qexpand.size(); ++i) { for (size_t i = 0; i < this->qexpand.size(); ++i) {
const unsigned nid = this->qexpand[i]; const unsigned nid = this->qexpand[i];
sbuilder[3 * nid + 0].sum_total = nstats[nid].pos_grad; sbuilder[3 * nid + 0].sum_total = static_cast<bst_float>(nstats[nid].pos_grad);
sbuilder[3 * nid + 1].sum_total = nstats[nid].neg_grad; sbuilder[3 * nid + 1].sum_total = static_cast<bst_float>(nstats[nid].neg_grad);
sbuilder[3 * nid + 2].sum_total = nstats[nid].sum_hess; sbuilder[3 * nid + 2].sum_total = static_cast<bst_float>(nstats[nid].sum_hess);
} }
} }
// if only one value, no need to do second pass // if only one value, no need to do second pass
@ -307,7 +307,7 @@ class SketchMaker: public BaseMaker {
} }
// set statistics on ptree // set statistics on ptree
inline void SetStats(int nid, const SKStats &node_sum, RegTree *p_tree) { inline void SetStats(int nid, const SKStats &node_sum, RegTree *p_tree) {
p_tree->stat(nid).base_weight = node_sum.CalcWeight(param); p_tree->stat(nid).base_weight = static_cast<float>(node_sum.CalcWeight(param));
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess); p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid)); node_sum.SetLeafVec(param, p_tree->leafvec(nid));
} }
@ -350,7 +350,7 @@ class SketchMaker: public BaseMaker {
if (s.sum_hess >= param.min_child_weight && if (s.sum_hess >= param.min_child_weight &&
c.sum_hess >= param.min_child_weight) { c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, fsplits[i], false); best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], false);
} }
// backward // backward
c.SetSubstract(feat_sum, s); c.SetSubstract(feat_sum, s);
@ -358,7 +358,7 @@ class SketchMaker: public BaseMaker {
if (s.sum_hess >= param.min_child_weight && if (s.sum_hess >= param.min_child_weight &&
c.sum_hess >= param.min_child_weight) { c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, fsplits[i], true); best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], true);
} }
} }
{// all including {// all including
@ -368,7 +368,7 @@ class SketchMaker: public BaseMaker {
c.sum_hess >= param.min_child_weight) { c.sum_hess >= param.min_child_weight) {
bst_float cpt = fsplits.back(); bst_float cpt = fsplits.back();
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, cpt + fabsf(cpt) + 1.0f, false); best->Update(static_cast<bst_float>(loss_chg), fid, cpt + fabsf(cpt) + 1.0f, false);
} }
} }
} }

View File

@ -1,205 +0,0 @@
#ifndef XGBOOST_UTILS_BASE64_H_
#define XGBOOST_UTILS_BASE64_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen
*/
#include <cctype>
#include <cstdio>
#include "./utils.h"
#include "./io.h"
namespace xgboost {
namespace utils {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
num_prev = 0; tmp_ch = 0;
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* call this function before actually start read
*/
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
inline bool IsEOF(void) const {
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
}
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev != 0) {
if (num_prev == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev = 1;
}
} else {
// assert num_prev == 1
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
// first byte
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev ++] = nvalue & 0xFF;
}
}
// get next char
tmp_ch = fgetc(fp);
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write");
}
private:
FILE *fp;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf_top < 3 && tlen != 0) {
buf[++buf_top] = *cptr++; --tlen;
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
*/
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
}
private:
FILE *fp;
int buf_top;
unsigned char buf[4];
};
} // namespace utils
} // namespace xgboost
#endif // XGBOOST_UTILS_BASE64_H_

View File

@ -5,6 +5,7 @@
#include <string> #include <string>
#include <cstring> #include <cstring>
#include "./utils.h" #include "./utils.h"
#include "../sync/sync.h"
/*! /*!
* \file io.h * \file io.h
* \brief general stream interface for serialization, I/O * \brief general stream interface for serialization, I/O
@ -12,168 +13,13 @@
*/ */
namespace xgboost { namespace xgboost {
namespace utils { namespace utils {
/*! // reuse the definitions of streams
* \brief interface of stream I/O, used to serialize model typedef rabit::IStream IStream;
*/ typedef rabit::utils::ISeekStream ISeekStream;
class IStream { typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
public: typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
/*! typedef rabit::io::Base64InStream Base64InStream;
* \brief read data from stream typedef rabit::io::Base64OutStream Base64OutStream;
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
};
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream {
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ + size <= buffer_size_,
"read can not have position excceed buffer length");
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
char *p_buffer_;
/*! \brief current pointer */
size_t buffer_size_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual ~MemoryBufferStream(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */ /*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream { class FileStream : public ISeekStream {
@ -194,6 +40,9 @@ class FileStream : public ISeekStream {
virtual size_t Tell(void) { virtual size_t Tell(void) {
return std::ftell(fp); return std::ftell(fp);
} }
virtual bool AtEnd(void) const {
return std::feof(fp) != 0;
}
inline void Close(void) { inline void Close(void) {
if (fp != NULL){ if (fp != NULL){
std::fclose(fp); fp = NULL; std::fclose(fp); fp = NULL;

View File

@ -36,14 +36,8 @@ class BoostLearnTask {
this->SetParam("silent", "1"); this->SetParam("silent", "1");
save_period = 0; save_period = 0;
} }
// whether need data rank // initialized the result
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL; rabit::Init(argc, argv);
// if need data rank in loading, initialize rabit engine before load data
// otherwise, initialize rabit engine after loading data
// lazy initialization of rabit engine can be helpful in speculative execution
if (need_data_rank) rabit::Init(argc, argv);
this->InitData();
if (!need_data_rank) rabit::Init(argc, argv);
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName(); std::string pname = rabit::GetProcessorName();
fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank()); fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank());
@ -54,6 +48,8 @@ class BoostLearnTask {
if (rabit::GetRank() != 0) { if (rabit::GetRank() != 0) {
this->SetParam("silent", "2"); this->SetParam("silent", "2");
} }
this->InitData();
if (task == "train") { if (task == "train") {
// if task is training, will try recover from checkpoint // if task is training, will try recover from checkpoint
this->TaskTrain(); this->TaskTrain();
@ -135,17 +131,22 @@ class BoostLearnTask {
train_path = s_tmp; train_path = s_tmp;
load_part = 1; load_part = 1;
} }
bool loadsplit = data_split == "row";
if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str()); if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str());
if (task == "dump") return; if (task == "dump") return;
if (task == "pred") { if (task == "pred") {
data = io::LoadDataMatrix(test_path.c_str(), silent != 0, use_buffer != 0); data = io::LoadDataMatrix(test_path.c_str(), silent != 0, use_buffer != 0, loadsplit);
} else { } else {
// training // training
data = io::LoadDataMatrix(train_path.c_str(), silent != 0 && load_part == 0, use_buffer != 0); data = io::LoadDataMatrix(train_path.c_str(),
silent != 0 && load_part == 0,
use_buffer != 0, loadsplit);
utils::Assert(eval_data_names.size() == eval_data_paths.size(), "BUG"); utils::Assert(eval_data_names.size() == eval_data_paths.size(), "BUG");
for (size_t i = 0; i < eval_data_names.size(); ++i) { for (size_t i = 0; i < eval_data_names.size(); ++i) {
deval.push_back(io::LoadDataMatrix(eval_data_paths[i].c_str(), silent != 0, use_buffer != 0)); deval.push_back(io::LoadDataMatrix(eval_data_paths[i].c_str(),
silent != 0,
use_buffer != 0,
loadsplit));
devalall.push_back(deval.back()); devalall.push_back(deval.back());
} }

View File

@ -2,7 +2,7 @@ ifndef CXX
export CXX = g++ export CXX = g++
endif endif
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -Llib export LDFLAGS= -Llib -lrt
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -pedantic export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -pedantic
export CFLAGS = -O3 -msse2 -fPIC $(WARNFLAGS) export CFLAGS = -O3 -msse2 -fPIC $(WARNFLAGS)
@ -50,7 +50,7 @@ $(ALIB):
ar cr $@ $+ ar cr $@ $+
$(SLIB) : $(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
clean: clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~ $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~

View File

@ -13,7 +13,7 @@ All these features comes from the facts about small rabbit:)
* Portable: rabit is light weight and runs everywhere * Portable: rabit is light weight and runs everywhere
- Rabit is a library instead of a framework, a program only needs to link the library to run - Rabit is a library instead of a framework, a program only needs to link the library to run
- Rabit only replies on a mechanism to start program, which was provided by most framework - Rabit only replies on a mechanism to start program, which was provided by most framework
- You can run rabit programs on many platforms, including Hadoop, MPI using the same code - You can run rabit programs on many platforms, including Yarn(Hadoop), MPI using the same code
* Scalable and Flexible: rabit runs fast * Scalable and Flexible: rabit runs fast
* Rabit program use Allreduce to communicate, and do not suffer the cost between iterations of MapReduce abstraction. * Rabit program use Allreduce to communicate, and do not suffer the cost between iterations of MapReduce abstraction.
- Programs can call rabit functions in any order, as opposed to frameworks where callbacks are offered and called by the framework, i.e. inversion of control principle. - Programs can call rabit functions in any order, as opposed to frameworks where callbacks are offered and called by the framework, i.e. inversion of control principle.

View File

@ -341,12 +341,11 @@ Rabit is a portable library that can run on multiple platforms.
* This script will restart the program when it exits with -2, so it can be used for [mock test](#link-against-mock-test-library) * This script will restart the program when it exits with -2, so it can be used for [mock test](#link-against-mock-test-library)
#### Running Rabit on Hadoop #### Running Rabit on Hadoop
* You can use [../tracker/rabit_hadoop.py](../tracker/rabit_hadoop.py) to run rabit programs on hadoop * You can use [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) to run rabit programs as Yarn application
* This will start n rabit programs as mappers of MapReduce * This will start rabit programs as yarn applications
* Each program can read its portion of data from stdin
* Yarn(Hadoop 2.0 or higher) is highly recommended, since Yarn allows specifying number of cpus and memory of each mapper:
- This allows multi-threading programs in each node, which can be more efficient - This allows multi-threading programs in each node, which can be more efficient
- An easy multi-threading solution could be to use OpenMP with rabit code - An easy multi-threading solution could be to use OpenMP with rabit code
* It is also possible to run rabit program via hadoop streaming, however, YARN is highly recommended.
#### Running Rabit using MPI #### Running Rabit using MPI
* You can submit rabit programs to an MPI cluster using [../tracker/rabit_mpi.py](../tracker/rabit_mpi.py). * You can submit rabit programs to an MPI cluster using [../tracker/rabit_mpi.py](../tracker/rabit_mpi.py).
@ -358,15 +357,15 @@ tracker scripts, such as [../tracker/rabit_hadoop.py](../tracker/rabit_hadoop.py
You will need to implement a platform dependent submission function with the following definition You will need to implement a platform dependent submission function with the following definition
```python ```python
def fun_submit(nworkers, worker_args): def fun_submit(nworkers, worker_args, worker_envs):
""" """
customized submit script, that submits nslave jobs, customized submit script, that submits nslave jobs,
each must contain args as parameter each must contain args as parameter
note this can be a lambda closure note this can be a lambda closure
Parameters Parameters
nworkers number of worker processes to start nworkers number of worker processes to start
worker_args tracker information which must be passed to the arguments worker_args addtiional arguments that needs to be passed to worker
this usually includes the parameters of master_uri and port, etc. worker_envs enviroment variables that need to be set to the worker
""" """
``` ```
The submission function should start nworkers processes in the platform, and append worker_args to the end of the other arguments. The submission function should start nworkers processes in the platform, and append worker_args to the end of the other arguments.
@ -374,7 +373,7 @@ Then you can simply call ```tracker.submit``` with fun_submit to submit jobs to
Note that the current rabit tracker does not restart a worker when it dies, the restart of a node is done by the platform, otherwise we should write the fail-restart logic in the custom script. Note that the current rabit tracker does not restart a worker when it dies, the restart of a node is done by the platform, otherwise we should write the fail-restart logic in the custom script.
* Fail-restart is usually provided by most platforms. * Fail-restart is usually provided by most platforms.
* For example, mapreduce will restart a mapper when it fails - rabit-yarn provides such functionality in YARN
Fault Tolerance Fault Tolerance
===== =====

View File

@ -65,9 +65,8 @@ inline int GetRank(void);
/*! \brief gets total number of processes */ /*! \brief gets total number of processes */
inline int GetWorldSize(void); inline int GetWorldSize(void);
/*! \brief whether rabit env is in distributed mode */ /*! \brief whether rabit env is in distributed mode */
inline bool IsDistributed(void) { inline bool IsDistributed(void);
return GetWorldSize() != 1;
}
/*! \brief gets processor's name */ /*! \brief gets processor's name */
inline std::string GetProcessorName(void); inline std::string GetProcessorName(void);
/*! /*!

View File

@ -145,6 +145,8 @@ class IEngine {
virtual int GetRank(void) const = 0; virtual int GetRank(void) const = 0;
/*! \brief gets total number of nodes */ /*! \brief gets total number of nodes */
virtual int GetWorldSize(void) const = 0; virtual int GetWorldSize(void) const = 0;
/*! \brief whether we run in distribted mode */
virtual bool IsDistributed(void) const = 0;
/*! \brief gets the host name of the current node */ /*! \brief gets the host name of the current node */
virtual std::string GetHost(void) const = 0; virtual std::string GetHost(void) const = 0;
/*! /*!

View File

@ -23,6 +23,8 @@ class ISeekStream: public IStream {
virtual void Seek(size_t pos) = 0; virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */ /*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0; virtual size_t Tell(void) = 0;
/*! \return whether we are at end of file */
virtual bool AtEnd(void) const = 0;
}; };
/*! \brief fixed size memory buffer */ /*! \brief fixed size memory buffer */
@ -55,7 +57,9 @@ struct MemoryFixSizeBuffer : public ISeekStream {
virtual size_t Tell(void) { virtual size_t Tell(void) {
return curr_ptr_; return curr_ptr_;
} }
virtual bool AtEnd(void) const {
return curr_ptr_ == buffer_size_;
}
private: private:
/*! \brief in memory buffer */ /*! \brief in memory buffer */
char *p_buffer_; char *p_buffer_;
@ -95,7 +99,9 @@ struct MemoryBufferStream : public ISeekStream {
virtual size_t Tell(void) { virtual size_t Tell(void) {
return curr_ptr_; return curr_ptr_;
} }
virtual bool AtEnd(void) const {
return curr_ptr_ == p_buffer_->length();
}
private: private:
/*! \brief in memory buffer */ /*! \brief in memory buffer */
std::string *p_buffer_; std::string *p_buffer_;

View File

@ -107,6 +107,10 @@ inline int GetRank(void) {
inline int GetWorldSize(void) { inline int GetWorldSize(void) {
return engine::GetEngine()->GetWorldSize(); return engine::GetEngine()->GetWorldSize();
} }
// whether rabit is distributed
inline bool IsDistributed(void) {
return engine::GetEngine()->IsDistributed();
}
// get the name of current processor // get the name of current processor
inline std::string GetProcessorName(void) { inline std::string GetProcessorName(void) {
return engine::GetEngine()->GetHost(); return engine::GetEngine()->GetHost();

View File

@ -3,9 +3,13 @@
* \brief This file defines the utils for timing * \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi * \author Tianqi Chen, Nacho, Tianyi
*/ */
#ifndef RABIT_TIMER_H #ifndef RABIT_TIMER_H_
#define RABIT_TIMER_H #define RABIT_TIMER_H_
#include <time.h> #include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#include "./utils.h" #include "./utils.h"
namespace rabit { namespace rabit {
@ -14,10 +18,19 @@ namespace utils {
* \brief return time in seconds, not cross platform, avoid to use this in most places * \brief return time in seconds, not cross platform, avoid to use this in most places
*/ */
inline double GetTime(void) { inline double GetTime(void) {
#ifdef __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
utils::Check(clock_get_time(cclock, &mts) == 0, "failed to get time");
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
timespec ts; timespec ts;
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time"); utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9; return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#endif
} }
} } // namespace utils
} } // namespace rabit
#endif #endif // RABIT_TIMER_H_

View File

@ -5,15 +5,13 @@ It also contain links to the Machine Learning packages that uses rabit.
* Contribution of toolkits, examples, benchmarks is more than welcomed! * Contribution of toolkits, examples, benchmarks is more than welcomed!
Toolkits Toolkits
==== ====
* [KMeans Clustering](kmeans) * [KMeans Clustering](kmeans)
* [Linear and Logistic Regression](linear) * [Linear and Logistic Regression](linear)
* [XGBoost: eXtreme Gradient Boosting](https://github.com/tqchen/xgboost/tree/master/multi-node) * [XGBoost: eXtreme Gradient Boosting](https://github.com/tqchen/xgboost/tree/master/multi-node)
- xgboost is a very fast boosted tree(also known as GBDT) library, that can run more than - xgboost is a very fast boosted tree(also known as GBDT) library, that can run more than
10 times faster than existing packages 10 times faster than existing packages
- Rabit carries xgboost to distributed enviroment, inheritating all the benefits of xgboost - Rabit carries xgboost to distributed enviroment, inheritating all the benefits of xgboost
single node version, and scale it to even larger problems single node version, and scale it to even larger problems

View File

@ -1,5 +1,5 @@
#ifndef RABIT_LEARN_UTILS_BASE64_H_ #ifndef RABIT_LEARN_IO_BASE64_INL_H_
#define RABIT_LEARN_UTILS_BASE64_H_ #define RABIT_LEARN_IO_BASE64_INL_H_
/*! /*!
* \file base64.h * \file base64.h
* \brief data stream support to input and output from/to base64 stream * \brief data stream support to input and output from/to base64 stream
@ -8,10 +8,11 @@
*/ */
#include <cctype> #include <cctype>
#include <cstdio> #include <cstdio>
#include <rabit/io.h> #include "./io.h"
#include "./buffer_reader-inl.h"
namespace rabit { namespace rabit {
namespace utils { namespace io {
/*! \brief namespace of base64 decoding and encoding table */ /*! \brief namespace of base64 decoding and encoding table */
namespace base64 { namespace base64 {
const char DecodeTable[] = { const char DecodeTable[] = {
@ -34,7 +35,8 @@ static const char EncodeTable[] =
/*! \brief the stream that reads from base64, note we take from file pointers */ /*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream { class Base64InStream: public IStream {
public: public:
explicit Base64InStream(FILE *fp) : fp(fp) { explicit Base64InStream(IStream *fs) : reader_(256) {
reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0; num_prev = 0; tmp_ch = 0;
} }
/*! /*!
@ -44,7 +46,7 @@ class Base64InStream: public IStream {
inline void InitPosition(void) { inline void InitPosition(void) {
// get a charater // get a charater
do { do {
tmp_ch = fgetc(fp); tmp_ch = reader_.GetChar();
} while (isspace(tmp_ch)); } while (isspace(tmp_ch));
} }
/*! \brief whether current position is end of a base64 stream */ /*! \brief whether current position is end of a base64 stream */
@ -85,19 +87,19 @@ class Base64InStream: public IStream {
nvalue = DecodeTable[tmp_ch] << 18; nvalue = DecodeTable[tmp_ch] << 18;
{ {
// second byte // second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format"); "invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12; nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen; *cptr++ = (nvalue >> 16) & 0xFF; --tlen;
} }
{ {
// third byte // third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format"); "invalid base64 format");
// handle termination // handle termination
if (tmp_ch == '=') { if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format"); utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)), utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format"); "invalid base64 format");
break; break;
} }
@ -110,10 +112,10 @@ class Base64InStream: public IStream {
} }
{ {
// fourth byte // fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format"); "invalid base64 format");
if (tmp_ch == '=') { if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)), utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format"); "invalid base64 format");
break; break;
} }
@ -125,10 +127,10 @@ class Base64InStream: public IStream {
} }
} }
// get next char // get next char
tmp_ch = fgetc(fp); tmp_ch = reader_.GetChar();
} }
if (kStrictCheck) { if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete"); utils::Check(tlen == 0, "Base64InStream: read incomplete");
} }
return size - tlen; return size - tlen;
} }
@ -137,7 +139,7 @@ class Base64InStream: public IStream {
} }
private: private:
FILE *fp; StreamBufferReader reader_;
int tmp_ch; int tmp_ch;
int num_prev; int num_prev;
unsigned char buf_prev[2]; unsigned char buf_prev[2];
@ -147,7 +149,7 @@ class Base64InStream: public IStream {
/*! \brief the stream that write to base64, note we take from file pointers */ /*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream { class Base64OutStream: public IStream {
public: public:
explicit Base64OutStream(FILE *fp) : fp(fp) { explicit Base64OutStream(IStream *fp) : fp(fp) {
buf_top = 0; buf_top = 0;
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
@ -160,16 +162,16 @@ class Base64OutStream: public IStream {
} }
if (buf_top == 3) { if (buf_top == 3) {
// flush 4 bytes out // flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp); PutChar(EncodeTable[buf[1] >> 2]);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp); PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp); PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
fputc(EncodeTable[buf[3] & 0x3F], fp); PutChar(EncodeTable[buf[3] & 0x3F]);
buf_top = 0; buf_top = 0;
} }
} }
} }
virtual size_t Read(void *ptr, size_t size) { virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read"); utils::Error("Base64OutStream do not support read");
return 0; return 0;
} }
/*! /*!
@ -179,26 +181,38 @@ class Base64OutStream: public IStream {
inline void Finish(char endch = EOF) { inline void Finish(char endch = EOF) {
using base64::EncodeTable; using base64::EncodeTable;
if (buf_top == 1) { if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp); PutChar(EncodeTable[buf[1] >> 2]);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp); PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
fputc('=', fp); PutChar('=');
fputc('=', fp); PutChar('=');
} }
if (buf_top == 2) { if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp); PutChar(EncodeTable[buf[1] >> 2]);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp); PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp); PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
fputc('=', fp); PutChar('=');
} }
buf_top = 0; buf_top = 0;
if (endch != EOF) fputc(endch, fp); if (endch != EOF) PutChar(endch);
this->Flush();
} }
private: private:
FILE *fp; IStream *fp;
int buf_top; int buf_top;
unsigned char buf[4]; unsigned char buf[4];
std::string out_buf;
const static size_t kBufferSize = 256;
inline void PutChar(char ch) {
out_buf += ch;
if (out_buf.length() >= kBufferSize) Flush();
}
inline void Flush(void) {
fp->Write(BeginPtr(out_buf), out_buf.length());
out_buf.clear();
}
}; };
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit
#endif // RABIT_LEARN_UTILS_BASE64_H_ #endif // RABIT_LEARN_UTILS_BASE64_INL_H_

View File

@ -0,0 +1,57 @@
#ifndef RABIT_LEARN_IO_BUFFER_READER_INL_H_
#define RABIT_LEARN_IO_BUFFER_READER_INL_H_
/*!
* \file buffer_reader-inl.h
* \brief implementation of stream buffer reader
* \author Tianqi Chen
*/
#include "./io.h"
namespace rabit {
namespace io {
/*! \brief buffer reader of the stream that allows you to get */
class StreamBufferReader {
public:
StreamBufferReader(size_t buffer_size)
:stream_(NULL),
read_len_(1), read_ptr_(1) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
*/
inline void set_stream(IStream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \brief allows quick read using get char
*/
inline char GetChar(void) {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
inline bool AtEnd(void) const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
IStream *stream_;
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_;
/*! \brief pointer in the buffer */
size_t read_ptr_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_BUFFER_READER_INL_H_

View File

@ -0,0 +1,107 @@
#ifndef RABIT_LEARN_IO_FILE_INL_H_
#define RABIT_LEARN_IO_FILE_INL_H_
/*!
* \file file-inl.h
* \brief normal filesystem I/O
* \author Tianqi Chen
*/
#include <string>
#include <vector>
#include <cstdio>
#include "./io.h"
#include "./line_split-inl.h"
/*! \brief io interface */
namespace rabit {
namespace io {
/*! \brief implementation of file i/o stream */
class FileStream : public utils::ISeekStream {
public:
explicit FileStream(const char *fname, const char *mode)
: use_stdio(false) {
using namespace std;
#ifndef RABIT_STRICT_CXX98_
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
fp = utils::FopenCheck(fname, flag.c_str());
}
}
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return std::feof(fp) != 0;
}
inline void Close(void) {
if (fp != NULL && !use_stdio) {
std::fclose(fp); fp = NULL;
}
}
private:
std::FILE *fp;
bool use_stdio;
};
/*! \brief line split from normal file system */
class FileSplit : public LineSplitBase {
public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
LineSplitBase::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) {
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp;
}
fsize.push_back(GetFileSize(fnames_[i].c_str()));
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~FileSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb");
}
// get file size
inline static size_t GetFileSize(const char *fname) {
std::FILE *fp = utils::FopenCheck(fname, "rb");
// NOTE: fseek may not be good, but serves as ok solution
std::fseek(fp, 0, SEEK_END);
size_t fsize = static_cast<size_t>(std::ftell(fp));
std::fclose(fp);
return fsize;
}
private:
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_FILE_INL_H_

View File

@ -0,0 +1,140 @@
#ifndef RABIT_LEARN_IO_HDFS_INL_H_
#define RABIT_LEARN_IO_HDFS_INL_H_
/*!
* \file hdfs-inl.h
* \brief HDFS I/O
* \author Tianqi Chen
*/
#include <string>
#include <vector>
#include <hdfs.h>
#include <errno.h>
#include "./io.h"
#include "./line_split-inl.h"
/*! \brief io interface */
namespace rabit {
namespace io {
class HDFSStream : public utils::ISeekStream {
public:
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
: fs_(fs), at_end_(false) {
int flag;
if (!strcmp(mode, "r")) {
flag = O_RDONLY;
} else if (!strcmp(mode, "w")) {
flag = O_WRONLY;
} else if (!strcmp(mode, "a")) {
flag = O_WRONLY | O_APPEND;
} else {
utils::Error("HDFSStream: unknown flag %s", mode);
}
fp_ = hdfsOpenFile(fs_, fname, flag, 0, 0, 0);
utils::Check(fp_ != NULL,
"HDFSStream: fail to open %s", fname);
}
virtual ~HDFSStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
tSize nread = hdfsRead(fs_, fp_, ptr, size);
if (nread == -1) {
int errsv = errno;
utils::Error("HDFSStream.Read Error:%s", strerror(errsv));
}
if (nread == 0) {
at_end_ = true;
}
return static_cast<size_t>(nread);
}
virtual void Write(const void *ptr, size_t size) {
const char *buf = reinterpret_cast<const char*>(ptr);
while (size != 0) {
tSize nwrite = hdfsWrite(fs_, fp_, buf, size);
if (nwrite == -1) {
int errsv = errno;
utils::Error("HDFSStream.Write Error:%s", strerror(errsv));
}
size_t sz = static_cast<size_t>(nwrite);
buf += sz; size -= sz;
}
}
virtual void Seek(size_t pos) {
if (hdfsSeek(fs_, fp_, pos) != 0) {
int errsv = errno;
utils::Error("HDFSStream.Seek Error:%s", strerror(errsv));
}
}
virtual size_t Tell(void) {
tOffset offset = hdfsTell(fs_, fp_);
if (offset == -1) {
int errsv = errno;
utils::Error("HDFSStream.Tell Error:%s", strerror(errsv));
}
return static_cast<size_t>(offset);
}
virtual bool AtEnd(void) const {
return at_end_;
}
inline void Close(void) {
if (fp_ != NULL) {
if (hdfsCloseFile(fs_, fp_) == -1) {
int errsv = errno;
utils::Error("HDFSStream.Close Error:%s", strerror(errsv));
}
fp_ = NULL;
}
}
private:
hdfsFS fs_;
hdfsFile fp_;
bool at_end_;
};
/*! \brief line split from normal file system */
class HDFSSplit : public LineSplitBase {
public:
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) {
fs_ = hdfsConnect("default", 0);
std::vector<std::string> paths;
LineSplitBase::SplitNames(&paths, uri, "#");
// get the files
std::vector<size_t> fsize;
for (size_t i = 0; i < paths.size(); ++i) {
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
if (info->mKind == 'D') {
int nentry;
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
for (int i = 0; i < nentry; ++i) {
if (files[i].mKind == 'F') {
fsize.push_back(files[i].mSize);
fnames_.push_back(std::string(files[i].mName));
}
}
hdfsFreeFileInfo(files, nentry);
} else {
fsize.push_back(info->mSize);
fnames_.push_back(std::string(info->mName));
}
hdfsFreeFileInfo(info, 1);
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~HDFSSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
}
private:
// hdfs handle
hdfsFS fs_;
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_HDFS_INL_H_

View File

@ -0,0 +1,67 @@
#ifndef RABIT_LEARN_IO_IO_INL_H_
#define RABIT_LEARN_IO_IO_INL_H_
/*!
* \file io-inl.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include <cstring>
#include "./io.h"
#if RABIT_USE_HDFS
#include "./hdfs-inl.h"
#endif
#include "./file-inl.h"
namespace rabit {
namespace io {
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit) {
using namespace std;
if (!strcmp(uri, "stdin")) {
return new SingleFileSplit(uri);
}
if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit);
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSSplit(uri, part, nsplit);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif
}
return new FileSplit(uri, part, nsplit);
}
/*!
* \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted
*
* \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write
*/
inline IStream *CreateStream(const char *uri, const char *mode) {
using namespace std;
if (!strncmp(uri, "file://", 7)) {
return new FileStream(uri + 7, mode);
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif
}
return new FileStream(uri, mode);
}
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_IO_INL_H_

View File

@ -0,0 +1,61 @@
#ifndef RABIT_LEARN_IO_IO_H_
#define RABIT_LEARN_IO_IO_H_
/*!
* \file io.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include "../../include/rabit_serializable.h"
/*! \brief whether compile with HDFS support */
#ifndef RABIT_USE_HDFS
#define RABIT_USE_HDFS 0
#endif
/*! \brief io interface */
namespace rabit {
/*!
* \brief namespace to handle input split and filesystem interfacing
*/
namespace io {
typedef utils::ISeekStream ISeekStream;
/*!
* \brief user facing input split helper,
* can be used to get the partition of data used by current node
*/
class InputSplit {
public:
/*!
* \brief get next line, store into out_data
* \param out_data the string that stores the line data,
* \n is not included
* \return true of next line was found, false if we read all the lines
*/
virtual bool NextLine(std::string *out_data) = 0;
/*! \brief destructor*/
virtual ~InputSplit(void) {}
};
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit);
/*!
* \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted
*
* \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write
*/
inline IStream *CreateStream(const char *uri, const char *mode);
} // namespace io
} // namespace rabit
#include "./io-inl.h"
#include "./base64-inl.h"
#endif // RABIT_LEARN_IO_IO_H_

View File

@ -0,0 +1,181 @@
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
/*!
* \std::FILE line_split-inl.h
* \brief base implementation of line-spliter
* \author Tianqi Chen
*/
#include <vector>
#include <utility>
#include <cstring>
#include <string>
#include "../../include/rabit.h"
#include "./io.h"
#include "./buffer_reader-inl.h"
namespace rabit {
namespace io {
class LineSplitBase : public InputSplit {
public:
virtual ~LineSplitBase() {
if (fs_ != NULL) delete fs_;
}
virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
out_data->clear();
while (true) {
char c = reader_.GetChar();
if (reader_.AtEnd()) {
if (out_data->length() != 0) return true;
file_ptr_ += 1;
if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:std::FILE size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_];
}
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check");
delete fs_;
fs_ = this->GetFile(file_ptr_);
reader_.set_stream(fs_);
} else {
++offset_curr_;
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
}
}
}
}
protected:
// constructor
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
}
/*!
* \brief initialize the line spliter,
* \param file_size, size of each std::FILEs
* \param rank the current rank of the data
* \param nsplit number of split we will divide the data into
*/
inline void Init(const std::vector<size_t> &file_size,
unsigned rank, unsigned nsplit) {
file_offset_.resize(file_size.size() + 1);
file_offset_[0] = 0;
for (size_t i = 0; i < file_size.size(); ++i) {
file_offset_[i + 1] = file_offset_[i] + file_size[i];
}
size_t ntotal = file_offset_.back();
size_t nstep = (ntotal + nsplit - 1) / nsplit;
offset_begin_ = std::min(nstep * rank, ntotal);
offset_end_ = std::min(nstep * (rank + 1), ntotal);
offset_curr_ = offset_begin_;
if (offset_begin_ == offset_end_) return;
file_ptr_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_begin_) - file_offset_.begin() - 1;
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_);
reader_.set_stream(fs_);
// try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) {
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]);
while (true) {
char c = reader_.GetChar();
if (!reader_.AtEnd()) ++offset_curr_;
if (c == '\n' || c == '\r' || c == EOF) return;
}
}
}
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of std::FILE
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
}
private:
/*! \brief current input stream */
utils::ISeekStream *fs_;
/*! \brief std::FILE pointer of which std::FILE to read on */
size_t file_ptr_;
/*! \brief std::FILE pointer where the end of std::FILE lies */
size_t file_ptr_end_;
/*! \brief get the current offset */
size_t offset_curr_;
/*! \brief beginning of offset */
size_t offset_begin_;
/*! \brief end of the offset */
size_t offset_end_;
/*! \brief byte-offset of each std::FILE */
std::vector<size_t> file_offset_;
/*! \brief buffer reader */
StreamBufferReader reader_;
/*! \brief buffer size */
const static size_t kBufferSize = 256;
};
/*! \brief line split from single std::FILE */
class SingleFileSplit : public InputSplit {
public:
explicit SingleFileSplit(const char *fname) {
if (!std::strcmp(fname, "stdin")) {
#ifndef RABIT_STRICT_CXX98_
use_stdin_ = true; fp_ = stdin;
#endif
}
if (!use_stdin_) {
fp_ = utils::FopenCheck(fname, "r");
}
end_of_file_ = false;
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) std::fclose(fp_);
}
virtual bool NextLine(std::string *out_data) {
if (end_of_file_) return false;
out_data->clear();
while (true) {
char c = std::fgetc(fp_);
if (c == EOF) {
end_of_file_ = true;
}
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (end_of_file_) return false;
}
}
return false;
}
private:
std::FILE *fp_;
bool use_stdin_;
bool end_of_file_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_LINE_SPLIT_INL_H_

View File

@ -6,11 +6,10 @@ MPIBIN = kmeans.mpi
OBJ = kmeans.o OBJ = kmeans.o
# common build script for programs # common build script for programs
include ../common.mk include ../make/common.mk
# dependenies here # dependenies here
kmeans.rabit: kmeans.o lib kmeans.rabit: kmeans.o lib
kmeans.mock: kmeans.o lib kmeans.mock: kmeans.o lib
kmeans.mpi: kmeans.o libmpi kmeans.mpi: kmeans.o libmpi
kmeans.o: kmeans.cc ../../src/*.h kmeans.o: kmeans.cc ../../src/*.h

View File

@ -0,0 +1,2 @@
mushroom.row*
*.model

View File

@ -6,7 +6,8 @@ MPIBIN =
OBJ = linear.o OBJ = linear.o
# common build script for programs # common build script for programs
include ../common.mk include ../make/config.mk
include ../make/common.mk
CFLAGS+=-fopenmp CFLAGS+=-fopenmp
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
# dependenies here # dependenies here

View File

@ -2,11 +2,24 @@ Linear and Logistic Regression
==== ====
* input format: LibSVM * input format: LibSVM
* Local Example: [run-linear.sh](run-linear.sh) * Local Example: [run-linear.sh](run-linear.sh)
* Runnig on Hadoop: [run-hadoop.sh](run-hadoop.sh) * Runnig on YARN: [run-yarn.sh](run-yarn.sh)
- Set input data to stdin, and model_out=stdout - You will need to have YARN
- Modify ```../make/config.mk``` to set USE_HDFS=1 to compile with HDFS support
- Run build.sh on [../../yarn](../../yarn) on to build yarn jar file
Multi-Threading Optimization
====
* The code can be multi-threaded, we encourage you to use it
- Simply add ```nthread=k``` where k is the number of threads you want to use
* If you submit with YARN
- Use ```--vcores``` and ```-mem``` to request CPU and memory resources
- Some scheduler in YARN do not honor CPU request, you can request more memory to grab working slots
* Usually multi-threading improves speed in general
- You can use less workers and assign more resources to each of worker
- This usually means less communication overhead and faster running time
Parameters Parameters
=== ====
All the parameters can be set by param=value All the parameters can be set by param=value
#### Important Parameters #### Important Parameters

View File

@ -1,6 +1,5 @@
#include "./linear.h" #include "./linear.h"
#include "../utils/io.h" #include "../io/io.h"
#include "../utils/base64.h"
namespace rabit { namespace rabit {
namespace linear { namespace linear {
@ -55,7 +54,9 @@ class LinearObjFunction : public solver::IObjFunction<float> {
} }
if (task == "train") { if (task == "train") {
lbfgs.Run(); lbfgs.Run();
this->SaveModel(model_out.c_str(), lbfgs.GetWeight()); if (rabit::GetRank() == 0) {
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
}
} else if (task == "pred") { } else if (task == "pred") {
this->TaskPred(); this->TaskPred();
} else { } else {
@ -74,51 +75,37 @@ class LinearObjFunction : public solver::IObjFunction<float> {
printf("Finishing writing to %s\n", name_pred.c_str()); printf("Finishing writing to %s\n", name_pred.c_str());
} }
inline void LoadModel(const char *fname) { inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb"); IStream *fi = io::CreateStream(fname, "r");
std::string header; header.resize(4); std::string header; header.resize(4);
// check header for different binary encode // check header for different binary encode
// can be base64 or binary // can be base64 or binary
utils::FileStream fi(fp); utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model"); // base64 format
// base64 format
if (header == "bs64") { if (header == "bs64") {
utils::Base64InStream bsin(fp); io::Base64InStream bsin(fi);
bsin.InitPosition(); bsin.InitPosition();
model.Load(bsin); model.Load(bsin);
fclose(fp);
return;
} else if (header == "binf") { } else if (header == "binf") {
model.Load(fi); model.Load(*fi);
fclose(fp);
return;
} else { } else {
utils::Error("invalid model file"); utils::Error("invalid model file");
} }
delete fi;
} }
inline void SaveModel(const char *fname, inline void SaveModel(const char *fname,
const float *wptr, const float *wptr,
bool save_base64 = false) { bool save_base64 = false) {
FILE *fp; IStream *fo = io::CreateStream(fname, "w");
bool use_stdout = false; if (save_base64 != 0 || !strcmp(fname, "stdout")) {
if (!strcmp(fname, "stdout")) { fo->Write("bs64\t", 5);
fp = stdout; io::Base64OutStream bout(fo);
use_stdout = true;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
model.Save(bout, wptr); model.Save(bout, wptr);
bout.Finish('\n'); bout.Finish('\n');
} else { } else {
fo.Write("binf", 4); fo->Write("binf", 4);
model.Save(fo, wptr); model.Save(*fo, wptr);
}
if (!use_stdout) {
fclose(fp);
} }
delete fo;
} }
inline void LoadData(const char *fname) { inline void LoadData(const char *fname) {
dtrain.Load(fname); dtrain.Load(fname);

View File

@ -12,7 +12,7 @@ hadoop fs -mkdir $2/data
hadoop fs -put ../data/agaricus.txt.train $2/data hadoop fs -put ../data/agaricus.txt.train $2/data
# submit to hadoop # submit to hadoop
../../tracker/rabit_hadoop.py --host_ip ip -n $1 -i $2/data/agaricus.txt.train -o $2/mushroom.linear.model linear.rabit stdin model_out=stdout "${*:3}" ../../tracker/rabit_hadoop_streaming.py -n $1 --vcores 1 -i $2/data/agaricus.txt.train -o $2/mushroom.linear.model linear.rabit stdin model_out=stdout "${*:3}"
# get the final model file # get the final model file
hadoop fs -get $2/mushroom.linear.model/part-00000 ./linear.model hadoop fs -get $2/mushroom.linear.model/part-00000 ./linear.model

View File

@ -5,11 +5,7 @@ then
exit -1 exit -1
fi fi
rm -rf mushroom.row* *.model rm -rf *.model
k=$1 k=$1
# split the lib svm file into k subfiles ../../tracker/rabit_demo.py -n $k linear.mock ../data/agaricus.txt.train "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.mock mushroom.row\%d "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1

View File

@ -5,13 +5,10 @@ then
exit -1 exit -1
fi fi
rm -rf mushroom.row* *.model rm -rf *.model
k=$1 k=$1
# split the lib svm file into k subfiles # run linear model, the program will automatically split the inputs
python splitrows.py ../data/agaricus.txt.train mushroom $k ../../tracker/rabit_demo.py -n $k linear.rabit ../data/agaricus.txt.train reg_L1=1
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model ./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model

View File

@ -0,0 +1,19 @@
#!/bin/bash
if [ "$#" -lt 3 ];
then
echo "Usage: <nworkers> <path_in_HDFS> [param=val]"
exit -1
fi
# put the local training file to HDFS
hadoop fs -rm -r -f $2/data
hadoop fs -rm -r -f $2/mushroom.linear.model
hadoop fs -mkdir $2/data
# submit to hadoop
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
# get the final model file
hadoop fs -get $2/mushroom.linear.model ./linear.model
./linear.rabit ../data/agaricus.txt.test task=pred model_in=linear.model

View File

@ -1,24 +0,0 @@
#!/usr/bin/python
import sys
import random
# split libsvm file into different rows
if len(sys.argv) < 4:
print ('Usage:<fin> <fo> k')
exit(0)
random.seed(10)
k = int(sys.argv[3])
fi = open( sys.argv[1], 'r' )
fos = []
for i in range(k):
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
for l in open(sys.argv[1]):
i = random.randint(0, k-1)
fos[i].write(l)
for f in fos:
f.close()

View File

@ -1,13 +1,20 @@
# this is the common build script for rabit programs # this is the common build script for rabit programs
# you do not have to use it # you do not have to use it
export CC = gcc export LDFLAGS= -L../../lib -pthread -lm -lrt
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../../lib
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
# setup opencv
ifeq ($(USE_HDFS),1)
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(LIBJVM) -lhdfs -ljvm
else
CFLAGS+= -DRABIT_USE_HDFS=0
endif
.PHONY: clean all lib mpi .PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN) all: $(BIN) $(MOCKBIN)
mpi: $(MPIBIN) mpi: $(MPIBIN)
lib: lib:
@ -15,10 +22,12 @@ lib:
libmpi: libmpi:
cd ../..;make lib/librabit_mpi.a;cd - cd ../..;make lib/librabit_mpi.a;cd -
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit $(LDFLAGS)
$(MOCKBIN) : $(MOCKBIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
$(OBJ) : $(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

View File

@ -0,0 +1,21 @@
#-----------------------------------------------------
# rabit-learn: the configuration compile script
#
# This is the default configuration setup for rabit-learn
# If you want to change configuration, do the following steps:
#
# - copy this file to the root of rabit-learn folder
# - modify the configuration you want
# - type make or make -j n for parallel build
#----------------------------------------------------
# choice of compiler
export CC = gcc
export CXX = g++
export MPICXX = mpicxx
# whether use HDFS support during compile
USE_HDFS = 1
# path to libjvm.so
LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server

View File

@ -14,7 +14,9 @@
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <cmath> #include <cmath>
#include <sstream>
#include <rabit.h> #include <rabit.h>
#include "../io/io.h"
namespace rabit { namespace rabit {
// typedef index type // typedef index type
@ -45,49 +47,37 @@ struct SparseMat {
} }
// load data from LibSVM format // load data from LibSVM format
inline void Load(const char *fname) { inline void Load(const char *fname) {
FILE *fi; io::InputSplit *in =
if (!strcmp(fname, "stdin")) { io::CreateInputSplit
fi = stdin; (fname, rabit::GetRank(),
} else { rabit::GetWorldSize());
if (strchr(fname, '%') != NULL) {
char s_tmp[256];
snprintf(s_tmp, sizeof(s_tmp), fname, rabit::GetRank());
fi = utils::FopenCheck(s_tmp, "r");
} else {
fi = utils::FopenCheck(fname, "r");
}
}
row_ptr.clear(); row_ptr.clear();
row_ptr.push_back(0); row_ptr.push_back(0);
data.clear(); data.clear();
feat_dim = 0; feat_dim = 0;
float label; bool init = true; std::string line;
char tmp[1024]; while (in->NextLine(&line)) {
while (fscanf(fi, "%s", tmp) == 1) { float label;
std::istringstream ss(line);
ss >> label;
Entry e; Entry e;
unsigned long fidx; unsigned long fidx;
if (sscanf(tmp, "%lu:%f", &fidx, &e.fvalue) == 2) { while (!ss.eof()) {
if (!(ss >> fidx)) break;
ss.ignore(32, ':');
if (!(ss >> e.fvalue)) break;
e.findex = static_cast<index_t>(fidx); e.findex = static_cast<index_t>(fidx);
data.push_back(e); data.push_back(e);
feat_dim = std::max(fidx, feat_dim); feat_dim = std::max(fidx, feat_dim);
} else {
if (!init) {
labels.push_back(label);
row_ptr.push_back(data.size());
}
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
init = false;
} }
labels.push_back(label);
row_ptr.push_back(data.size());
} }
// last row delete in;
labels.push_back(label);
row_ptr.push_back(data.size());
feat_dim += 1; feat_dim += 1;
utils::Check(feat_dim < std::numeric_limits<index_t>::max(), utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
"feature dimension exceed limit of index_t"\ "feature dimension exceed limit of index_t"\
"consider change the index_t to unsigned long"); "consider change the index_t to unsigned long");
// close the filed
if (fi != stdin) fclose(fi);
} }
inline size_t NumRow(void) const { inline size_t NumRow(void) const {
return row_ptr.size() - 1; return row_ptr.size() - 1;
@ -98,6 +88,7 @@ struct SparseMat {
std::vector<Entry> data; std::vector<Entry> data;
std::vector<float> labels; std::vector<float> labels;
}; };
// dense matrix // dense matrix
struct Matrix { struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) { inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {

View File

@ -1,40 +0,0 @@
#ifndef RABIT_LEARN_UTILS_IO_H_
#define RABIT_LEARN_UTILS_IO_H_
/*!
* \file io.h
* \brief additional stream interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_IO_H_

View File

@ -29,11 +29,24 @@ AllreduceBase::AllreduceBase(void) {
task_id = "NULL"; task_id = "NULL";
err_link = NULL; err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB"); this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of intrest
env_vars.push_back("rabit_task_id");
env_vars.push_back("rabit_num_trial");
env_vars.push_back("rabit_reduce_buffer");
env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port");
} }
// initialization function // initialization function
void AllreduceBase::Init(void) { void AllreduceBase::Init(void) {
// setup from enviroment variables // setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
}
}
{ {
// handling for hadoop // handling for hadoop
const char *task_id = getenv("mapred_tip_id"); const char *task_id = getenv("mapred_tip_id");

View File

@ -63,6 +63,10 @@ class AllreduceBase : public IEngine {
if (world_size == -1) return 1; if (world_size == -1) return 1;
return world_size; return world_size;
} }
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
return tracker_uri != "NULL";
}
/*! \brief get rank */ /*! \brief get rank */
virtual std::string GetHost(void) const { virtual std::string GetHost(void) const {
return host_uri; return host_uri;
@ -413,6 +417,8 @@ class AllreduceBase : public IEngine {
// pointer to links in the ring // pointer to links in the ring
LinkRecord *ring_prev, *ring_next; LinkRecord *ring_prev, *ring_next;
//----- meta information----- //----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
// unique identifier of the possible job this process is doing // unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL // used to assign ranks, optional, default to NULL
std::string task_id; std::string task_id;

View File

@ -28,6 +28,8 @@ AllreduceRobust::AllreduceRobust(void) {
global_lazycheck = NULL; global_lazycheck = NULL;
use_local_model = -1; use_local_model = -1;
recover_counter = 0; recover_counter = 0;
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
} }
void AllreduceRobust::Init(void) { void AllreduceRobust::Init(void) {
AllreduceBase::Init(); AllreduceBase::Init();

View File

@ -56,6 +56,10 @@ class EmptyEngine : public IEngine {
virtual int GetWorldSize(void) const { virtual int GetWorldSize(void) const {
return 1; return 1;
} }
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return false;
}
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const { virtual std::string GetHost(void) const {
return std::string(""); return std::string("");

View File

@ -59,6 +59,10 @@ class MPIEngine : public IEngine {
virtual int GetWorldSize(void) const { virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size(); return MPI::COMM_WORLD.Get_size();
} }
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return true;
}
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const { virtual std::string GetHost(void) const {
int len; int len;

View File

@ -0,0 +1,12 @@
Trackers
=====
This folder contains tracker scripts that can be used to submit yarn jobs to different platforms,
the example guidelines are in the script themselfs
***Supported Platforms***
* Local demo: [rabit_demo.py](rabit_demo.py)
* MPI: [rabit_mpi.py](rabit_mpi.py)
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios

View File

@ -31,35 +31,38 @@ nrep=0
rc=254 rc=254
while [ $rc -eq 254 ]; while [ $rc -eq 254 ];
do do
export rabit_num_trial=$nrep
%s
%s %s
%s %s rabit_num_trial=$nrep
rc=$?; rc=$?;
nrep=$((nrep+1)); nrep=$((nrep+1));
done done
""" """
def exec_cmd(cmd, taskid): def exec_cmd(cmd, taskid, worker_env):
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt': if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0] cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd) cmd = ' '.join(cmd)
arg = ' rabit_task_id=%d' % (taskid) env = {}
cmd = cmd + arg for k, v in worker_env.items():
env[k] = str(v)
env['rabit_task_id'] = str(taskid)
env['PYTHONPATH'] = WRAPPER_PATH
ntrial = 0 ntrial = 0
while True: while True:
if os.name == 'nt': if os.name == 'nt':
prep = 'SET PYTHONPATH=\"%s\"\n' % WRAPPER_PATH env['rabit_num_trial'] = str(ntrial)
ret = subprocess.call(prep + cmd + ('rabit_num_trial=%d' % ntrial), shell=True) ret = subprocess.call(cmd, shell=True, env = env)
if ret == 254: if ret == 254:
ntrial += 1 ntrial += 1
continue continue
else: else:
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
if args.verbose != 0: if args.verbose != 0:
bash = keepalive % (echo % cmd, prep, cmd) bash = keepalive % (echo % cmd, cmd)
else: else:
bash = keepalive % ('', prep, cmd) bash = keepalive % ('', cmd)
ret = subprocess.call(bash, shell=True, executable='bash') ret = subprocess.call(bash, shell=True, executable='bash', env = env)
if ret == 0: if ret == 0:
if args.verbose != 0: if args.verbose != 0:
print 'Thread %d exit with 0' % taskid print 'Thread %d exit with 0' % taskid
@ -73,7 +76,7 @@ def exec_cmd(cmd, taskid):
# Note: this submit script is only used for demo purpose # Note: this submit script is only used for demo purpose
# submission script using pyhton multi-threading # submission script using pyhton multi-threading
# #
def mthread_submit(nslave, worker_args): def mthread_submit(nslave, worker_args, worker_envs):
""" """
customized submit script, that submit nslave jobs, each must contain args as parameter customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input note this can be a lambda function containing additional parameters in input
@ -84,7 +87,7 @@ def mthread_submit(nslave, worker_args):
""" """
procs = {} procs = {}
for i in range(nslave): for i in range(nslave):
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i)) procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i, worker_envs))
procs[i].daemon = True procs[i].daemon = True
procs[i].start() procs[i].start()
for i in range(nslave): for i in range(nslave):

View File

@ -1,7 +1,11 @@
#!/usr/bin/python #!/usr/bin/python
""" """
Deprecated
This is a script to submit rabit job using hadoop streaming. This is a script to submit rabit job using hadoop streaming.
It will submit the rabit process as mappers of MapReduce. It will submit the rabit process as mappers of MapReduce.
This script is deprecated, it is highly recommended to use rabit_yarn.py instead
""" """
import argparse import argparse
import sys import sys
@ -34,13 +38,11 @@ if hadoop_binary == None or hadoop_streaming_jar == None:
', or modify rabit_hadoop.py line 16', stacklevel = 2) ', or modify rabit_hadoop.py line 16', stacklevel = 2)
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs using Hadoop Streaming.'\ parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs using Hadoop Streaming.'\
'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended') 'It is Highly recommended to use rabit_yarn.py instead')
parser.add_argument('-n', '--nworker', required=True, type=int, parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched') help = 'number of worker proccess to be launched')
parser.add_argument('-hip', '--host_ip', default='auto', type=str, parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine') help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-nt', '--nthread', default = -1, type=int,
help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded')
parser.add_argument('-i', '--input', required=True, parser.add_argument('-i', '--input', required=True,
help = 'input path in HDFS') help = 'input path in HDFS')
parser.add_argument('-o', '--output', required=True, parser.add_argument('-o', '--output', required=True,
@ -61,6 +63,8 @@ parser.add_argument('--jobname', default='auto', help = 'customize jobname in tr
parser.add_argument('--timeout', default=600000000, type=int, parser.add_argument('--timeout', default=600000000, type=int,
help = 'timeout (in million seconds) of each mapper job, automatically set to a very long time,'\ help = 'timeout (in million seconds) of each mapper job, automatically set to a very long time,'\
'normally you do not need to set this ') 'normally you do not need to set this ')
parser.add_argument('--vcores', default = -1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('-mem', '--memory_mb', default=-1, type=int, parser.add_argument('-mem', '--memory_mb', default=-1, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\ help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\ 'if you are running multi-threading rabit,'\
@ -91,10 +95,14 @@ out = out.split('\n')[0].split()
assert out[0] == 'Hadoop', 'cannot parse hadoop version string' assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
hadoop_version = out[1].split('.') hadoop_version = out[1].split('.')
use_yarn = int(hadoop_version[0]) >= 2 use_yarn = int(hadoop_version[0]) >= 2
if use_yarn:
warnings.warn('It is highly recommended to use rabit_yarn.py to submit jobs to yarn instead', stacklevel = 2)
print 'Current Hadoop Version is %s' % out[1] print 'Current Hadoop Version is %s' % out[1]
def hadoop_streaming(nworker, worker_args, use_yarn): def hadoop_streaming(nworker, worker_args, worker_envs, use_yarn):
worker_envs['CLASSPATH'] = '`$HADOOP_HOME/bin/hadoop classpath --glob` '
worker_envs['LD_LIBRARY_PATH'] = '{LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server'
fset = set() fset = set()
if args.auto_file_cache: if args.auto_file_cache:
for i in range(len(args.command)): for i in range(len(args.command)):
@ -113,6 +121,7 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
if os.path.exists(f): if os.path.exists(f):
fset.add(f) fset.add(f)
kmap = {} kmap = {}
kmap['env'] = 'mapred.child.env'
# setup keymaps # setup keymaps
if use_yarn: if use_yarn:
kmap['nworker'] = 'mapreduce.job.maps' kmap['nworker'] = 'mapreduce.job.maps'
@ -129,12 +138,14 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
cmd = '%s jar %s' % (args.hadoop_binary, args.hadoop_streaming_jar) cmd = '%s jar %s' % (args.hadoop_binary, args.hadoop_streaming_jar)
cmd += ' -D%s=%d' % (kmap['nworker'], nworker) cmd += ' -D%s=%d' % (kmap['nworker'], nworker)
cmd += ' -D%s=%s' % (kmap['jobname'], args.jobname) cmd += ' -D%s=%s' % (kmap['jobname'], args.jobname)
if args.nthread != -1: envstr = ','.join('%s=%s' % (k, str(v)) for k, v in worker_envs.items())
cmd += ' -D%s=\"%s\"' % (kmap['env'], envstr)
if args.vcores != -1:
if kmap['nthread'] is None: if kmap['nthread'] is None:
warnings.warn('nthread can only be set in Yarn(Hadoop version greater than 2.0),'\ warnings.warn('nthread can only be set in Yarn(Hadoop version greater than 2.0),'\
'it is recommended to use Yarn to submit rabit jobs', stacklevel = 2) 'it is recommended to use Yarn to submit rabit jobs', stacklevel = 2)
else: else:
cmd += ' -D%s=%d' % (kmap['nthread'], args.nthread) cmd += ' -D%s=%d' % (kmap['nthread'], args.vcores)
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout) cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
if args.memory_mb != -1: if args.memory_mb != -1:
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout) cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
@ -150,5 +161,5 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
print cmd print cmd
subprocess.check_call(cmd, shell = True) subprocess.check_call(cmd, shell = True)
fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2) fun_submit = lambda nworker, worker_args, worker_envs: hadoop_streaming(nworker, worker_args, worker_envs, int(hadoop_version[0]) >= 2)
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip) tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip)

View File

@ -22,7 +22,7 @@ args = parser.parse_args()
# #
# submission script using MPI # submission script using MPI
# #
def mpi_submit(nslave, worker_args): def mpi_submit(nslave, worker_args, worker_envs):
""" """
customized submit script, that submit nslave jobs, each must contain args as parameter customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input note this can be a lambda function containing additional parameters in input
@ -31,6 +31,7 @@ def mpi_submit(nslave, worker_args):
args arguments to launch each job args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit this usually includes the parameters of master_uri and parameters passed into submit
""" """
worker_args += ['%s=%s' % (k, str(v)) for k, v in worker_envs.items()]
sargs = ' '.join(args.command + worker_args) sargs = ' '.join(args.command + worker_args)
if args.hostfile is None: if args.hostfile is None:
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args.command + worker_args) cmd = ' '.join(['mpirun -n %d' % (nslave)] + args.command + worker_args)

View File

@ -134,19 +134,25 @@ class Tracker:
sock.listen(16) sock.listen(16)
self.sock = sock self.sock = sock
self.verbose = verbose self.verbose = verbose
if hostIP == 'auto':
hostIP = 'dns'
self.hostIP = hostIP self.hostIP = hostIP
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1) self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
def __del__(self): def __del__(self):
self.sock.close() self.sock.close()
def slave_args(self): def slave_envs(self):
if self.hostIP == 'auto': """
get enviroment variables for slaves
can be passed in as args or envs
"""
if self.hostIP == 'dns':
host = socket.gethostname() host = socket.gethostname()
elif self.hostIP == 'ip': elif self.hostIP == 'ip':
host = socket.gethostbyname(socket.getfqdn()) host = socket.gethostbyname(socket.getfqdn())
else: else:
host = self.hostIP host = self.hostIP
return ['rabit_tracker_uri=%s' % host, return {'rabit_tracker_uri': host,
'rabit_tracker_port=%s' % self.port] 'rabit_tracker_port': self.port}
def get_neighbor(self, rank, nslave): def get_neighbor(self, rank, nslave):
rank = rank + 1 rank = rank + 1
ret = [] ret = []
@ -261,9 +267,9 @@ class Tracker:
wait_conn[rank] = s wait_conn[rank] = s
self.log_print('@tracker All nodes finishes job', 2) self.log_print('@tracker All nodes finishes job', 2)
def submit(nslave, args, fun_submit, verbose, hostIP): def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
master = Tracker(verbose = verbose, hostIP = hostIP) master = Tracker(verbose = verbose, hostIP = hostIP)
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args())) submit_thread = Thread(target = fun_submit, args = (nslave, args, master.slave_envs()))
submit_thread.daemon = True submit_thread.daemon = True
submit_thread.start() submit_thread.start()
master.accept_slaves(nslave) master.accept_slaves(nslave)

View File

@ -0,0 +1,129 @@
#!/usr/bin/python
"""
This is a script to submit rabit job via Yarn
rabit will run as a Yarn application
"""
import argparse
import sys
import os
import time
import subprocess
import warnings
import rabit_tracker as tracker
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
if not os.path.exists(YARN_JAR_PATH):
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
cmd = 'cd %s;./build.sh' % (os.path.dirname(__file__) + '/../yarn/')
print cmd
subprocess.check_call(cmd, shell = True, env = os.environ)
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
hadoop_binary = 'hadoop'
# code
hadoop_home = os.getenv('HADOOP_HOME')
if hadoop_home != None:
if hadoop_binary == None:
hadoop_binary = hadoop_home + '/bin/hadoop'
assert os.path.exists(hadoop_binary), "HADOOP_HOME does not contain the hadoop binary"
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs to Yarn.')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console')
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
parser.add_argument('-f', '--files', default = [], action='append',
help = 'the cached file list in mapreduce,'\
' the submission script will automatically cache all the files which appears in command'\
' This will also cause rewritten of all the file names in the command to current path,'\
' for example `../../kmeans ../kmeans.conf` will be rewritten to `./kmeans kmeans.conf`'\
' because the two files are cached to running folder.'\
' You may need this option to cache additional files.'\
' You can also use it to manually cache files when auto_file_cache is off')
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
parser.add_argument('--tempdir', default='/tmp', help = 'temporary directory in HDFS that can be used to store intermediate results')
parser.add_argument('--vcores', default = 1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\
'so that each node can occupy all the mapper slots in a machine for maximum performance')
parser.add_argument('command', nargs='+',
help = 'command for rabit program')
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
if hadoop_binary == None:
parser.add_argument('-hb', '--hadoop_binary', required = True,
help="path to hadoop binary file")
else:
parser.add_argument('-hb', '--hadoop_binary', default = hadoop_binary,
help="path to hadoop binary file")
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
# detech hadoop version
(out, err) = subprocess.Popen('%s version' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
out = out.split('\n')[0].split()
assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
hadoop_version = out[1].split('.')
(classpath, err) = subprocess.Popen('%s classpath --glob' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
if hadoop_version < 2:
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
def submit_yarn(nworker, worker_args, worker_env):
fset = set([YARN_JAR_PATH])
if args.auto_file_cache != 0:
for i in range(len(args.command)):
f = args.command[i]
if os.path.exists(f):
fset.add(f)
if i == 0:
args.command[i] = './' + args.command[i].split('/')[-1]
else:
args.command[i] = args.command[i].split('/')[-1]
if args.command[0].endswith('.py'):
flst = [WRAPPER_PATH + '/rabit.py',
WRAPPER_PATH + '/librabit_wrapper.so',
WRAPPER_PATH + '/librabit_wrapper_mock.so']
for f in flst:
if os.path.exists(f):
fset.add(f)
cmd = 'java -cp `%s classpath`:%s org.apache.hadoop.yarn.rabit.Client ' % (args.hadoop_binary, YARN_JAR_PATH)
env = os.environ.copy()
for k, v in worker_env.items():
env[k] = str(v)
env['rabit_cpu_vcores'] = str(args.vcores)
env['rabit_memory_mb'] = str(args.memory_mb)
env['rabit_world_size'] = str(args.nworker)
if args.files != None:
for flst in args.files:
for f in flst.split('#'):
fset.add(f)
for f in fset:
cmd += ' -file %s' % f
cmd += ' -jobname %s ' % args.jobname
cmd += ' -tempdir %s ' % args.tempdir
cmd += (' '.join(args.command + worker_args))
if args.verbose != 0:
print cmd
subprocess.check_call(cmd, shell = True, env = env)
tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)

4
subtree/rabit/yarn/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
bin
.classpath
.project
*.jar

View File

@ -0,0 +1,5 @@
rabit-yarn
=====
* This folder contains Application code to allow rabit run on Yarn.
* You can use [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) to submit the job
- run ```./build.sh``` to build the jar, before using the script

View File

@ -0,0 +1 @@
foler used to hold generated class files

8
subtree/rabit/yarn/build.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
if [ -z "$HADOOP_PREFIX" ]; then
echo "cannot found $HADOOP_PREFIX in the environment variable, please set it properly"
exit 1
fi
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
jar cf rabit-yarn.jar -C bin .

View File

@ -0,0 +1,508 @@
package org.apache.hadoop.yarn.rabit;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Collection;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.ContainerState;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
/**
* application master for allocating resources of rabit client
*
* @author Tianqi Chen
*/
public class ApplicationMaster {
// logger
private static final Log LOG = LogFactory.getLog(ApplicationMaster.class);
// configuration
private Configuration conf = new YarnConfiguration();
// hdfs handler
private FileSystem dfs;
// number of cores allocated for each task
private int numVCores = 1;
// memory needed requested for the task
private int numMemoryMB = 10;
// priority of the app master
private int appPriority = 0;
// total number of tasks
private int numTasks = 1;
// maximum number of attempts to try in each task
private int maxNumAttempt = 3;
// command to launch
private String command = "";
// application tracker hostname
private String appHostName = "";
// tracker URL to do
private String appTrackerUrl = "";
// tracker port
private int appTrackerPort = 0;
// whether we start to abort the application, due to whatever fatal reasons
private boolean startAbort = false;
// worker resources
private Map<String, LocalResource> workerResources = new java.util.HashMap<String, LocalResource>();
// record the aborting reason
private String abortDiagnosis = "";
// resource manager
private AMRMClientAsync<ContainerRequest> rmClient = null;
// node manager
private NMClientAsync nmClient = null;
// list of tasks that pending for resources to be allocated
private final Queue<TaskRecord> pendingTasks = new java.util.LinkedList<TaskRecord>();
// map containerId->task record of tasks that was running
private final Map<ContainerId, TaskRecord> runningTasks = new java.util.HashMap<ContainerId, TaskRecord>();
// collection of tasks
private final Collection<TaskRecord> finishedTasks = new java.util.LinkedList<TaskRecord>();
// collection of killed tasks
private final Collection<TaskRecord> killedTasks = new java.util.LinkedList<TaskRecord>();
public static void main(String[] args) throws Exception {
new ApplicationMaster().run(args);
}
private ApplicationMaster() throws IOException {
dfs = FileSystem.get(conf);
}
/**
* get integer argument from environment variable
*
* @param name
* name of key
* @param required
* whether this is required
* @param defv
* default value
* @return the requested result
*/
private int getEnvInteger(String name, boolean required, int defv)
throws IOException {
String value = System.getenv(name);
if (value == null) {
if (required) {
throw new IOException("environment variable " + name
+ " not set");
} else {
return defv;
}
}
return Integer.valueOf(value);
}
/**
* initialize from arguments and command lines
*
* @param args
*/
private void initArgs(String args[]) throws IOException {
LOG.info("Invoke initArgs");
// cached maps
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
for (int i = 0; i < args.length; ++i) {
if (args[i].equals("-file")) {
String[] arr = args[++i].split("#");
Path path = new Path(arr[0]);
if (arr.length == 1) {
cacheFiles.put(path.getName(), path);
} else {
cacheFiles.put(arr[1], path);
}
} else {
this.command += args[i] + " ";
}
}
for (Map.Entry<String, Path> e : cacheFiles.entrySet()) {
LocalResource r = Records.newRecord(LocalResource.class);
FileStatus status = dfs.getFileStatus(e.getValue());
r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue()));
r.setSize(status.getLen());
r.setTimestamp(status.getModificationTime());
r.setType(LocalResourceType.FILE);
r.setVisibility(LocalResourceVisibility.APPLICATION);
workerResources.put(e.getKey(), r);
}
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt);
}
/**
* called to start the application
*/
private void run(String args[]) throws Exception {
this.initArgs(args);
this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000,
new RMCallbackHandler());
this.nmClient = NMClientAsync
.createNMClientAsync(new NMCallbackHandler());
this.rmClient.init(conf);
this.rmClient.start();
this.nmClient.init(conf);
this.nmClient.start();
RegisterApplicationMasterResponse response = this.rmClient
.registerApplicationMaster(this.appHostName,
this.appTrackerPort, this.appTrackerUrl);
boolean success = false;
String diagnostics = "";
try {
// list of tasks that waits to be submit
java.util.Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
// add waiting tasks
for (int i = 0; i < this.numTasks; ++i) {
tasks.add(new TaskRecord(i));
}
Resource maxResource = response.getMaximumResourceCapability();
if (maxResource.getMemory() < this.numMemoryMB) {
LOG.warn("[Rabit] memory requested exceed bound "
+ maxResource.getMemory());
this.numMemoryMB = maxResource.getMemory();
}
if (maxResource.getVirtualCores() < this.numVCores) {
LOG.warn("[Rabit] memory requested exceed bound "
+ maxResource.getVirtualCores());
this.numVCores = maxResource.getVirtualCores();
}
this.submitTasks(tasks);
LOG.info("[Rabit] ApplicationMaster started");
while (!this.doneAllJobs()) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
}
assert (killedTasks.size() + finishedTasks.size() == numTasks);
success = finishedTasks.size() == numTasks;
LOG.info("Application completed. Stopping running containers");
nmClient.stop();
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
+ ", finished=" + this.finishedTasks.size() + ", failed="
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
LOG.info(diagnostics);
} catch (Exception e) {
diagnostics = e.toString();
}
rmClient.unregisterApplicationMaster(
success ? FinalApplicationStatus.SUCCEEDED
: FinalApplicationStatus.FAILED, diagnostics,
appTrackerUrl);
if (!success) throw new Exception("Application not successful");
}
/**
* check if the job finishes
*
* @return whether we finished all the jobs
*/
private synchronized boolean doneAllJobs() {
return pendingTasks.size() == 0 && runningTasks.size() == 0;
}
/**
* submit tasks to request containers for the tasks
*
* @param tasks
* a collection of tasks we want to ask container for
*/
private synchronized void submitTasks(Collection<TaskRecord> tasks) {
for (TaskRecord r : tasks) {
Resource resource = Records.newRecord(Resource.class);
resource.setMemory(numMemoryMB);
resource.setVirtualCores(numVCores);
Priority priority = Records.newRecord(Priority.class);
priority.setPriority(this.appPriority);
r.containerRequest = new ContainerRequest(resource, null, null,
priority);
rmClient.addContainerRequest(r.containerRequest);
pendingTasks.add(r);
}
}
/**
* launch the task on container
*
* @param container
* container to run the task
* @param task
* the task
*/
private void launchTask(Container container, TaskRecord task) {
task.container = container;
task.containerRequest = null;
ContainerLaunchContext ctx = Records
.newRecord(ContainerLaunchContext.class);
String cmd =
// use this to setup CLASSPATH correctly for libhdfs
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` "
+ this.command + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ "/stderr";
LOG.info(cmd);
ctx.setCommands(Collections.singletonList(cmd));
LOG.info(workerResources);
ctx.setLocalResources(this.workerResources);
// setup environment variables
Map<String, String> env = new java.util.HashMap<String, String>();
// setup class path, this is kind of duplicated, ignoring
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath.append(':');
cpath.append(c.trim());
}
// already use hadoop command to get class path in worker, maybe a better solution in future
// env.put("CLASSPATH", cpath.toString());
// setup LD_LIBARY_PATH path for libhdfs
env.put("LD_LIBRARY_PATH",
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
env.put("PYTHONPATH", "${PYTHONPATH}:.");
// inherit all rabit variables
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
}
env.put("rabit_task_id", String.valueOf(task.taskId));
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
ctx.setEnvironment(env);
synchronized (this) {
assert (!this.runningTasks.containsKey(container.getId()));
this.runningTasks.put(container.getId(), task);
this.nmClient.startContainerAsync(container, ctx);
}
}
/**
* free the containers that have not yet been launched
*
* @param containers
*/
private synchronized void freeUnusedContainers(
Collection<Container> containers) {
}
/**
* handle method for AMRMClientAsync.CallbackHandler container allocation
*
* @param containers
*/
private synchronized void onContainersAllocated(List<Container> containers) {
if (this.startAbort) {
this.freeUnusedContainers(containers);
return;
}
Collection<Container> freelist = new java.util.LinkedList<Container>();
for (Container c : containers) {
TaskRecord task;
task = pendingTasks.poll();
if (task == null) {
freelist.add(c);
continue;
}
this.launchTask(c, task);
}
this.freeUnusedContainers(freelist);
}
/**
* start aborting the job
*
* @param msg
* the fatal message
*/
private synchronized void abortJob(String msg) {
if (!this.startAbort)
this.abortDiagnosis = msg;
this.startAbort = true;
for (TaskRecord r : this.runningTasks.values()) {
if (!r.abortRequested) {
nmClient.stopContainerAsync(r.container.getId(),
r.container.getNodeId());
r.abortRequested = true;
}
}
this.killedTasks.addAll(this.pendingTasks);
for (TaskRecord r : this.pendingTasks) {
rmClient.removeContainerRequest(r.containerRequest);
}
this.pendingTasks.clear();
LOG.info(msg);
}
/**
* handle non fatal failures
*
* @param cid
*/
private synchronized void handleFailure(Collection<ContainerId> failed) {
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
for (ContainerId cid : failed) {
TaskRecord r = runningTasks.remove(cid);
if (r == null)
continue;
r.attemptCounter += 1;
r.container = null;
tasks.add(r);
if (r.attemptCounter >= this.maxNumAttempt) {
this.abortJob("[Rabit] Task " + r.taskId + " failed more than "
+ r.attemptCounter + "times");
}
}
if (this.startAbort) {
this.killedTasks.addAll(tasks);
} else {
this.submitTasks(tasks);
}
}
/**
* handle method for AMRMClientAsync.CallbackHandler container allocation
*
* @param status
* list of status
*/
private synchronized void onContainersCompleted(List<ContainerStatus> status) {
Collection<ContainerId> failed = new java.util.LinkedList<ContainerId>();
for (ContainerStatus s : status) {
assert (s.getState().equals(ContainerState.COMPLETE));
int exstatus = s.getExitStatus();
TaskRecord r = runningTasks.get(s.getContainerId());
if (r == null)
continue;
if (exstatus == ContainerExitStatus.SUCCESS) {
finishedTasks.add(r);
runningTasks.remove(s.getContainerId());
} else {
switch (exstatus) {
case ContainerExitStatus.KILLED_EXCEEDED_PMEM:
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated physical memory");
break;
case ContainerExitStatus.KILLED_EXCEEDED_VMEM:
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated virtual memory");
break;
default:
LOG.info("[Rabit] Task " + r.taskId
+ " exited with status " + exstatus);
failed.add(s.getContainerId());
}
}
}
this.handleFailure(failed);
}
/**
* callback handler for resource manager
*/
private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler {
@Override
public float getProgress() {
return 1.0f - (float) (pendingTasks.size()) / numTasks;
}
@Override
public void onContainersAllocated(List<Container> containers) {
ApplicationMaster.this.onContainersAllocated(containers);
}
@Override
public void onContainersCompleted(List<ContainerStatus> status) {
ApplicationMaster.this.onContainersCompleted(status);
}
@Override
public void onError(Throwable ex) {
ApplicationMaster.this.abortJob("[Rabit] Resource manager Error "
+ ex.toString());
}
@Override
public void onNodesUpdated(List<NodeReport> nodereport) {
}
@Override
public void onShutdownRequest() {
ApplicationMaster.this
.abortJob("[Rabit] Get shutdown request, start to shutdown...");
}
}
private class NMCallbackHandler implements NMClientAsync.CallbackHandler {
@Override
public void onContainerStarted(ContainerId cid,
Map<String, ByteBuffer> services) {
LOG.debug("onContainerStarted Invoked");
}
@Override
public void onContainerStatusReceived(ContainerId cid,
ContainerStatus status) {
LOG.debug("onContainerStatusReceived Invoked");
}
@Override
public void onContainerStopped(ContainerId cid) {
LOG.debug("onContainerStopped Invoked");
}
@Override
public void onGetContainerStatusError(ContainerId cid, Throwable ex) {
LOG.debug("onGetContainerStatusError Invoked: " + ex.toString());
ApplicationMaster.this
.handleFailure(Collections.singletonList(cid));
}
@Override
public void onStartContainerError(ContainerId cid, Throwable ex) {
LOG.debug("onStartContainerError Invoked: " + ex.toString());
ApplicationMaster.this
.handleFailure(Collections.singletonList(cid));
}
@Override
public void onStopContainerError(ContainerId cid, Throwable ex) {
LOG.info("onStopContainerError Invoked: " + ex.toString());
}
}
}

View File

@ -0,0 +1,233 @@
package org.apache.hadoop.yarn.rabit;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
public class Client {
// logger
private static final Log LOG = LogFactory.getLog(Client.class);
// permission for temp file
private static final FsPermission permTemp = new FsPermission("777");
// configuration
private YarnConfiguration conf = new YarnConfiguration();
// hdfs handler
private FileSystem dfs;
// cached maps
private Map<String, String> cacheFiles = new java.util.HashMap<String, String>();
// enviroment variable to setup cachefiles
private String cacheFileArg = "";
// args to pass to application master
private String appArgs = "";
// HDFS Path to store temporal result
private String tempdir = "/tmp";
// job name
private String jobName = "";
/**
* constructor
* @throws IOException
*/
private Client() throws IOException {
dfs = FileSystem.get(conf);
}
/**
* ge
*
* @param fmaps
* the file maps
* @return the resource map
* @throws IOException
*/
private Map<String, LocalResource> setupCacheFiles(ApplicationId appId) throws IOException {
// create temporary rabit directory
Path tmpPath = new Path(this.tempdir);
if (!dfs.exists(tmpPath)) {
dfs.mkdirs(tmpPath, permTemp);
LOG.info("HDFS temp directory do not exist, creating.. " + tmpPath);
}
tmpPath = new Path(tmpPath + "/temp-rabit-yarn-" + appId);
if (dfs.exists(tmpPath)) {
dfs.delete(tmpPath, true);
}
// create temporary directory
FileSystem.mkdirs(dfs, tmpPath, permTemp);
StringBuilder cstr = new StringBuilder();
Map<String, LocalResource> rmap = new java.util.HashMap<String, LocalResource>();
for (Map.Entry<String, String> e : cacheFiles.entrySet()) {
LocalResource r = Records.newRecord(LocalResource.class);
Path path = new Path(e.getValue());
// copy local data to temporary folder in HDFS
if (!e.getValue().startsWith("hdfs://")) {
Path dst = new Path("hdfs://" + tmpPath + "/"+ path.getName());
dfs.copyFromLocalFile(false, true, path, dst);
dfs.setPermission(dst, permTemp);
dfs.deleteOnExit(dst);
path = dst;
}
FileStatus status = dfs.getFileStatus(path);
r.setResource(ConverterUtils.getYarnUrlFromPath(path));
r.setSize(status.getLen());
r.setTimestamp(status.getModificationTime());
r.setType(LocalResourceType.FILE);
r.setVisibility(LocalResourceVisibility.APPLICATION);
rmap.put(e.getKey(), r);
cstr.append(" -file \"");
cstr.append(path.toString());
cstr.append('#');
cstr.append(e.getKey());
cstr.append("\"");
}
dfs.deleteOnExit(tmpPath);
this.cacheFileArg = cstr.toString();
return rmap;
}
/**
* get the environment variables for container
*
* @return the env variable for child class
*/
private Map<String, String> getEnvironment() {
// Setup environment variables
Map<String, String> env = new java.util.HashMap<String, String>();
String cpath = "${CLASSPATH}:./*";
for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath += ':';
cpath += c.trim();
}
env.put("CLASSPATH", cpath);
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
}
LOG.debug(env);
return env;
}
/**
* initialize the settings
*
* @param args
*/
private void initArgs(String[] args) {
// directly pass all arguments except args0
StringBuilder sargs = new StringBuilder("");
for (int i = 0; i < args.length; ++i) {
if (args[i].equals("-file")) {
String[] arr = args[++i].split("#");
if (arr.length == 1) {
cacheFiles.put(new Path(arr[0]).getName(), arr[0]);
} else {
cacheFiles.put(arr[1], arr[0]);
}
} else if(args[i].equals("-jobname")) {
this.jobName = args[++i];
} else if(args[i].equals("-tempdir")) {
this.tempdir = args[++i];
} else {
sargs.append(" ");
sargs.append(args[i]);
}
}
this.appArgs = sargs.toString();
}
private void run(String[] args) throws Exception {
if (args.length == 0) {
System.out.println("Usage: [options] [commands..]");
System.out.println("options: [-file filename]");
return;
}
this.initArgs(args);
// Create yarnClient
YarnConfiguration conf = new YarnConfiguration();
YarnClient yarnClient = YarnClient.createYarnClient();
yarnClient.init(conf);
yarnClient.start();
// Create application via yarnClient
YarnClientApplication app = yarnClient.createApplication();
// Set up the container launch context for the application master
ContainerLaunchContext amContainer = Records
.newRecord(ContainerLaunchContext.class);
ApplicationSubmissionContext appContext = app
.getApplicationSubmissionContext();
// Submit application
ApplicationId appId = appContext.getApplicationId();
// setup cache-files and environment variables
amContainer.setLocalResources(this.setupCacheFiles(appId));
amContainer.setEnvironment(this.getEnvironment());
String cmd = "$JAVA_HOME/bin/java"
+ " -Xmx256M"
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr";
LOG.debug(cmd);
amContainer.setCommands(Collections.singletonList(cmd));
// Set up resource type requirements for ApplicationMaster
Resource capability = Records.newRecord(Resource.class);
capability.setMemory(256);
capability.setVirtualCores(1);
LOG.info("jobname=" + this.jobName);
appContext.setApplicationName(jobName + ":RABIT-YARN");
appContext.setAMContainerSpec(amContainer);
appContext.setResource(capability);
appContext.setQueue("default");
LOG.info("Submitting application " + appId);
yarnClient.submitApplication(appContext);
ApplicationReport appReport = yarnClient.getApplicationReport(appId);
YarnApplicationState appState = appReport.getYarnApplicationState();
while (appState != YarnApplicationState.FINISHED
&& appState != YarnApplicationState.KILLED
&& appState != YarnApplicationState.FAILED) {
Thread.sleep(100);
appReport = yarnClient.getApplicationReport(appId);
appState = appReport.getYarnApplicationState();
}
System.out.println("Application " + appId + " finished with"
+ " state " + appState + " at " + appReport.getFinishTime());
if (!appReport.getFinalApplicationStatus().equals(
FinalApplicationStatus.SUCCEEDED)) {
System.err.println(appReport.getDiagnostics());
}
}
public static void main(String[] args) throws Exception {
new Client().run(args);
}
}

View File

@ -0,0 +1,24 @@
package org.apache.hadoop.yarn.rabit;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
/**
* data structure to hold the task information
*/
public class TaskRecord {
// task id of the task
public int taskId = 0;
// number of failed attempts to run the task
public int attemptCounter = 0;
// container request, can be null if task is already running
public ContainerRequest containerRequest = null;
// running container, can be null if the task is not launched
public Container container = null;
// whether we have requested abortion of this task
public boolean abortRequested = false;
public TaskRecord(int taskId) {
this.taskId = taskId;
}
}

View File

@ -112,7 +112,7 @@ using namespace xgboost::wrapper;
extern "C"{ extern "C"{
void* XGDMatrixCreateFromFile(const char *fname, int silent) { void* XGDMatrixCreateFromFile(const char *fname, int silent) {
return LoadDataMatrix(fname, silent != 0, false); return LoadDataMatrix(fname, silent != 0, false, false);
} }
void* XGDMatrixCreateFromCSR(const bst_ulong *indptr, void* XGDMatrixCreateFromCSR(const bst_ulong *indptr,
const unsigned *indices, const unsigned *indices,