move stream to rabit part, support rabit on yarn

This commit is contained in:
tqchen
2015-03-09 14:43:46 -07:00
parent 9f7c6fe271
commit a8d5af39fd
14 changed files with 134 additions and 500 deletions

View File

@@ -14,10 +14,11 @@
namespace xgboost {
namespace io {
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
if (!std::strcmp(fname, "stdin")) {
DataMatrix* LoadDataMatrix(const char *fname, bool silent,
bool savebuffer, bool loadsplit) {
if (!std::strcmp(fname, "stdin") || loadsplit) {
DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadText(fname, silent);
dmat->LoadText(fname, silent, loadsplit);
return dmat;
}
int magic;

View File

@@ -19,9 +19,14 @@ typedef learner::DMatrix DataMatrix;
* \param fname file name to be loaded
* \param silent whether print message during loading
* \param savebuffer whether temporal buffer the file if the file is in text format
* \param loadsplit whether we only load a split of input files
* such that each worker node get a split of the data
* \return a loaded DMatrix
*/
DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuffer = true);
DataMatrix* LoadDataMatrix(const char *fname,
bool silent,
bool savebuffer,
bool loadsplit);
/*!
* \brief save DataMatrix into stream,
* note: the saved dmatrix format may not be in exactly same as input

View File

@@ -11,12 +11,14 @@
#include <string>
#include <cstring>
#include <vector>
#include <sstream>
#include <algorithm>
#include "../data.h"
#include "../utils/utils.h"
#include "../learner/dmatrix.h"
#include "./io.h"
#include "./simple_fmatrix-inl.hpp"
#include "../sync/sync.h"
namespace xgboost {
namespace io {
@@ -77,63 +79,59 @@ class DMatrixSimple : public DataMatrix {
return row_ptr_.size() - 2;
}
/*!
* \brief load from text file
* \param fname name of text data
* \brief load split of input, used in distributed mode
* \param uri the uri of input
* \param loadsplit whether loadsplit of data or all the data
* \param silent whether print information or not
*/
inline void LoadText(const char* fname, bool silent = false) {
using namespace std;
inline void LoadText(const char *uri, bool silent = false, bool loadsplit = false) {
int rank = 0, npart = 1;
if (loadsplit) {
rank = rabit::GetRank();
npart = rabit::GetWorldSize();
}
rabit::io::InputSplit *in =
rabit::io::CreateInputSplit(uri, rank, npart);
this->Clear();
FILE* file;
if (!strcmp(fname, "stdin")) {
file = stdin;
} else {
file = utils::FopenCheck(fname, "r");
}
float label; bool init = true;
char tmp[1024];
std::vector<RowBatch::Entry> feats;
while (fscanf(file, "%s", tmp) == 1) {
RowBatch::Entry e;
if (sscanf(tmp, "%u:%f", &e.index, &e.fvalue) == 2) {
std::string line;
while (in->NextLine(&line)) {
float label;
std::istringstream ss(line);
std::vector<RowBatch::Entry> feats;
ss >> label;
while (!ss.eof()) {
RowBatch::Entry e;
if (!(ss >> e.index)) break;
ss.ignore(32, ':');
if (!(ss >> e.fvalue)) break;
feats.push_back(e);
} else {
if (!init) {
info.labels.push_back(label);
this->AddRow(feats);
}
feats.clear();
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
init = false;
}
info.labels.push_back(label);
this->AddRow(feats);
}
info.labels.push_back(label);
this->AddRow(feats);
delete in;
if (!silent) {
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
static_cast<unsigned long>(info.num_row()),
static_cast<unsigned long>(info.num_col()),
static_cast<unsigned long>(row_data_.size()), fname);
}
if (file != stdin) {
fclose(file);
static_cast<unsigned long>(row_data_.size()), uri);
}
// try to load in additional file
std::string name = fname;
std::string gname = name + ".group";
if (info.TryLoadGroup(gname.c_str(), silent)) {
utils::Check(info.group_ptr.back() == info.num_row(),
"DMatrix: group data does not match the number of rows in features");
}
std::string wname = name + ".weight";
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
utils::Check(info.weights.size() == info.num_row(),
"DMatrix: weight data does not match the number of rows in features");
}
std::string mname = name + ".base_margin";
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
if (!loadsplit) {
std::string name = uri;
std::string gname = name + ".group";
if (info.TryLoadGroup(gname.c_str(), silent)) {
utils::Check(info.group_ptr.back() == info.num_row(),
"DMatrix: group data does not match the number of rows in features");
}
std::string wname = name + ".weight";
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
utils::Check(info.weights.size() == info.num_row(),
"DMatrix: weight data does not match the number of rows in features");
}
std::string mname = name + ".base_margin";
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
}
}
}
/*!