introduce input split

This commit is contained in:
tqchen 2015-03-06 21:08:04 -08:00
parent d4ec037f2e
commit 1a573f987b
14 changed files with 496 additions and 166 deletions

View File

@ -8,8 +8,7 @@ It also contain links to the Machine Learning packages that uses rabit.
Toolkits
====
* [KMeans Clustering](kmeans)
* [Linear and Logistic Regression](linear)
* [Linear and Logistic Regression](linear)
* [XGBoost: eXtreme Gradient Boosting](https://github.com/tqchen/xgboost/tree/master/multi-node)
- xgboost is a very fast boosted tree(also known as GBDT) library, that can run more than
10 times faster than existing packages

View File

@ -3,8 +3,8 @@
export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../../lib
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
export LDFLAGS= -pthread -lm -L../../lib -lrt
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
.PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN)
@ -16,9 +16,10 @@ libmpi:
cd ../..;make lib/librabit_mpi.a;cd -
$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit $(LDFLAGS)
$(MOCKBIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

View File

@ -1,5 +1,5 @@
#ifndef RABIT_LEARN_UTILS_BASE64_H_
#define RABIT_LEARN_UTILS_BASE64_H_
#ifndef RABIT_LEARN_IO_BASE64_H_
#define RABIT_LEARN_IO_BASE64_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
@ -8,10 +8,11 @@
*/
#include <cctype>
#include <cstdio>
#include <rabit/io.h>
#include "./io.h"
#include "./utils.h"
namespace rabit {
namespace utils {
namespace io {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
@ -34,7 +35,8 @@ static const char EncodeTable[] =
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
explicit Base64InStream(IStream *fs) : reader_(256) {
reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0;
}
/*!
@ -44,7 +46,7 @@ class Base64InStream: public IStream {
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
tmp_ch = reader_.GetChar();
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
@ -85,19 +87,19 @@ class Base64InStream: public IStream {
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format");
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
@ -110,10 +112,10 @@ class Base64InStream: public IStream {
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
@ -125,10 +127,10 @@ class Base64InStream: public IStream {
}
}
// get next char
tmp_ch = fgetc(fp);
tmp_ch = reader_.GetChar();
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
utils::Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
@ -137,7 +139,7 @@ class Base64InStream: public IStream {
}
private:
FILE *fp;
StreamBufferReader reader_;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
@ -147,7 +149,7 @@ class Base64InStream: public IStream {
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
explicit Base64OutStream(IStream *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
@ -160,16 +162,16 @@ class Base64OutStream: public IStream {
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
PutChar(EncodeTable[buf[3] & 0x3F]);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
utils::Error("Base64OutStream do not support read");
return 0;
}
/*!
@ -179,25 +181,37 @@ class Base64OutStream: public IStream {
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
PutChar('=');
PutChar('=');
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
PutChar('=');
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
if (endch != EOF) PutChar(endch);
this->Flush();
}
private:
FILE *fp;
private:
IStream *fp;
int buf_top;
unsigned char buf[4];
std::string out_buf;
const static size_t kBufferSize = 256;
inline void PutChar(char ch) {
out_buf += ch;
if (out_buf.length() >= kBufferSize) Flush();
}
inline void Flush(void) {
fp->Write(BeginPtr(out_buf), out_buf.length());
out_buf.clear();
}
};
} // namespace utils
} // namespace rabit

36
rabit-learn/io/io-inl.h Normal file
View File

@ -0,0 +1,36 @@
#ifndef RABIT_LEARN_IO_IO_INL_H_
#define RABIT_LEARN_IO_IO_INL_H_
/*!
* \file io-inl.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include <cstring>
#include "./line_split.h"
namespace rabit {
namespace io {
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit) {
if (!strcmp(uri, "stdin")) {
return new SingleFileSplit(uri);
}
if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit);
}
if (!strncmp(uri, "hdfs://", 7)) {
utils::Error("HDFS reading is not yet supported");
return NULL;
}
return new FileSplit(uri, part, nsplit);
}
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_IO_INL_H_

53
rabit-learn/io/io.h Normal file
View File

@ -0,0 +1,53 @@
#ifndef RABIT_LEARN_IO_IO_H_
#define RABIT_LEARN_IO_IO_H_
/*!
* \file io.h
* \brief Input/Output utils that handles read/write
* of files in distrubuted enviroment
* \author Tianqi Chen
*/
#include "../../include/rabit_serializable.h"
/*! \brief io interface */
namespace rabit {
/*!
* \brief namespace to handle input split and filesystem interfacing
*/
namespace io {
/*!
* \brief user facing input split helper,
* can be used to get the partition of data used by current node
*/
class InputSplit {
public:
/*!
* \brief get next line, store into out_data
* \param out_data the string that stores the line data,
* \n is not included
* \return true of next line was found, false if we read all the lines
*/
virtual bool NextLine(std::string *out_data) = 0;
/*! \brief destructor*/
virtual ~InputSplit(void) {}
};
/*!
* \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part the part id of current input
* \param nsplit total number of splits
*/
inline InputSplit *CreateInputSplit(const char *uri,
unsigned part,
unsigned nsplit);
/*!
* \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted
*
* \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write
*/
inline IStream *CreateStream(const char *uri, const char *mode);
} // namespace io
} // namespace rabit
#include "./io-inl.h"
#endif // RABIT_LEARN_IO_IO_H_

217
rabit-learn/io/line_split.h Normal file
View File

@ -0,0 +1,217 @@
#ifndef RABIT_LEARN_IO_LINE_SPLIT_H_
#define RABIT_LEARN_IO_LINE_SPLIT_H_
/*!
* \file line_split.h
* \brief base implementation of line-spliter
* \author Tianqi Chen
*/
#include <vector>
#include <utility>
#include <cstring>
#include <iostream>
#include <fstream>
#include "../../include/rabit.h"
#include "./io.h"
#include "./utils.h"
namespace rabit {
namespace io {
class LineSplitBase : public InputSplit {
public:
virtual ~LineSplitBase() {
if (fs_ != NULL) delete fs_;
}
virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
out_data->clear();
while (true) {
char c = reader_.GetChar();
if (reader_.AtEnd()) {
if (out_data->length() != 0) return true;
file_ptr_ += 1;
if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:file size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_];
}
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check");
delete fs_;
fs_ = this->GetFile(file_ptr_);
reader_.set_stream(fs_);
} else {
++offset_curr_;
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
}
}
}
}
protected:
// constructor
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
}
/*!
* \brief initialize the line spliter,
* \param file_size, size of each files
* \param rank the current rank of the data
* \param nsplit number of split we will divide the data into
*/
inline void Init(const std::vector<size_t> &file_size,
unsigned rank, unsigned nsplit) {
file_offset_.resize(file_size.size() + 1);
file_offset_[0] = 0;
for (size_t i = 0; i < file_size.size(); ++i) {
file_offset_[i + 1] = file_offset_[i] + file_size[i];
}
size_t ntotal = file_offset_.back();
size_t nstep = (ntotal + nsplit - 1) / nsplit;
offset_begin_ = std::min(nstep * rank, ntotal);
offset_end_ = std::min(nstep * (rank + 1), ntotal);
offset_curr_ = offset_begin_;
if (offset_begin_ == offset_end_) return;
file_ptr_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_begin_) - file_offset_.begin() - 1;
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_);
reader_.set_stream(fs_);
// try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) {
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]);
while (true) {
char c = reader_.GetChar();
if (!reader_.AtEnd()) ++offset_curr_;
if (c == '\n' || c == '\r' || c == EOF) return;
}
}
}
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of file
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output file names
* \param uri_ the iput uri file
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = strtok(NULL, dlm);
}
}
private:
/*! \brief current input stream */
utils::ISeekStream *fs_;
/*! \brief file pointer of which file to read on */
size_t file_ptr_;
/*! \brief file pointer where the end of file lies */
size_t file_ptr_end_;
/*! \brief get the current offset */
size_t offset_curr_;
/*! \brief beginning of offset */
size_t offset_begin_;
/*! \brief end of the offset */
size_t offset_end_;
/*! \brief byte-offset of each file */
std::vector<size_t> file_offset_;
/*! \brief buffer reader */
StreamBufferReader reader_;
/*! \brief buffer size */
const static size_t kBufferSize = 256;
};
/*! \brief line split from single file */
class SingleFileSplit : public InputSplit {
public:
explicit SingleFileSplit(const char *fname) {
if (!strcmp(fname, "stdin")) {
use_stdin_ = true;
}
if (!use_stdin_) {
fp_ = utils::FopenCheck(fname, "r");
}
end_of_file_ = false;
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) fclose(fp_);
}
virtual bool NextLine(std::string *out_data) {
if (end_of_file_) return false;
out_data->clear();
while (true) {
char c = fgetc(fp_);
if (c == EOF) {
end_of_file_ = true;
}
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (end_of_file_) return false;
}
}
return false;
}
private:
FILE *fp_;
bool use_stdin_;
bool end_of_file_;
};
/*! \brief line split from normal file system */
class FileSplit : public LineSplitBase {
public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
LineSplitBase::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) {
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp;
}
fsize.push_back(GetFileSize(fnames_[i].c_str()));
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~FileSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb");
}
// get file size
inline static size_t GetFileSize(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
// NOTE: fseek may not be good, but serves as ok solution
fseek(fp, 0, SEEK_END);
size_t fsize = static_cast<size_t>(ftell(fp));
fclose(fp);
return fsize;
}
private:
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_LINE_SPLIT_H_

102
rabit-learn/io/utils.h Normal file
View File

@ -0,0 +1,102 @@
#ifndef RABIT_LEARN_IO_UTILS_H_
#define RABIT_LEARN_IO_UTILS_H_
/*!
* \file utils.h
* \brief some helper utils that can be used to implement IO
* \author Tianqi Chen
*/
namespace rabit {
namespace io {
/*! \brief buffer reader of the stream that allows you to get */
class StreamBufferReader {
public:
StreamBufferReader(size_t buffer_size)
:stream_(NULL),
read_len_(1), read_ptr_(1) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
*/
inline void set_stream(IStream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \brief allows quick read using get char
*/
inline char GetChar(void) {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
inline bool AtEnd(void) const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
IStream *stream_;
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_;
/*! \brief pointer in the buffer */
size_t read_ptr_;
};
/*! \brief implementation of file i/o stream */
class FileStream : public utils::ISeekStream {
public:
explicit FileStream(const char *fname, const char *mode)
: use_stdio(false) {
#ifndef RABIT_STRICT_CXX98_
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!use_stdio) {
fp = utils::FopenCheck(fname, mode);
}
}
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return feof(fp) != 0;
}
inline void Close(void) {
if (fp != NULL && !use_stdio) {
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
bool use_stdio;
};
} // namespace io
} // namespace rabit
#endif // RABIT_LEARN_IO_UTILS_H_

2
rabit-learn/linear/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
mushroom.row*
*.model

View File

@ -1,6 +1,6 @@
#include "./linear.h"
#include "../utils/io.h"
#include "../utils/base64.h"
#include "../io/io.h"
#include "../io/base64.h"
namespace rabit {
namespace linear {
@ -74,23 +74,20 @@ class LinearObjFunction : public solver::IObjFunction<float> {
printf("Finishing writing to %s\n", name_pred.c_str());
}
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
io::FileStream fi(fname, "rb");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
utils::FileStream fi(fp);
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
// base64 format
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
io::Base64InStream bsin(&fi);
bsin.InitPosition();
model.Load(bsin);
fclose(fp);
return;
} else if (header == "binf") {
model.Load(fi);
fclose(fp);
return;
return;
} else {
utils::Error("invalid model file");
}
@ -98,27 +95,16 @@ class LinearObjFunction : public solver::IObjFunction<float> {
inline void SaveModel(const char *fname,
const float *wptr,
bool save_base64 = false) {
FILE *fp;
bool use_stdout = false;
if (!strcmp(fname, "stdout")) {
fp = stdout;
use_stdout = true;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
if (save_base64 != 0|| use_stdout) {
io::FileStream fo(fname, "wb");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
io::Base64OutStream bout(&fo);
model.Save(bout, wptr);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
model.Save(fo, wptr);
}
if (!use_stdout) {
fclose(fp);
}
}
inline void LoadData(const char *fname) {
dtrain.Load(fname);

View File

@ -5,11 +5,7 @@ then
exit -1
fi
rm -rf mushroom.row* *.model
rm -rf *.model
k=$1
# split the lib svm file into k subfiles
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.mock mushroom.row\%d "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1
../../tracker/rabit_demo.py -n $k linear.mock ../data/agaricus.txt.train "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1

View File

@ -5,13 +5,10 @@ then
exit -1
fi
rm -rf mushroom.row* *.model
rm -rf *.model
k=$1
# split the lib svm file into k subfiles
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
# run linear model, the program will automatically split the inputs
../../tracker/rabit_demo.py -n $k linear.rabit ../data/agaricus.txt.train reg_L1=1
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model

View File

@ -1,24 +0,0 @@
#!/usr/bin/python
import sys
import random
# split libsvm file into different rows
if len(sys.argv) < 4:
print ('Usage:<fin> <fo> k')
exit(0)
random.seed(10)
k = int(sys.argv[3])
fi = open( sys.argv[1], 'r' )
fos = []
for i in range(k):
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
for l in open(sys.argv[1]):
i = random.randint(0, k-1)
fos[i].write(l)
for f in fos:
f.close()

View File

@ -14,7 +14,9 @@
#include <cstring>
#include <limits>
#include <cmath>
#include <sstream>
#include <rabit.h>
#include "../io/io.h"
namespace rabit {
// typedef index type
@ -45,49 +47,37 @@ struct SparseMat {
}
// load data from LibSVM format
inline void Load(const char *fname) {
FILE *fi;
if (!strcmp(fname, "stdin")) {
fi = stdin;
} else {
if (strchr(fname, '%') != NULL) {
char s_tmp[256];
snprintf(s_tmp, sizeof(s_tmp), fname, rabit::GetRank());
fi = utils::FopenCheck(s_tmp, "r");
} else {
fi = utils::FopenCheck(fname, "r");
}
}
io::InputSplit *in =
io::CreateInputSplit
(fname, rabit::GetRank(),
rabit::GetWorldSize());
row_ptr.clear();
row_ptr.push_back(0);
data.clear();
feat_dim = 0;
float label; bool init = true;
char tmp[1024];
while (fscanf(fi, "%s", tmp) == 1) {
std::string line;
while (in->NextLine(&line)) {
float label;
std::istringstream ss(line);
ss >> label;
Entry e;
unsigned long fidx;
if (sscanf(tmp, "%lu:%f", &fidx, &e.fvalue) == 2) {
while (!ss.eof()) {
if (!(ss >> fidx)) break;
ss.ignore(32, ':');
if (!(ss >> e.fvalue)) break;
e.findex = static_cast<index_t>(fidx);
data.push_back(e);
feat_dim = std::max(fidx, feat_dim);
} else {
if (!init) {
labels.push_back(label);
row_ptr.push_back(data.size());
}
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
init = false;
}
labels.push_back(label);
row_ptr.push_back(data.size());
}
// last row
labels.push_back(label);
row_ptr.push_back(data.size());
delete in;
feat_dim += 1;
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
"feature dimension exceed limit of index_t"\
"consider change the index_t to unsigned long");
// close the filed
if (fi != stdin) fclose(fi);
}
inline size_t NumRow(void) const {
return row_ptr.size() - 1;
@ -98,6 +88,7 @@ struct SparseMat {
std::vector<Entry> data;
std::vector<float> labels;
};
// dense matrix
struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {

View File

@ -1,40 +0,0 @@
#ifndef RABIT_LEARN_UTILS_IO_H_
#define RABIT_LEARN_UTILS_IO_H_
/*!
* \file io.h
* \brief additional stream interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_IO_H_