Merge commit '57b5d7873f4f0953357e9d98e9c60cff8373d7ec'

This commit is contained in:
tqchen 2015-03-09 13:28:38 -07:00
commit 9f7c6fe271
43 changed files with 1797 additions and 235 deletions

View File

@ -2,7 +2,7 @@ ifndef CXX
export CXX = g++
endif
export MPICXX = mpicxx
export LDFLAGS= -Llib
export LDFLAGS= -Llib -lrt
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -pedantic
export CFLAGS = -O3 -msse2 -fPIC $(WARNFLAGS)
@ -50,7 +50,7 @@ $(ALIB):
ar cr $@ $+
$(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~

View File

@ -13,7 +13,7 @@ All these features comes from the facts about small rabbit:)
* Portable: rabit is light weight and runs everywhere
- Rabit is a library instead of a framework, a program only needs to link the library to run
- Rabit only replies on a mechanism to start program, which was provided by most framework
- You can run rabit programs on many platforms, including Hadoop, MPI using the same code
- You can run rabit programs on many platforms, including Yarn(Hadoop), MPI using the same code
* Scalable and Flexible: rabit runs fast
* Rabit program use Allreduce to communicate, and do not suffer the cost between iterations of MapReduce abstraction.
- Programs can call rabit functions in any order, as opposed to frameworks where callbacks are offered and called by the framework, i.e. inversion of control principle.

View File

@ -341,12 +341,11 @@ Rabit is a portable library that can run on multiple platforms.
* This script will restart the program when it exits with -2, so it can be used for [mock test](#link-against-mock-test-library)
#### Running Rabit on Hadoop
* You can use [../tracker/rabit_hadoop.py](../tracker/rabit_hadoop.py) to run rabit programs on hadoop
* This will start n rabit programs as mappers of MapReduce
* Each program can read its portion of data from stdin
* Yarn(Hadoop 2.0 or higher) is highly recommended, since Yarn allows specifying number of cpus and memory of each mapper:
* You can use [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) to run rabit programs as Yarn application
* This will start rabit programs as yarn applications
- This allows multi-threading programs in each node, which can be more efficient
- An easy multi-threading solution could be to use OpenMP with rabit code
* It is also possible to run rabit program via hadoop streaming, however, YARN is highly recommended.
#### Running Rabit using MPI
* You can submit rabit programs to an MPI cluster using [../tracker/rabit_mpi.py](../tracker/rabit_mpi.py).
@ -358,15 +357,15 @@ tracker scripts, such as [../tracker/rabit_hadoop.py](../tracker/rabit_hadoop.py
You will need to implement a platform dependent submission function with the following definition
```python
def fun_submit(nworkers, worker_args):
def fun_submit(nworkers, worker_args, worker_envs):
"""
customized submit script, that submits nslave jobs,
each must contain args as parameter
note this can be a lambda closure
Parameters
nworkers number of worker processes to start
worker_args tracker information which must be passed to the arguments
this usually includes the parameters of master_uri and port, etc.
worker_args addtiional arguments that needs to be passed to worker
worker_envs enviroment variables that need to be set to the worker
"""
```
The submission function should start nworkers processes in the platform, and append worker_args to the end of the other arguments.
@ -374,7 +373,7 @@ Then you can simply call ```tracker.submit``` with fun_submit to submit jobs to
Note that the current rabit tracker does not restart a worker when it dies, the restart of a node is done by the platform, otherwise we should write the fail-restart logic in the custom script.
* Fail-restart is usually provided by most platforms.
* For example, mapreduce will restart a mapper when it fails
- rabit-yarn provides such functionality in YARN
Fault Tolerance
=====

View File

@ -23,6 +23,8 @@ class ISeekStream: public IStream {
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
/*! \return whether we are at end of file */
virtual bool AtEnd(void) const = 0;
};
/*! \brief fixed size memory buffer */
@ -55,7 +57,9 @@ struct MemoryFixSizeBuffer : public ISeekStream {
virtual size_t Tell(void) {
return curr_ptr_;
}
virtual bool AtEnd(void) const {
return curr_ptr_ == buffer_size_;
}
private:
/*! \brief in memory buffer */
char *p_buffer_;
@ -95,7 +99,9 @@ struct MemoryBufferStream : public ISeekStream {
virtual size_t Tell(void) {
return curr_ptr_;
}
virtual bool AtEnd(void) const {
return curr_ptr_ == p_buffer_->length();
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;

View File

@ -3,9 +3,13 @@
* \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_TIMER_H
#define RABIT_TIMER_H
#ifndef RABIT_TIMER_H_
#define RABIT_TIMER_H_
#include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#include "./utils.h"
namespace rabit {
@ -14,10 +18,19 @@ namespace utils {
* \brief return time in seconds, not cross platform, avoid to use this in most places
*/
inline double GetTime(void) {
#ifdef __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
utils::Check(clock_get_time(cclock, &mts) == 0, "failed to get time");
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
timespec ts;
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#endif
}
}
}
#endif
} // namespace utils
} // namespace rabit
#endif // RABIT_TIMER_H_

View File

@ -5,15 +5,13 @@ It also contain links to the Machine Learning packages that uses rabit.
* Contribution of toolkits, examples, benchmarks is more than welcomed!
Toolkits
====
* [KMeans Clustering](kmeans)
* [Linear and Logistic Regression](linear)
* [Linear and Logistic Regression](linear)
* [XGBoost: eXtreme Gradient Boosting](https://github.com/tqchen/xgboost/tree/master/multi-node)
- xgboost is a very fast boosted tree(also known as GBDT) library, that can run more than
10 times faster than existing packages
- Rabit carries xgboost to distributed enviroment, inheritating all the benefits of xgboost
single node version, and scale it to even larger problems

View File

@ -1,5 +1,5 @@
#ifndef RABIT_LEARN_UTILS_BASE64_H_
#define RABIT_LEARN_UTILS_BASE64_H_
#ifndef RABIT_LEARN_IO_BASE64_INL_H_
#define RABIT_LEARN_IO_BASE64_INL_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
@ -8,10 +8,11 @@
*/
#include <cctype>
#include <cstdio>
#include <rabit/io.h>
#include "./io.h"
#include "./buffer_reader-inl.h"
namespace rabit {
namespace utils {
namespace io {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
@ -34,7 +35,8 @@ static const char EncodeTable[] =
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
explicit Base64InStream(IStream *fs) : reader_(256) {
reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0;
}
/*!
@ -44,7 +46,7 @@ class Base64InStream: public IStream {
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
tmp_ch = reader_.GetChar();
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
@ -85,19 +87,19 @@ class Base64InStream: public IStream {
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format");
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
@ -110,10 +112,10 @@ class Base64InStream: public IStream {
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
@ -125,10 +127,10 @@ class Base64InStream: public IStream {
}
}
// get next char
tmp_ch = fgetc(fp);
tmp_ch = reader_.GetChar();
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
utils::Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
@ -137,7 +139,7 @@ class Base64InStream: public IStream {
}
private:
FILE *fp;
StreamBufferReader reader_;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
@ -147,7 +149,7 @@ class Base64InStream: public IStream {
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
explicit Base64OutStream(IStream *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
@ -160,16 +162,16 @@ class Base64OutStream: public IStream {
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
PutChar(EncodeTable[buf[3] & 0x3F]);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
utils::Error("Base64OutStream do not support read");
return 0;
}
/*!
@ -179,26 +181,38 @@ class Base64OutStream: public IStream {
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
PutChar('=');
PutChar('=');
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
PutChar('=');
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
if (endch != EOF) PutChar(endch);
this->Flush();
}
private:
FILE *fp;
private:
IStream *fp;
int buf_top;
unsigned char buf[4];
std::string out_buf;
const static size_t kBufferSize = 256;
inline void PutChar(char ch) {
out_buf += ch;
if (out_buf.length() >= kBufferSize) Flush();
}
inline void Flush(void) {
fp->Write(BeginPtr(out_buf), out_buf.length());
out_buf.clear();
}
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_BASE64_H_
#endif // RABIT_LEARN_UTILS_BASE64_INL_H_

View File

@ -0,0 +1,57 @@
#ifndef RABIT_LEARN_IO_BUFFER_READER_INL_H_
#define RABIT_LEARN_IO_BUFFER_READER_INL_H_
/*!
* \file buffer_reader-inl.h
* \brief implementation of stream buffer reader
* \author Tianqi Chen
*/
#include "./io.h"
namespace rabit {
namespace io {
/*! \brief buffer reader of the stream that allows you to get */
class StreamBufferReader {
public:
StreamBufferReader(size_t buffer_size)
:stream_(NULL),
read_len_(1), read_ptr_(1) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
*/
inline void set_stream(IStream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \brief allows quick read using get char
*/
inline char GetChar(void) {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
inline bool AtEnd(void) const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
IStream *stream_;
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_;
/*! \brief pointer in the buffer */
size_t read_ptr_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_BUFFER_READER_INL_H_

View File

@ -0,0 +1,106 @@
#ifndef RABIT_LEARN_IO_FILE_INL_H_
#define RABIT_LEARN_IO_FILE_INL_H_
/*!
* \file file-inl.h
* \brief normal filesystem I/O
* \author Tianqi Chen
*/
#include <string>
#include <vector>
#include <cstdio>
#include "./io.h"
#include "./line_split-inl.h"
/*! \brief io interface */
namespace rabit {
namespace io {
/*! \brief implementation of file i/o stream */
class FileStream : public utils::ISeekStream {
public:
explicit FileStream(const char *fname, const char *mode)
: use_stdio(false) {
#ifndef RABIT_STRICT_CXX98_
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
fp = utils::FopenCheck(fname, flag.c_str());
}
}
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return feof(fp) != 0;
}
inline void Close(void) {
if (fp != NULL && !use_stdio) {
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
bool use_stdio;
};
/*! \brief line split from normal file system */
class FileSplit : public LineSplitBase {
public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
LineSplitBase::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) {
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp;
}
fsize.push_back(GetFileSize(fnames_[i].c_str()));
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~FileSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb");
}
// get file size
inline static size_t GetFileSize(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
// NOTE: fseek may not be good, but serves as ok solution
fseek(fp, 0, SEEK_END);
size_t fsize = static_cast<size_t>(ftell(fp));
fclose(fp);
return fsize;
}
private:
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_FILE_INL_H_

View File

@ -0,0 +1,140 @@
#ifndef RABIT_LEARN_IO_HDFS_INL_H_
#define RABIT_LEARN_IO_HDFS_INL_H_
/*!
* \file hdfs-inl.h
* \brief HDFS I/O
* \author Tianqi Chen
*/
#include <string>
#include <vector>
#include <hdfs.h>
#include <errno.h>
#include "./io.h"
#include "./line_split-inl.h"
/*! \brief io interface */
namespace rabit {
namespace io {
class HDFSStream : public utils::ISeekStream {
public:
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
: fs_(fs), at_end_(false) {
int flag;
if (!strcmp(mode, "r")) {
flag = O_RDONLY;
} else if (!strcmp(mode, "w")) {
flag = O_WRONLY;
} else if (!strcmp(mode, "a")) {
flag = O_WRONLY | O_APPEND;
} else {
utils::Error("HDFSStream: unknown flag %s", mode);
}
fp_ = hdfsOpenFile(fs_, fname, flag, 0, 0, 0);
utils::Check(fp_ != NULL,
"HDFSStream: fail to open %s", fname);
}
virtual ~HDFSStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
tSize nread = hdfsRead(fs_, fp_, ptr, size);
if (nread == -1) {
int errsv = errno;
utils::Error("HDFSStream.Read Error:%s", strerror(errsv));
}
if (nread == 0) {
at_end_ = true;
}
return static_cast<size_t>(nread);
}
virtual void Write(const void *ptr, size_t size) {
const char *buf = reinterpret_cast<const char*>(ptr);
while (size != 0) {
tSize nwrite = hdfsWrite(fs_, fp_, buf, size);
if (nwrite == -1) {
int errsv = errno;
utils::Error("HDFSStream.Write Error:%s", strerror(errsv));
}
size_t sz = static_cast<size_t>(nwrite);
buf += sz; size -= sz;
}
}
virtual void Seek(size_t pos) {
if (hdfsSeek(fs_, fp_, pos) != 0) {
int errsv = errno;
utils::Error("HDFSStream.Seek Error:%s", strerror(errsv));
}
}
virtual size_t Tell(void) {
tOffset offset = hdfsTell(fs_, fp_);
if (offset == -1) {
int errsv = errno;
utils::Error("HDFSStream.Tell Error:%s", strerror(errsv));
}
return static_cast<size_t>(offset);
}
virtual bool AtEnd(void) const {
return at_end_;
}
inline void Close(void) {
if (fp_ != NULL) {
if (hdfsCloseFile(fs_, fp_) == -1) {
int errsv = errno;
utils::Error("HDFSStream.Close Error:%s", strerror(errsv));
}
fp_ = NULL;
}
}
private:
hdfsFS fs_;
hdfsFile fp_;
bool at_end_;
};
/*! \brief line split from normal file system */
class HDFSSplit : public LineSplitBase {
public:
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) {
fs_ = hdfsConnect("default", 0);
std::vector<std::string> paths;
LineSplitBase::SplitNames(&paths, uri, "#");
// get the files
std::vector<size_t> fsize;
for (size_t i = 0; i < paths.size(); ++i) {
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
if (info->mKind == 'D') {
int nentry;
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
for (int i = 0; i < nentry; ++i) {
if (files[i].mKind == 'F') {
fsize.push_back(files[i].mSize);
fnames_.push_back(std::string(files[i].mName));
}
}
hdfsFreeFileInfo(files, nentry);
} else {
fsize.push_back(info->mSize);
fnames_.push_back(std::string(info->mName));
}
hdfsFreeFileInfo(info, 1);
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~HDFSSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
}
private:
// hdfs handle
hdfsFS fs_;
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_HDFS_INL_H_

View File

@ -0,0 +1,65 @@
#ifndef RABIT_LEARN_IO_IO_INL_H_
#define RABIT_LEARN_IO_IO_INL_H_
/*!
* \file io-inl.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include <cstring>
#include "./io.h"
#if RABIT_USE_HDFS
#include "./hdfs-inl.h"
#endif
#include "./file-inl.h"
namespace rabit {
namespace io {
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit) {
if (!strcmp(uri, "stdin")) {
return new SingleFileSplit(uri);
}
if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit);
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSSplit(uri, part, nsplit);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif
}
return new FileSplit(uri, part, nsplit);
}
/*!
* \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted
*
* \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write
*/
inline IStream *CreateStream(const char *uri, const char *mode) {
if (!strncmp(uri, "file://", 7)) {
return new FileStream(uri + 7, mode);
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif
}
return new FileStream(uri, mode);
}
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_IO_INL_H_

View File

@ -0,0 +1,61 @@
#ifndef RABIT_LEARN_IO_IO_H_
#define RABIT_LEARN_IO_IO_H_
/*!
* \file io.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include "../../include/rabit_serializable.h"
/*! \brief whether compile with HDFS support */
#ifndef RABIT_USE_HDFS
#define RABIT_USE_HDFS 0
#endif
/*! \brief io interface */
namespace rabit {
/*!
* \brief namespace to handle input split and filesystem interfacing
*/
namespace io {
typedef utils::ISeekStream ISeekStream;
/*!
* \brief user facing input split helper,
* can be used to get the partition of data used by current node
*/
class InputSplit {
public:
/*!
* \brief get next line, store into out_data
* \param out_data the string that stores the line data,
* \n is not included
* \return true of next line was found, false if we read all the lines
*/
virtual bool NextLine(std::string *out_data) = 0;
/*! \brief destructor*/
virtual ~InputSplit(void) {}
};
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit);
/*!
* \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted
*
* \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write
*/
inline IStream *CreateStream(const char *uri, const char *mode);
} // namespace io
} // namespace rabit
#include "./io-inl.h"
#include "./base64-inl.h"
#endif // RABIT_LEARN_IO_IO_H_

View File

@ -0,0 +1,181 @@
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
/*!
* \file line_split-inl.h
* \brief base implementation of line-spliter
* \author Tianqi Chen
*/
#include <vector>
#include <utility>
#include <cstring>
#include <string>
#include "../../include/rabit.h"
#include "./io.h"
#include "./buffer_reader-inl.h"
namespace rabit {
namespace io {
class LineSplitBase : public InputSplit {
public:
virtual ~LineSplitBase() {
if (fs_ != NULL) delete fs_;
}
virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
out_data->clear();
while (true) {
char c = reader_.GetChar();
if (reader_.AtEnd()) {
if (out_data->length() != 0) return true;
file_ptr_ += 1;
if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:file size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_];
}
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check");
delete fs_;
fs_ = this->GetFile(file_ptr_);
reader_.set_stream(fs_);
} else {
++offset_curr_;
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
}
}
}
}
protected:
// constructor
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
}
/*!
* \brief initialize the line spliter,
* \param file_size, size of each files
* \param rank the current rank of the data
* \param nsplit number of split we will divide the data into
*/
inline void Init(const std::vector<size_t> &file_size,
unsigned rank, unsigned nsplit) {
file_offset_.resize(file_size.size() + 1);
file_offset_[0] = 0;
for (size_t i = 0; i < file_size.size(); ++i) {
file_offset_[i + 1] = file_offset_[i] + file_size[i];
}
size_t ntotal = file_offset_.back();
size_t nstep = (ntotal + nsplit - 1) / nsplit;
offset_begin_ = std::min(nstep * rank, ntotal);
offset_end_ = std::min(nstep * (rank + 1), ntotal);
offset_curr_ = offset_begin_;
if (offset_begin_ == offset_end_) return;
file_ptr_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_begin_) - file_offset_.begin() - 1;
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_);
reader_.set_stream(fs_);
// try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) {
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]);
while (true) {
char c = reader_.GetChar();
if (!reader_.AtEnd()) ++offset_curr_;
if (c == '\n' || c == '\r' || c == EOF) return;
}
}
}
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of file
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output file names
* \param uri_ the iput uri file
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = strtok(NULL, dlm);
}
}
private:
/*! \brief current input stream */
utils::ISeekStream *fs_;
/*! \brief file pointer of which file to read on */
size_t file_ptr_;
/*! \brief file pointer where the end of file lies */
size_t file_ptr_end_;
/*! \brief get the current offset */
size_t offset_curr_;
/*! \brief beginning of offset */
size_t offset_begin_;
/*! \brief end of the offset */
size_t offset_end_;
/*! \brief byte-offset of each file */
std::vector<size_t> file_offset_;
/*! \brief buffer reader */
StreamBufferReader reader_;
/*! \brief buffer size */
const static size_t kBufferSize = 256;
};
/*! \brief line split from single file */
class SingleFileSplit : public InputSplit {
public:
explicit SingleFileSplit(const char *fname) {
if (!strcmp(fname, "stdin")) {
#ifndef RABIT_STRICT_CXX98_
use_stdin_ = true; fp_ = stdin;
#endif
}
if (!use_stdin_) {
fp_ = utils::FopenCheck(fname, "r");
}
end_of_file_ = false;
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) fclose(fp_);
}
virtual bool NextLine(std::string *out_data) {
if (end_of_file_) return false;
out_data->clear();
while (true) {
char c = fgetc(fp_);
if (c == EOF) {
end_of_file_ = true;
}
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (end_of_file_) return false;
}
}
return false;
}
private:
FILE *fp_;
bool use_stdin_;
bool end_of_file_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_LINE_SPLIT_INL_H_

View File

@ -6,11 +6,10 @@ MPIBIN = kmeans.mpi
OBJ = kmeans.o
# common build script for programs
include ../common.mk
include ../make/common.mk
# dependenies here
kmeans.rabit: kmeans.o lib
kmeans.mock: kmeans.o lib
kmeans.mpi: kmeans.o libmpi
kmeans.o: kmeans.cc ../../src/*.h

View File

@ -0,0 +1,2 @@
mushroom.row*
*.model

View File

@ -6,7 +6,8 @@ MPIBIN =
OBJ = linear.o
# common build script for programs
include ../common.mk
include ../make/config.mk
include ../make/common.mk
CFLAGS+=-fopenmp
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
# dependenies here

View File

@ -2,11 +2,24 @@ Linear and Logistic Regression
====
* input format: LibSVM
* Local Example: [run-linear.sh](run-linear.sh)
* Runnig on Hadoop: [run-hadoop.sh](run-hadoop.sh)
- Set input data to stdin, and model_out=stdout
* Runnig on YARN: [run-yarn.sh](run-yarn.sh)
- You will need to have YARN
- Modify ```../make/config.mk``` to set USE_HDFS=1 to compile with HDFS support
- Run build.sh on [../../yarn](../../yarn) on to build yarn jar file
Multi-Threading Optimization
====
* The code can be multi-threaded, we encourage you to use it
- Simply add ```nthread=k``` where k is the number of threads you want to use
* If you submit with YARN
- Use ```--vcores``` and ```-mem``` to request CPU and memory resources
- Some scheduler in YARN do not honor CPU request, you can request more memory to grab working slots
* Usually multi-threading improves speed in general
- You can use less workers and assign more resources to each of worker
- This usually means less communication overhead and faster running time
Parameters
===
====
All the parameters can be set by param=value
#### Important Parameters

View File

@ -1,6 +1,5 @@
#include "./linear.h"
#include "../utils/io.h"
#include "../utils/base64.h"
#include "../io/io.h"
namespace rabit {
namespace linear {
@ -55,7 +54,9 @@ class LinearObjFunction : public solver::IObjFunction<float> {
}
if (task == "train") {
lbfgs.Run();
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
if (rabit::GetRank() == 0) {
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
}
} else if (task == "pred") {
this->TaskPred();
} else {
@ -74,51 +75,37 @@ class LinearObjFunction : public solver::IObjFunction<float> {
printf("Finishing writing to %s\n", name_pred.c_str());
}
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
IStream *fi = io::CreateStream(fname, "r");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
utils::FileStream fi(fp);
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
// base64 format
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
io::Base64InStream bsin(fi);
bsin.InitPosition();
model.Load(bsin);
fclose(fp);
return;
} else if (header == "binf") {
model.Load(fi);
fclose(fp);
return;
model.Load(*fi);
} else {
utils::Error("invalid model file");
}
delete fi;
}
inline void SaveModel(const char *fname,
const float *wptr,
bool save_base64 = false) {
FILE *fp;
bool use_stdout = false;
if (!strcmp(fname, "stdout")) {
fp = stdout;
use_stdout = true;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
IStream *fo = io::CreateStream(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5);
io::Base64OutStream bout(fo);
model.Save(bout, wptr);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
model.Save(fo, wptr);
}
if (!use_stdout) {
fclose(fp);
fo->Write("binf", 4);
model.Save(*fo, wptr);
}
delete fo;
}
inline void LoadData(const char *fname) {
dtrain.Load(fname);

View File

@ -12,7 +12,7 @@ hadoop fs -mkdir $2/data
hadoop fs -put ../data/agaricus.txt.train $2/data
# submit to hadoop
../../tracker/rabit_hadoop.py --host_ip ip -n $1 -i $2/data/agaricus.txt.train -o $2/mushroom.linear.model linear.rabit stdin model_out=stdout "${*:3}"
../../tracker/rabit_hadoop_streaming.py -n $1 --vcores 1 -i $2/data/agaricus.txt.train -o $2/mushroom.linear.model linear.rabit stdin model_out=stdout "${*:3}"
# get the final model file
hadoop fs -get $2/mushroom.linear.model/part-00000 ./linear.model

View File

@ -5,11 +5,7 @@ then
exit -1
fi
rm -rf mushroom.row* *.model
rm -rf *.model
k=$1
# split the lib svm file into k subfiles
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.mock mushroom.row\%d "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1
../../tracker/rabit_demo.py -n $k linear.mock ../data/agaricus.txt.train "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1

View File

@ -5,13 +5,10 @@ then
exit -1
fi
rm -rf mushroom.row* *.model
rm -rf *.model
k=$1
# split the lib svm file into k subfiles
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
# run linear model, the program will automatically split the inputs
../../tracker/rabit_demo.py -n $k linear.rabit ../data/agaricus.txt.train reg_L1=1
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model

View File

@ -0,0 +1,19 @@
#!/bin/bash
if [ "$#" -lt 3 ];
then
echo "Usage: <nworkers> <path_in_HDFS> [param=val]"
exit -1
fi
# put the local training file to HDFS
hadoop fs -rm -r -f $2/data
hadoop fs -rm -r -f $2/mushroom.linear.model
hadoop fs -mkdir $2/data
# submit to hadoop
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
# get the final model file
hadoop fs -get $2/mushroom.linear.model ./linear.model
./linear.rabit ../data/agaricus.txt.test task=pred model_in=linear.model

View File

@ -1,24 +0,0 @@
#!/usr/bin/python
import sys
import random
# split libsvm file into different rows
if len(sys.argv) < 4:
print ('Usage:<fin> <fo> k')
exit(0)
random.seed(10)
k = int(sys.argv[3])
fi = open( sys.argv[1], 'r' )
fos = []
for i in range(k):
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
for l in open(sys.argv[1]):
i = random.randint(0, k-1)
fos[i].write(l)
for f in fos:
f.close()

View File

@ -1,13 +1,20 @@
# this is the common build script for rabit programs
# you do not have to use it
export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../../lib
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
# you do not have to use it
export LDFLAGS= -L../../lib -pthread -lm -lrt
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
# setup opencv
ifeq ($(USE_HDFS),1)
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(LIBJVM) -lhdfs -ljvm
else
CFLAGS+= -DRABIT_USE_HDFS=0
endif
.PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN)
mpi: $(MPIBIN)
lib:
@ -15,10 +22,12 @@ lib:
libmpi:
cd ../..;make lib/librabit_mpi.a;cd -
$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit $(LDFLAGS)
$(MOCKBIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

View File

@ -0,0 +1,21 @@
#-----------------------------------------------------
# rabit-learn: the configuration compile script
#
# This is the default configuration setup for rabit-learn
# If you want to change configuration, do the following steps:
#
# - copy this file to the root of rabit-learn folder
# - modify the configuration you want
# - type make or make -j n for parallel build
#----------------------------------------------------
# choice of compiler
export CC = gcc
export CXX = g++
export MPICXX = mpicxx
# whether use HDFS support during compile
USE_HDFS = 1
# path to libjvm.so
LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server

View File

@ -14,7 +14,9 @@
#include <cstring>
#include <limits>
#include <cmath>
#include <sstream>
#include <rabit.h>
#include "../io/io.h"
namespace rabit {
// typedef index type
@ -45,49 +47,37 @@ struct SparseMat {
}
// load data from LibSVM format
inline void Load(const char *fname) {
FILE *fi;
if (!strcmp(fname, "stdin")) {
fi = stdin;
} else {
if (strchr(fname, '%') != NULL) {
char s_tmp[256];
snprintf(s_tmp, sizeof(s_tmp), fname, rabit::GetRank());
fi = utils::FopenCheck(s_tmp, "r");
} else {
fi = utils::FopenCheck(fname, "r");
}
}
io::InputSplit *in =
io::CreateInputSplit
(fname, rabit::GetRank(),
rabit::GetWorldSize());
row_ptr.clear();
row_ptr.push_back(0);
data.clear();
feat_dim = 0;
float label; bool init = true;
char tmp[1024];
while (fscanf(fi, "%s", tmp) == 1) {
std::string line;
while (in->NextLine(&line)) {
float label;
std::istringstream ss(line);
ss >> label;
Entry e;
unsigned long fidx;
if (sscanf(tmp, "%lu:%f", &fidx, &e.fvalue) == 2) {
while (!ss.eof()) {
if (!(ss >> fidx)) break;
ss.ignore(32, ':');
if (!(ss >> e.fvalue)) break;
e.findex = static_cast<index_t>(fidx);
data.push_back(e);
feat_dim = std::max(fidx, feat_dim);
} else {
if (!init) {
labels.push_back(label);
row_ptr.push_back(data.size());
}
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
init = false;
}
labels.push_back(label);
row_ptr.push_back(data.size());
}
// last row
labels.push_back(label);
row_ptr.push_back(data.size());
delete in;
feat_dim += 1;
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
"feature dimension exceed limit of index_t"\
"consider change the index_t to unsigned long");
// close the filed
if (fi != stdin) fclose(fi);
}
inline size_t NumRow(void) const {
return row_ptr.size() - 1;
@ -98,6 +88,7 @@ struct SparseMat {
std::vector<Entry> data;
std::vector<float> labels;
};
// dense matrix
struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {

View File

@ -1,40 +0,0 @@
#ifndef RABIT_LEARN_UTILS_IO_H_
#define RABIT_LEARN_UTILS_IO_H_
/*!
* \file io.h
* \brief additional stream interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_IO_H_

View File

@ -29,11 +29,24 @@ AllreduceBase::AllreduceBase(void) {
task_id = "NULL";
err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of intrest
env_vars.push_back("rabit_task_id");
env_vars.push_back("rabit_num_trial");
env_vars.push_back("rabit_reduce_buffer");
env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port");
}
// initialization function
void AllreduceBase::Init(void) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
}
}
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");

View File

@ -413,6 +413,8 @@ class AllreduceBase : public IEngine {
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;

View File

@ -27,7 +27,9 @@ AllreduceRobust::AllreduceRobust(void) {
result_buffer_round = 1;
global_lazycheck = NULL;
use_local_model = -1;
recover_counter = 0;
recover_counter = 0;
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(void) {
AllreduceBase::Init();

View File

@ -0,0 +1,12 @@
Trackers
=====
This folder contains tracker scripts that can be used to submit yarn jobs to different platforms,
the example guidelines are in the script themselfs
***Supported Platforms***
* Local demo: [rabit_demo.py](rabit_demo.py)
* MPI: [rabit_mpi.py](rabit_mpi.py)
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios

View File

@ -31,35 +31,38 @@ nrep=0
rc=254
while [ $rc -eq 254 ];
do
export rabit_num_trial=$nrep
%s
%s %s rabit_num_trial=$nrep
%s
rc=$?;
nrep=$((nrep+1));
done
"""
def exec_cmd(cmd, taskid):
def exec_cmd(cmd, taskid, worker_env):
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd)
arg = ' rabit_task_id=%d' % (taskid)
cmd = cmd + arg
env = {}
for k, v in worker_env.items():
env[k] = str(v)
env['rabit_task_id'] = str(taskid)
env['PYTHONPATH'] = WRAPPER_PATH
ntrial = 0
while True:
if os.name == 'nt':
prep = 'SET PYTHONPATH=\"%s\"\n' % WRAPPER_PATH
ret = subprocess.call(prep + cmd + ('rabit_num_trial=%d' % ntrial), shell=True)
env['rabit_num_trial'] = str(ntrial)
ret = subprocess.call(cmd, shell=True, env = env)
if ret == 254:
ntrial += 1
continue
else:
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
if args.verbose != 0:
bash = keepalive % (echo % cmd, prep, cmd)
if args.verbose != 0:
bash = keepalive % (echo % cmd, cmd)
else:
bash = keepalive % ('', prep, cmd)
ret = subprocess.call(bash, shell=True, executable='bash')
bash = keepalive % ('', cmd)
ret = subprocess.call(bash, shell=True, executable='bash', env = env)
if ret == 0:
if args.verbose != 0:
print 'Thread %d exit with 0' % taskid
@ -73,7 +76,7 @@ def exec_cmd(cmd, taskid):
# Note: this submit script is only used for demo purpose
# submission script using pyhton multi-threading
#
def mthread_submit(nslave, worker_args):
def mthread_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
@ -84,7 +87,7 @@ def mthread_submit(nslave, worker_args):
"""
procs = {}
for i in range(nslave):
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i))
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i, worker_envs))
procs[i].daemon = True
procs[i].start()
for i in range(nslave):

View File

@ -1,7 +1,11 @@
#!/usr/bin/python
"""
Deprecated
This is a script to submit rabit job using hadoop streaming.
It will submit the rabit process as mappers of MapReduce.
This script is deprecated, it is highly recommended to use rabit_yarn.py instead
"""
import argparse
import sys
@ -34,13 +38,11 @@ if hadoop_binary == None or hadoop_streaming_jar == None:
', or modify rabit_hadoop.py line 16', stacklevel = 2)
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs using Hadoop Streaming.'\
'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended')
'It is Highly recommended to use rabit_yarn.py instead')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-nt', '--nthread', default = -1, type=int,
help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded')
parser.add_argument('-i', '--input', required=True,
help = 'input path in HDFS')
parser.add_argument('-o', '--output', required=True,
@ -61,6 +63,8 @@ parser.add_argument('--jobname', default='auto', help = 'customize jobname in tr
parser.add_argument('--timeout', default=600000000, type=int,
help = 'timeout (in million seconds) of each mapper job, automatically set to a very long time,'\
'normally you do not need to set this ')
parser.add_argument('--vcores', default = -1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('-mem', '--memory_mb', default=-1, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\
@ -91,10 +95,14 @@ out = out.split('\n')[0].split()
assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
hadoop_version = out[1].split('.')
use_yarn = int(hadoop_version[0]) >= 2
if use_yarn:
warnings.warn('It is highly recommended to use rabit_yarn.py to submit jobs to yarn instead', stacklevel = 2)
print 'Current Hadoop Version is %s' % out[1]
def hadoop_streaming(nworker, worker_args, use_yarn):
def hadoop_streaming(nworker, worker_args, worker_envs, use_yarn):
worker_envs['CLASSPATH'] = '`$HADOOP_HOME/bin/hadoop classpath --glob` '
worker_envs['LD_LIBRARY_PATH'] = '{LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server'
fset = set()
if args.auto_file_cache:
for i in range(len(args.command)):
@ -113,6 +121,7 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
if os.path.exists(f):
fset.add(f)
kmap = {}
kmap['env'] = 'mapred.child.env'
# setup keymaps
if use_yarn:
kmap['nworker'] = 'mapreduce.job.maps'
@ -129,12 +138,14 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
cmd = '%s jar %s' % (args.hadoop_binary, args.hadoop_streaming_jar)
cmd += ' -D%s=%d' % (kmap['nworker'], nworker)
cmd += ' -D%s=%s' % (kmap['jobname'], args.jobname)
if args.nthread != -1:
envstr = ','.join('%s=%s' % (k, str(v)) for k, v in worker_envs.items())
cmd += ' -D%s=\"%s\"' % (kmap['env'], envstr)
if args.vcores != -1:
if kmap['nthread'] is None:
warnings.warn('nthread can only be set in Yarn(Hadoop version greater than 2.0),'\
'it is recommended to use Yarn to submit rabit jobs', stacklevel = 2)
else:
cmd += ' -D%s=%d' % (kmap['nthread'], args.nthread)
cmd += ' -D%s=%d' % (kmap['nthread'], args.vcores)
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
if args.memory_mb != -1:
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
@ -150,5 +161,5 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
print cmd
subprocess.check_call(cmd, shell = True)
fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2)
fun_submit = lambda nworker, worker_args, worker_envs: hadoop_streaming(nworker, worker_args, worker_envs, int(hadoop_version[0]) >= 2)
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip)

View File

@ -22,7 +22,7 @@ args = parser.parse_args()
#
# submission script using MPI
#
def mpi_submit(nslave, worker_args):
def mpi_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
@ -31,6 +31,7 @@ def mpi_submit(nslave, worker_args):
args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit
"""
worker_args += ['%s=%s' % (k, str(v)) for k, v in worker_envs.items()]
sargs = ' '.join(args.command + worker_args)
if args.hostfile is None:
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args.command + worker_args)

View File

@ -134,19 +134,25 @@ class Tracker:
sock.listen(16)
self.sock = sock
self.verbose = verbose
if hostIP == 'auto':
hostIP = 'dns'
self.hostIP = hostIP
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
def __del__(self):
self.sock.close()
def slave_args(self):
if self.hostIP == 'auto':
def slave_envs(self):
"""
get enviroment variables for slaves
can be passed in as args or envs
"""
if self.hostIP == 'dns':
host = socket.gethostname()
elif self.hostIP == 'ip':
host = socket.gethostbyname(socket.getfqdn())
else:
host = self.hostIP
return ['rabit_tracker_uri=%s' % host,
'rabit_tracker_port=%s' % self.port]
return {'rabit_tracker_uri': host,
'rabit_tracker_port': self.port}
def get_neighbor(self, rank, nslave):
rank = rank + 1
ret = []
@ -261,9 +267,9 @@ class Tracker:
wait_conn[rank] = s
self.log_print('@tracker All nodes finishes job', 2)
def submit(nslave, args, fun_submit, verbose, hostIP):
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
master = Tracker(verbose = verbose, hostIP = hostIP)
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
submit_thread = Thread(target = fun_submit, args = (nslave, args, master.slave_envs()))
submit_thread.daemon = True
submit_thread.start()
master.accept_slaves(nslave)

View File

@ -0,0 +1,122 @@
#!/usr/bin/python
"""
This is a script to submit rabit job via Yarn
rabit will run as a Yarn application
"""
import argparse
import sys
import os
import time
import subprocess
import warnings
import rabit_tracker as tracker
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
assert os.path.exists(YARN_JAR_PATH), ("cannot find \"%s\", please run build.sh on the yarn folder" % YARN_JAR_PATH)
hadoop_binary = 'hadoop'
# code
hadoop_home = os.getenv('HADOOP_HOME')
if hadoop_home != None:
if hadoop_binary == None:
hadoop_binary = hadoop_home + '/bin/hadoop'
assert os.path.exists(hadoop_binary), "HADOOP_HOME does not contain the hadoop binary"
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs to Yarn.')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console')
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
parser.add_argument('-f', '--files', default = [], action='append',
help = 'the cached file list in mapreduce,'\
' the submission script will automatically cache all the files which appears in command'\
' This will also cause rewritten of all the file names in the command to current path,'\
' for example `../../kmeans ../kmeans.conf` will be rewritten to `./kmeans kmeans.conf`'\
' because the two files are cached to running folder.'\
' You may need this option to cache additional files.'\
' You can also use it to manually cache files when auto_file_cache is off')
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
parser.add_argument('--tempdir', default='/tmp', help = 'temporary directory in HDFS that can be used to store intermediate results')
parser.add_argument('--vcores', default = 1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\
'so that each node can occupy all the mapper slots in a machine for maximum performance')
parser.add_argument('command', nargs='+',
help = 'command for rabit program')
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
if hadoop_binary == None:
parser.add_argument('-hb', '--hadoop_binary', required = True,
help="path to hadoop binary file")
else:
parser.add_argument('-hb', '--hadoop_binary', default = hadoop_binary,
help="path to hadoop binary file")
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
# detech hadoop version
(out, err) = subprocess.Popen('%s version' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
out = out.split('\n')[0].split()
assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
hadoop_version = out[1].split('.')
(classpath, err) = subprocess.Popen('%s classpath --glob' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
if hadoop_version < 2:
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
def submit_yarn(nworker, worker_args, worker_env):
fset = set([YARN_JAR_PATH])
if args.auto_file_cache != 0:
for i in range(len(args.command)):
f = args.command[i]
if os.path.exists(f):
fset.add(f)
if i == 0:
args.command[i] = './' + args.command[i].split('/')[-1]
else:
args.command[i] = args.command[i].split('/')[-1]
if args.command[0].endswith('.py'):
flst = [WRAPPER_PATH + '/rabit.py',
WRAPPER_PATH + '/librabit_wrapper.so',
WRAPPER_PATH + '/librabit_wrapper_mock.so']
for f in flst:
if os.path.exists(f):
fset.add(f)
cmd = 'java -cp `%s classpath`:%s org.apache.hadoop.yarn.rabit.Client ' % (args.hadoop_binary, YARN_JAR_PATH)
env = os.environ.copy()
for k, v in worker_env.items():
env[k] = str(v)
env['rabit_cpu_vcores'] = str(args.vcores)
env['rabit_memory_mb'] = str(args.memory_mb)
env['rabit_world_size'] = str(args.nworker)
if args.files != None:
for flst in args.files:
for f in flst.split('#'):
fset.add(f)
for f in fset:
cmd += ' -file %s' % f
cmd += ' -jobname %s ' % args.jobname
cmd += ' -tempdir %s ' % args.tempdir
cmd += (' '.join(args.command + worker_args))
print cmd
subprocess.check_call(cmd, shell = True, env = env)
tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)

4
subtree/rabit/yarn/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
bin
.classpath
.project
*.jar

View File

@ -0,0 +1,5 @@
rabit-yarn
=====
* This folder contains Application code to allow rabit run on Yarn.
* You can use [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) to submit the job
- run ```./build.sh``` to build the jar, before using the script

View File

@ -0,0 +1 @@
foler used to hold generated class files

4
subtree/rabit/yarn/build.sh Executable file
View File

@ -0,0 +1,4 @@
#!/bin/bash
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
jar cf rabit-yarn.jar -C bin .

View File

@ -0,0 +1,508 @@
package org.apache.hadoop.yarn.rabit;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Collection;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.ContainerState;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
/**
* application master for allocating resources of rabit client
*
* @author Tianqi Chen
*/
public class ApplicationMaster {
// logger
private static final Log LOG = LogFactory.getLog(ApplicationMaster.class);
// configuration
private Configuration conf = new YarnConfiguration();
// hdfs handler
private FileSystem dfs;
// number of cores allocated for each task
private int numVCores = 1;
// memory needed requested for the task
private int numMemoryMB = 10;
// priority of the app master
private int appPriority = 0;
// total number of tasks
private int numTasks = 1;
// maximum number of attempts to try in each task
private int maxNumAttempt = 3;
// command to launch
private String command = "";
// application tracker hostname
private String appHostName = "";
// tracker URL to do
private String appTrackerUrl = "";
// tracker port
private int appTrackerPort = 0;
// whether we start to abort the application, due to whatever fatal reasons
private boolean startAbort = false;
// worker resources
private Map<String, LocalResource> workerResources = new java.util.HashMap<String, LocalResource>();
// record the aborting reason
private String abortDiagnosis = "";
// resource manager
private AMRMClientAsync<ContainerRequest> rmClient = null;
// node manager
private NMClientAsync nmClient = null;
// list of tasks that pending for resources to be allocated
private final Queue<TaskRecord> pendingTasks = new java.util.LinkedList<TaskRecord>();
// map containerId->task record of tasks that was running
private final Map<ContainerId, TaskRecord> runningTasks = new java.util.HashMap<ContainerId, TaskRecord>();
// collection of tasks
private final Collection<TaskRecord> finishedTasks = new java.util.LinkedList<TaskRecord>();
// collection of killed tasks
private final Collection<TaskRecord> killedTasks = new java.util.LinkedList<TaskRecord>();
public static void main(String[] args) throws Exception {
new ApplicationMaster().run(args);
}
private ApplicationMaster() throws IOException {
dfs = FileSystem.get(conf);
}
/**
* get integer argument from environment variable
*
* @param name
* name of key
* @param required
* whether this is required
* @param defv
* default value
* @return the requested result
*/
private int getEnvInteger(String name, boolean required, int defv)
throws IOException {
String value = System.getenv(name);
if (value == null) {
if (required) {
throw new IOException("environment variable " + name
+ " not set");
} else {
return defv;
}
}
return Integer.valueOf(value);
}
/**
* initialize from arguments and command lines
*
* @param args
*/
private void initArgs(String args[]) throws IOException {
LOG.info("Invoke initArgs");
// cached maps
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
for (int i = 0; i < args.length; ++i) {
if (args[i].equals("-file")) {
String[] arr = args[++i].split("#");
Path path = new Path(arr[0]);
if (arr.length == 1) {
cacheFiles.put(path.getName(), path);
} else {
cacheFiles.put(arr[1], path);
}
} else {
this.command += args[i] + " ";
}
}
for (Map.Entry<String, Path> e : cacheFiles.entrySet()) {
LocalResource r = Records.newRecord(LocalResource.class);
FileStatus status = dfs.getFileStatus(e.getValue());
r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue()));
r.setSize(status.getLen());
r.setTimestamp(status.getModificationTime());
r.setType(LocalResourceType.FILE);
r.setVisibility(LocalResourceVisibility.APPLICATION);
workerResources.put(e.getKey(), r);
}
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt);
}
/**
* called to start the application
*/
private void run(String args[]) throws Exception {
this.initArgs(args);
this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000,
new RMCallbackHandler());
this.nmClient = NMClientAsync
.createNMClientAsync(new NMCallbackHandler());
this.rmClient.init(conf);
this.rmClient.start();
this.nmClient.init(conf);
this.nmClient.start();
RegisterApplicationMasterResponse response = this.rmClient
.registerApplicationMaster(this.appHostName,
this.appTrackerPort, this.appTrackerUrl);
boolean success = false;
String diagnostics = "";
try {
// list of tasks that waits to be submit
java.util.Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
// add waiting tasks
for (int i = 0; i < this.numTasks; ++i) {
tasks.add(new TaskRecord(i));
}
Resource maxResource = response.getMaximumResourceCapability();
if (maxResource.getMemory() < this.numMemoryMB) {
LOG.warn("[Rabit] memory requested exceed bound "
+ maxResource.getMemory());
this.numMemoryMB = maxResource.getMemory();
}
if (maxResource.getVirtualCores() < this.numVCores) {
LOG.warn("[Rabit] memory requested exceed bound "
+ maxResource.getVirtualCores());
this.numVCores = maxResource.getVirtualCores();
}
this.submitTasks(tasks);
LOG.info("[Rabit] ApplicationMaster started");
while (!this.doneAllJobs()) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
}
assert (killedTasks.size() + finishedTasks.size() == numTasks);
success = finishedTasks.size() == numTasks;
LOG.info("Application completed. Stopping running containers");
nmClient.stop();
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
+ ", finished=" + this.finishedTasks.size() + ", failed="
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
LOG.info(diagnostics);
} catch (Exception e) {
diagnostics = e.toString();
}
rmClient.unregisterApplicationMaster(
success ? FinalApplicationStatus.SUCCEEDED
: FinalApplicationStatus.FAILED, diagnostics,
appTrackerUrl);
if (!success) throw new Exception("Application not successful");
}
/**
* check if the job finishes
*
* @return whether we finished all the jobs
*/
private synchronized boolean doneAllJobs() {
return pendingTasks.size() == 0 && runningTasks.size() == 0;
}
/**
* submit tasks to request containers for the tasks
*
* @param tasks
* a collection of tasks we want to ask container for
*/
private synchronized void submitTasks(Collection<TaskRecord> tasks) {
for (TaskRecord r : tasks) {
Resource resource = Records.newRecord(Resource.class);
resource.setMemory(numMemoryMB);
resource.setVirtualCores(numVCores);
Priority priority = Records.newRecord(Priority.class);
priority.setPriority(this.appPriority);
r.containerRequest = new ContainerRequest(resource, null, null,
priority);
rmClient.addContainerRequest(r.containerRequest);
pendingTasks.add(r);
}
}
/**
* launch the task on container
*
* @param container
* container to run the task
* @param task
* the task
*/
private void launchTask(Container container, TaskRecord task) {
task.container = container;
task.containerRequest = null;
ContainerLaunchContext ctx = Records
.newRecord(ContainerLaunchContext.class);
String cmd =
// use this to setup CLASSPATH correctly for libhdfs
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` "
+ this.command + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ "/stderr";
LOG.info(cmd);
ctx.setCommands(Collections.singletonList(cmd));
LOG.info(workerResources);
ctx.setLocalResources(this.workerResources);
// setup environment variables
Map<String, String> env = new java.util.HashMap<String, String>();
// setup class path, this is kind of duplicated, ignoring
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath.append(':');
cpath.append(c.trim());
}
// already use hadoop command to get class path in worker, maybe a better solution in future
// env.put("CLASSPATH", cpath.toString());
// setup LD_LIBARY_PATH path for libhdfs
env.put("LD_LIBRARY_PATH",
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
env.put("PYTHONPATH", "${PYTHONPATH}:.");
// inherit all rabit variables
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
}
env.put("rabit_task_id", String.valueOf(task.taskId));
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
ctx.setEnvironment(env);
synchronized (this) {
assert (!this.runningTasks.containsKey(container.getId()));
this.runningTasks.put(container.getId(), task);
this.nmClient.startContainerAsync(container, ctx);
}
}
/**
* free the containers that have not yet been launched
*
* @param containers
*/
private synchronized void freeUnusedContainers(
Collection<Container> containers) {
}
/**
* handle method for AMRMClientAsync.CallbackHandler container allocation
*
* @param containers
*/
private synchronized void onContainersAllocated(List<Container> containers) {
if (this.startAbort) {
this.freeUnusedContainers(containers);
return;
}
Collection<Container> freelist = new java.util.LinkedList<Container>();
for (Container c : containers) {
TaskRecord task;
task = pendingTasks.poll();
if (task == null) {
freelist.add(c);
continue;
}
this.launchTask(c, task);
}
this.freeUnusedContainers(freelist);
}
/**
* start aborting the job
*
* @param msg
* the fatal message
*/
private synchronized void abortJob(String msg) {
if (!this.startAbort)
this.abortDiagnosis = msg;
this.startAbort = true;
for (TaskRecord r : this.runningTasks.values()) {
if (!r.abortRequested) {
nmClient.stopContainerAsync(r.container.getId(),
r.container.getNodeId());
r.abortRequested = true;
}
}
this.killedTasks.addAll(this.pendingTasks);
for (TaskRecord r : this.pendingTasks) {
rmClient.removeContainerRequest(r.containerRequest);
}
this.pendingTasks.clear();
LOG.info(msg);
}
/**
* handle non fatal failures
*
* @param cid
*/
private synchronized void handleFailure(Collection<ContainerId> failed) {
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
for (ContainerId cid : failed) {
TaskRecord r = runningTasks.remove(cid);
if (r == null)
continue;
r.attemptCounter += 1;
r.container = null;
tasks.add(r);
if (r.attemptCounter >= this.maxNumAttempt) {
this.abortJob("[Rabit] Task " + r.taskId + " failed more than "
+ r.attemptCounter + "times");
}
}
if (this.startAbort) {
this.killedTasks.addAll(tasks);
} else {
this.submitTasks(tasks);
}
}
/**
* handle method for AMRMClientAsync.CallbackHandler container allocation
*
* @param status
* list of status
*/
private synchronized void onContainersCompleted(List<ContainerStatus> status) {
Collection<ContainerId> failed = new java.util.LinkedList<ContainerId>();
for (ContainerStatus s : status) {
assert (s.getState().equals(ContainerState.COMPLETE));
int exstatus = s.getExitStatus();
TaskRecord r = runningTasks.get(s.getContainerId());
if (r == null)
continue;
if (exstatus == ContainerExitStatus.SUCCESS) {
finishedTasks.add(r);
runningTasks.remove(s.getContainerId());
} else {
switch (exstatus) {
case ContainerExitStatus.KILLED_EXCEEDED_PMEM:
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated physical memory");
break;
case ContainerExitStatus.KILLED_EXCEEDED_VMEM:
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated virtual memory");
break;
default:
LOG.info("[Rabit] Task " + r.taskId
+ " exited with status " + exstatus);
failed.add(s.getContainerId());
}
}
}
this.handleFailure(failed);
}
/**
* callback handler for resource manager
*/
private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler {
@Override
public float getProgress() {
return 1.0f - (float) (pendingTasks.size()) / numTasks;
}
@Override
public void onContainersAllocated(List<Container> containers) {
ApplicationMaster.this.onContainersAllocated(containers);
}
@Override
public void onContainersCompleted(List<ContainerStatus> status) {
ApplicationMaster.this.onContainersCompleted(status);
}
@Override
public void onError(Throwable ex) {
ApplicationMaster.this.abortJob("[Rabit] Resource manager Error "
+ ex.toString());
}
@Override
public void onNodesUpdated(List<NodeReport> nodereport) {
}
@Override
public void onShutdownRequest() {
ApplicationMaster.this
.abortJob("[Rabit] Get shutdown request, start to shutdown...");
}
}
private class NMCallbackHandler implements NMClientAsync.CallbackHandler {
@Override
public void onContainerStarted(ContainerId cid,
Map<String, ByteBuffer> services) {
LOG.debug("onContainerStarted Invoked");
}
@Override
public void onContainerStatusReceived(ContainerId cid,
ContainerStatus status) {
LOG.debug("onContainerStatusReceived Invoked");
}
@Override
public void onContainerStopped(ContainerId cid) {
LOG.debug("onContainerStopped Invoked");
}
@Override
public void onGetContainerStatusError(ContainerId cid, Throwable ex) {
LOG.debug("onGetContainerStatusError Invoked: " + ex.toString());
ApplicationMaster.this
.handleFailure(Collections.singletonList(cid));
}
@Override
public void onStartContainerError(ContainerId cid, Throwable ex) {
LOG.debug("onStartContainerError Invoked: " + ex.toString());
ApplicationMaster.this
.handleFailure(Collections.singletonList(cid));
}
@Override
public void onStopContainerError(ContainerId cid, Throwable ex) {
LOG.info("onStopContainerError Invoked: " + ex.toString());
}
}
}

View File

@ -0,0 +1,233 @@
package org.apache.hadoop.yarn.rabit;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
public class Client {
// logger
private static final Log LOG = LogFactory.getLog(Client.class);
// permission for temp file
private static final FsPermission permTemp = new FsPermission("777");
// configuration
private YarnConfiguration conf = new YarnConfiguration();
// hdfs handler
private FileSystem dfs;
// cached maps
private Map<String, String> cacheFiles = new java.util.HashMap<String, String>();
// enviroment variable to setup cachefiles
private String cacheFileArg = "";
// args to pass to application master
private String appArgs = "";
// HDFS Path to store temporal result
private String tempdir = "/tmp";
// job name
private String jobName = "";
/**
* constructor
* @throws IOException
*/
private Client() throws IOException {
dfs = FileSystem.get(conf);
}
/**
* ge
*
* @param fmaps
* the file maps
* @return the resource map
* @throws IOException
*/
private Map<String, LocalResource> setupCacheFiles(ApplicationId appId) throws IOException {
// create temporary rabit directory
Path tmpPath = new Path(this.tempdir);
if (!dfs.exists(tmpPath)) {
dfs.mkdirs(tmpPath, permTemp);
LOG.info("HDFS temp directory do not exist, creating.. " + tmpPath);
}
tmpPath = new Path(tmpPath + "/temp-rabit-yarn-" + appId);
if (dfs.exists(tmpPath)) {
dfs.delete(tmpPath, true);
}
// create temporary directory
FileSystem.mkdirs(dfs, tmpPath, permTemp);
StringBuilder cstr = new StringBuilder();
Map<String, LocalResource> rmap = new java.util.HashMap<String, LocalResource>();
for (Map.Entry<String, String> e : cacheFiles.entrySet()) {
LocalResource r = Records.newRecord(LocalResource.class);
Path path = new Path(e.getValue());
// copy local data to temporary folder in HDFS
if (!e.getValue().startsWith("hdfs://")) {
Path dst = new Path("hdfs://" + tmpPath + "/"+ path.getName());
dfs.copyFromLocalFile(false, true, path, dst);
dfs.setPermission(dst, permTemp);
dfs.deleteOnExit(dst);
path = dst;
}
FileStatus status = dfs.getFileStatus(path);
r.setResource(ConverterUtils.getYarnUrlFromPath(path));
r.setSize(status.getLen());
r.setTimestamp(status.getModificationTime());
r.setType(LocalResourceType.FILE);
r.setVisibility(LocalResourceVisibility.APPLICATION);
rmap.put(e.getKey(), r);
cstr.append(" -file \"");
cstr.append(path.toString());
cstr.append('#');
cstr.append(e.getKey());
cstr.append("\"");
}
dfs.deleteOnExit(tmpPath);
this.cacheFileArg = cstr.toString();
return rmap;
}
/**
* get the environment variables for container
*
* @return the env variable for child class
*/
private Map<String, String> getEnvironment() {
// Setup environment variables
Map<String, String> env = new java.util.HashMap<String, String>();
String cpath = "${CLASSPATH}:./*";
for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath += ':';
cpath += c.trim();
}
env.put("CLASSPATH", cpath);
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
}
LOG.debug(env);
return env;
}
/**
* initialize the settings
*
* @param args
*/
private void initArgs(String[] args) {
// directly pass all arguments except args0
StringBuilder sargs = new StringBuilder("");
for (int i = 0; i < args.length; ++i) {
if (args[i].equals("-file")) {
String[] arr = args[++i].split("#");
if (arr.length == 1) {
cacheFiles.put(new Path(arr[0]).getName(), arr[0]);
} else {
cacheFiles.put(arr[1], arr[0]);
}
} else if(args[i].equals("-jobname")) {
this.jobName = args[++i];
} else if(args[i].equals("-tempdir")) {
this.tempdir = args[++i];
} else {
sargs.append(" ");
sargs.append(args[i]);
}
}
this.appArgs = sargs.toString();
}
private void run(String[] args) throws Exception {
if (args.length == 0) {
System.out.println("Usage: [options] [commands..]");
System.out.println("options: [-file filename]");
return;
}
this.initArgs(args);
// Create yarnClient
YarnConfiguration conf = new YarnConfiguration();
YarnClient yarnClient = YarnClient.createYarnClient();
yarnClient.init(conf);
yarnClient.start();
// Create application via yarnClient
YarnClientApplication app = yarnClient.createApplication();
// Set up the container launch context for the application master
ContainerLaunchContext amContainer = Records
.newRecord(ContainerLaunchContext.class);
ApplicationSubmissionContext appContext = app
.getApplicationSubmissionContext();
// Submit application
ApplicationId appId = appContext.getApplicationId();
// setup cache-files and environment variables
amContainer.setLocalResources(this.setupCacheFiles(appId));
amContainer.setEnvironment(this.getEnvironment());
String cmd = "$JAVA_HOME/bin/java"
+ " -Xmx256M"
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr";
LOG.debug(cmd);
amContainer.setCommands(Collections.singletonList(cmd));
// Set up resource type requirements for ApplicationMaster
Resource capability = Records.newRecord(Resource.class);
capability.setMemory(256);
capability.setVirtualCores(1);
LOG.info("jobname=" + this.jobName);
appContext.setApplicationName(jobName + ":RABIT-YARN");
appContext.setAMContainerSpec(amContainer);
appContext.setResource(capability);
appContext.setQueue("default");
LOG.info("Submitting application " + appId);
yarnClient.submitApplication(appContext);
ApplicationReport appReport = yarnClient.getApplicationReport(appId);
YarnApplicationState appState = appReport.getYarnApplicationState();
while (appState != YarnApplicationState.FINISHED
&& appState != YarnApplicationState.KILLED
&& appState != YarnApplicationState.FAILED) {
Thread.sleep(100);
appReport = yarnClient.getApplicationReport(appId);
appState = appReport.getYarnApplicationState();
}
System.out.println("Application " + appId + " finished with"
+ " state " + appState + " at " + appReport.getFinishTime());
if (!appReport.getFinalApplicationStatus().equals(
FinalApplicationStatus.SUCCEEDED)) {
System.err.println(appReport.getDiagnostics());
}
}
public static void main(String[] args) throws Exception {
new Client().run(args);
}
}

View File

@ -0,0 +1,24 @@
package org.apache.hadoop.yarn.rabit;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
/**
* data structure to hold the task information
*/
public class TaskRecord {
// task id of the task
public int taskId = 0;
// number of failed attempts to run the task
public int attemptCounter = 0;
// container request, can be null if task is already running
public ContainerRequest containerRequest = null;
// running container, can be null if the task is not launched
public Container container = null;
// whether we have requested abortion of this task
public boolean abortRequested = false;
public TaskRecord(int taskId) {
this.taskId = taskId;
}
}