Merge commit '3d11f56880521c1d45504c965ae12886e9b72ace'
This commit is contained in:
@@ -33,9 +33,9 @@ static const char EncodeTable[] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
} // namespace base64
|
||||
/*! \brief the stream that reads from base64, note we take from file pointers */
|
||||
class Base64InStream: public IStream {
|
||||
class Base64InStream: public Stream {
|
||||
public:
|
||||
explicit Base64InStream(IStream *fs) : reader_(256) {
|
||||
explicit Base64InStream(Stream *fs) : reader_(256) {
|
||||
reader_.set_stream(fs);
|
||||
num_prev = 0; tmp_ch = 0;
|
||||
}
|
||||
@@ -147,9 +147,9 @@ class Base64InStream: public IStream {
|
||||
static const bool kStrictCheck = false;
|
||||
};
|
||||
/*! \brief the stream that write to base64, note we take from file pointers */
|
||||
class Base64OutStream: public IStream {
|
||||
class Base64OutStream: public Stream {
|
||||
public:
|
||||
explicit Base64OutStream(IStream *fp) : fp(fp) {
|
||||
explicit Base64OutStream(Stream *fp) : fp(fp) {
|
||||
buf_top = 0;
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
@@ -198,7 +198,7 @@ class Base64OutStream: public IStream {
|
||||
}
|
||||
|
||||
private:
|
||||
IStream *fp;
|
||||
Stream *fp;
|
||||
int buf_top;
|
||||
unsigned char buf[4];
|
||||
std::string out_buf;
|
||||
|
||||
@@ -20,7 +20,7 @@ class StreamBufferReader {
|
||||
/*!
|
||||
* \brief set input stream
|
||||
*/
|
||||
inline void set_stream(IStream *stream) {
|
||||
inline void set_stream(Stream *stream) {
|
||||
stream_ = stream;
|
||||
read_len_ = read_ptr_ = 1;
|
||||
}
|
||||
@@ -45,7 +45,7 @@ class StreamBufferReader {
|
||||
|
||||
private:
|
||||
/*! \brief the underlying stream */
|
||||
IStream *stream_;
|
||||
Stream *stream_;
|
||||
/*! \brief buffer to hold data */
|
||||
std::string buffer_;
|
||||
/*! \brief length of valid data in buffer */
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
/*! \brief implementation of file i/o stream */
|
||||
class FileStream : public utils::ISeekStream {
|
||||
class FileStream : public utils::SeekStream {
|
||||
public:
|
||||
explicit FileStream(const char *fname, const char *mode)
|
||||
: use_stdio(false) {
|
||||
@@ -84,7 +84,7 @@ class FileProvider : public LineSplitter::IFileProvider {
|
||||
}
|
||||
// destrucor
|
||||
virtual ~FileProvider(void) {}
|
||||
virtual utils::ISeekStream *Open(size_t file_index) {
|
||||
virtual utils::SeekStream *Open(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new FileStream(fnames_[file_index].c_str(), "rb");
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
/*! \brief io interface */
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
class HDFSStream : public ISeekStream {
|
||||
class HDFSStream : public SeekStream {
|
||||
public:
|
||||
HDFSStream(hdfsFS fs,
|
||||
const char *fname,
|
||||
@@ -147,7 +147,7 @@ class HDFSProvider : public LineSplitter::IFileProvider {
|
||||
virtual const std::vector<size_t> &FileSize(void) const {
|
||||
return fsize_;
|
||||
}
|
||||
virtual ISeekStream *Open(size_t file_index) {
|
||||
virtual SeekStream *Open(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
||||
}
|
||||
|
||||
template<typename TStream>
|
||||
class StreamAdapter : public IStream {
|
||||
class StreamAdapter : public Stream {
|
||||
public:
|
||||
explicit StreamAdapter(TStream *stream)
|
||||
: stream_(stream) {
|
||||
@@ -75,9 +75,9 @@ class StreamAdapter : public IStream {
|
||||
* \param uri the uri of the input, can contain hdfs prefix
|
||||
* \param mode can be 'w' or 'r' for read or write
|
||||
*/
|
||||
inline IStream *CreateStream(const char *uri, const char *mode) {
|
||||
inline Stream *CreateStream(const char *uri, const char *mode) {
|
||||
#if RABIT_USE_WORMHOLE
|
||||
return new StreamAdapter<dmlc::IStream>(dmlc::IStream::Create(uri, mode));
|
||||
return new StreamAdapter<dmlc::Stream>(dmlc::Stream::Create(uri, mode));
|
||||
#else
|
||||
using namespace std;
|
||||
if (!strncmp(uri, "file://", 7)) {
|
||||
|
||||
@@ -26,12 +26,12 @@ namespace rabit {
|
||||
* \brief namespace to handle input split and filesystem interfacing
|
||||
*/
|
||||
namespace io {
|
||||
/*! \brief reused ISeekStream's definition */
|
||||
/*! \brief reused SeekStream's definition */
|
||||
#if RABIT_USE_WORMHOLE
|
||||
typedef dmlc::ISeekStream ISeekStream;
|
||||
typedef dmlc::SeekStream SeekStream;
|
||||
typedef dmlc::InputSplit InputSplit;
|
||||
#else
|
||||
typedef utils::ISeekStream ISeekStream;
|
||||
typedef utils::SeekStream SeekStream;
|
||||
/*!
|
||||
* \brief user facing input split helper,
|
||||
* can be used to get the partition of data used by current node
|
||||
@@ -65,7 +65,7 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
||||
* \param uri the uri of the input, can contain hdfs prefix
|
||||
* \param mode can be 'w' or 'r' for read or write
|
||||
*/
|
||||
inline IStream *CreateStream(const char *uri, const char *mode);
|
||||
inline Stream *CreateStream(const char *uri, const char *mode);
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class LineSplitter : public InputSplit {
|
||||
* \return the corresponding seek stream at head of the stream
|
||||
* the seek stream's resource can be freed by calling delete
|
||||
*/
|
||||
virtual ISeekStream *Open(size_t file_index) = 0;
|
||||
virtual SeekStream *Open(size_t file_index) = 0;
|
||||
/*!
|
||||
* \return const reference to size of each files
|
||||
*/
|
||||
@@ -142,7 +142,7 @@ class LineSplitter : public InputSplit {
|
||||
/*! \brief FileProvider */
|
||||
IFileProvider *provider_;
|
||||
/*! \brief current input stream */
|
||||
utils::ISeekStream *fs_;
|
||||
utils::SeekStream *fs_;
|
||||
/*! \brief file pointer of which file to read on */
|
||||
size_t file_ptr_;
|
||||
/*! \brief file pointer where the end of file lies */
|
||||
|
||||
@@ -7,22 +7,54 @@
|
||||
|
||||
using namespace rabit;
|
||||
|
||||
// simple dense matrix, mshadow or Eigen matrix was better
|
||||
// this was was OK
|
||||
struct Matrix {
|
||||
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
|
||||
this->nrow = nrow;
|
||||
this->ncol = ncol;
|
||||
data.resize(nrow * ncol);
|
||||
std::fill(data.begin(), data.end(), v);
|
||||
}
|
||||
inline float *operator[](size_t i) {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline const float *operator[](size_t i) const {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline void Print(utils::Stream *fo) {
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
std::ostringstream ss;
|
||||
ss << data[i];
|
||||
if ((i+1) % ncol == 0) {
|
||||
ss << '\n';
|
||||
} else {
|
||||
ss << ' ';
|
||||
}
|
||||
}
|
||||
std::string s = ss.str();
|
||||
}
|
||||
// number of data
|
||||
size_t nrow, ncol;
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
// kmeans model
|
||||
class Model : public rabit::ISerializable {
|
||||
class Model : public rabit::Serializable {
|
||||
public:
|
||||
// matrix of centroids
|
||||
Matrix centroids;
|
||||
// load from stream
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fi.Read(¢roids.data);
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
fi->Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fi->Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fi->Read(¢roids.data);
|
||||
}
|
||||
/*! \brief save the model to the stream */
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fo.Write(centroids.data);
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
fo->Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fo->Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fo->Write(centroids.data);
|
||||
}
|
||||
virtual void InitModel(unsigned num_cluster, unsigned feat_dim) {
|
||||
centroids.Init(num_cluster, feat_dim);
|
||||
@@ -153,7 +185,7 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
}
|
||||
model.Normalize();
|
||||
rabit::CheckPoint(&model);
|
||||
rabit::LazyCheckPoint(&model);
|
||||
}
|
||||
// output the model file to somewhere
|
||||
if (rabit::GetRank() == 0) {
|
||||
|
||||
@@ -75,7 +75,7 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
printf("Finishing writing to %s\n", name_pred.c_str());
|
||||
}
|
||||
inline void LoadModel(const char *fname) {
|
||||
IStream *fi = io::CreateStream(fname, "r");
|
||||
Stream *fi = io::CreateStream(fname, "r");
|
||||
std::string header; header.resize(4);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
@@ -84,9 +84,9 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
if (header == "bs64") {
|
||||
io::Base64InStream bsin(fi);
|
||||
bsin.InitPosition();
|
||||
model.Load(bsin);
|
||||
model.Load(&bsin);
|
||||
} else if (header == "binf") {
|
||||
model.Load(*fi);
|
||||
model.Load(fi);
|
||||
} else {
|
||||
utils::Error("invalid model file");
|
||||
}
|
||||
@@ -95,15 +95,15 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
inline void SaveModel(const char *fname,
|
||||
const float *wptr,
|
||||
bool save_base64 = false) {
|
||||
IStream *fo = io::CreateStream(fname, "w");
|
||||
Stream *fo = io::CreateStream(fname, "w");
|
||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||
fo->Write("bs64\t", 5);
|
||||
io::Base64OutStream bout(fo);
|
||||
model.Save(bout, wptr);
|
||||
model.Save(&bout, wptr);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo->Write("binf", 4);
|
||||
model.Save(*fo, wptr);
|
||||
model.Save(fo, wptr);
|
||||
}
|
||||
delete fo;
|
||||
}
|
||||
@@ -128,11 +128,11 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
}
|
||||
}
|
||||
// load model
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(&model.param, sizeof(model.param));
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
fi->Read(&model.param, sizeof(model.param));
|
||||
}
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(&model.param, sizeof(model.param));
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
fo->Write(&model.param, sizeof(model.param));
|
||||
}
|
||||
virtual double Eval(const float *weight, size_t size) {
|
||||
if (nthread != 0) omp_set_num_threads(nthread);
|
||||
|
||||
@@ -113,17 +113,17 @@ struct LinearModel {
|
||||
if (weight != NULL) delete [] weight;
|
||||
}
|
||||
// load model
|
||||
inline void Load(rabit::IStream &fi) {
|
||||
fi.Read(¶m, sizeof(param));
|
||||
inline void Load(rabit::Stream *fi) {
|
||||
fi->Read(¶m, sizeof(param));
|
||||
if (weight == NULL) {
|
||||
weight = new float[param.num_feature + 1];
|
||||
}
|
||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||
fi->Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
|
||||
fo.Write(¶m, sizeof(param));
|
||||
inline void Save(rabit::Stream *fo, const float *wptr = NULL) {
|
||||
fo->Write(¶m, sizeof(param));
|
||||
if (wptr == NULL) wptr = weight;
|
||||
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||
fo->Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
inline float Predict(const SparseMat::Vector &v) const {
|
||||
return param.Predict(weight, v);
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace solver {
|
||||
* to remember the state parameters that might need to remember
|
||||
*/
|
||||
template<typename DType>
|
||||
class IObjFunction : public rabit::ISerializable {
|
||||
class IObjFunction : public rabit::Serializable {
|
||||
public:
|
||||
// destructor
|
||||
virtual ~IObjFunction(void){}
|
||||
@@ -463,7 +463,7 @@ class LBFGSSolver {
|
||||
}
|
||||
}
|
||||
// global solver state
|
||||
struct GlobalState : public rabit::ISerializable {
|
||||
struct GlobalState : public rabit::Serializable {
|
||||
public:
|
||||
// memory size of L-BFGS
|
||||
size_t size_memory;
|
||||
@@ -514,28 +514,28 @@ class LBFGSSolver {
|
||||
MapIndex(j, offset_, size_memory)];
|
||||
}
|
||||
// load the shift array
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(&size_memory, sizeof(size_memory));
|
||||
fi.Read(&num_iteration, sizeof(num_iteration));
|
||||
fi.Read(&num_dim, sizeof(num_dim));
|
||||
fi.Read(&init_objval, sizeof(init_objval));
|
||||
fi.Read(&old_objval, sizeof(old_objval));
|
||||
fi.Read(&offset_, sizeof(offset_));
|
||||
fi.Read(&data);
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
fi->Read(&size_memory, sizeof(size_memory));
|
||||
fi->Read(&num_iteration, sizeof(num_iteration));
|
||||
fi->Read(&num_dim, sizeof(num_dim));
|
||||
fi->Read(&init_objval, sizeof(init_objval));
|
||||
fi->Read(&old_objval, sizeof(old_objval));
|
||||
fi->Read(&offset_, sizeof(offset_));
|
||||
fi->Read(&data);
|
||||
this->AllocSpace();
|
||||
fi.Read(weight, sizeof(DType) * num_dim);
|
||||
fi->Read(weight, sizeof(DType) * num_dim);
|
||||
obj->Load(fi);
|
||||
}
|
||||
// save the shift array
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(&size_memory, sizeof(size_memory));
|
||||
fo.Write(&num_iteration, sizeof(num_iteration));
|
||||
fo.Write(&num_dim, sizeof(num_dim));
|
||||
fo.Write(&init_objval, sizeof(init_objval));
|
||||
fo.Write(&old_objval, sizeof(old_objval));
|
||||
fo.Write(&offset_, sizeof(offset_));
|
||||
fo.Write(data);
|
||||
fo.Write(weight, sizeof(DType) * num_dim);
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
fo->Write(&size_memory, sizeof(size_memory));
|
||||
fo->Write(&num_iteration, sizeof(num_iteration));
|
||||
fo->Write(&num_dim, sizeof(num_dim));
|
||||
fo->Write(&init_objval, sizeof(init_objval));
|
||||
fo->Write(&old_objval, sizeof(old_objval));
|
||||
fo->Write(&offset_, sizeof(offset_));
|
||||
fo->Write(data);
|
||||
fo->Write(weight, sizeof(DType) * num_dim);
|
||||
obj->Save(fo);
|
||||
}
|
||||
inline void Shift(void) {
|
||||
@@ -556,7 +556,7 @@ class LBFGSSolver {
|
||||
}
|
||||
};
|
||||
/*! \brief rolling array that carries history information */
|
||||
struct HistoryArray : public rabit::ISerializable {
|
||||
struct HistoryArray : public rabit::Serializable {
|
||||
public:
|
||||
HistoryArray(void) : dptr_(NULL) {
|
||||
num_useful_ = 0;
|
||||
@@ -609,26 +609,26 @@ class LBFGSSolver {
|
||||
num_useful_ = num_useful;
|
||||
}
|
||||
// load the shift array
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(&num_col_, sizeof(num_col_));
|
||||
fi.Read(&stride_, sizeof(stride_));
|
||||
fi.Read(&size_memory_, sizeof(size_memory_));
|
||||
fi.Read(&num_useful_, sizeof(num_useful_));
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
fi->Read(&num_col_, sizeof(num_col_));
|
||||
fi->Read(&stride_, sizeof(stride_));
|
||||
fi->Read(&size_memory_, sizeof(size_memory_));
|
||||
fi->Read(&num_useful_, sizeof(num_useful_));
|
||||
this->Init(num_col_, size_memory_);
|
||||
for (size_t i = 0; i < num_useful_; ++i) {
|
||||
fi.Read((*this)[i], num_col_ * sizeof(DType));
|
||||
fi.Read((*this)[i + size_memory_], num_col_ * sizeof(DType));
|
||||
fi->Read((*this)[i], num_col_ * sizeof(DType));
|
||||
fi->Read((*this)[i + size_memory_], num_col_ * sizeof(DType));
|
||||
}
|
||||
}
|
||||
// save the shift array
|
||||
virtual void Save(rabit::IStream &fi) const {
|
||||
fi.Write(&num_col_, sizeof(num_col_));
|
||||
fi.Write(&stride_, sizeof(stride_));
|
||||
fi.Write(&size_memory_, sizeof(size_memory_));
|
||||
fi.Write(&num_useful_, sizeof(num_useful_));
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
fo->Write(&num_col_, sizeof(num_col_));
|
||||
fo->Write(&stride_, sizeof(stride_));
|
||||
fo->Write(&size_memory_, sizeof(size_memory_));
|
||||
fo->Write(&num_useful_, sizeof(num_useful_));
|
||||
for (size_t i = 0; i < num_useful_; ++i) {
|
||||
fi.Write((*this)[i], num_col_ * sizeof(DType));
|
||||
fi.Write((*this)[i + size_memory_], num_col_ * sizeof(DType));
|
||||
fo->Write((*this)[i], num_col_ * sizeof(DType));
|
||||
fo->Write((*this)[i + size_memory_], num_col_ * sizeof(DType));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -93,43 +93,6 @@ struct SparseMat {
|
||||
std::vector<float> labels;
|
||||
};
|
||||
|
||||
// dense matrix
|
||||
struct Matrix {
|
||||
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
|
||||
this->nrow = nrow;
|
||||
this->ncol = ncol;
|
||||
data.resize(nrow * ncol);
|
||||
std::fill(data.begin(), data.end(), v);
|
||||
}
|
||||
inline float *operator[](size_t i) {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline const float *operator[](size_t i) const {
|
||||
return &data[0] + i * ncol;
|
||||
}
|
||||
inline void Print(const char *fname) {
|
||||
FILE *fo;
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fo = stdout;
|
||||
} else {
|
||||
fo = utils::FopenCheck(fname, "w");
|
||||
}
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
fprintf(fo, "%g", data[i]);
|
||||
if ((i+1) % ncol == 0) {
|
||||
fprintf(fo, "\n");
|
||||
} else {
|
||||
fprintf(fo, " ");
|
||||
}
|
||||
}
|
||||
// close the filed
|
||||
if (fo != stdout) fclose(fo);
|
||||
}
|
||||
// number of data
|
||||
size_t nrow, ncol;
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
/*!\brief computes a random number modulo the value */
|
||||
inline int Random(int value) {
|
||||
return rand() % value;
|
||||
|
||||
Reference in New Issue
Block a user