move stream to rabit part, support rabit on yarn
This commit is contained in:
@@ -14,10 +14,11 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace io {
|
||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
||||
if (!std::strcmp(fname, "stdin")) {
|
||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent,
|
||||
bool savebuffer, bool loadsplit) {
|
||||
if (!std::strcmp(fname, "stdin") || loadsplit) {
|
||||
DMatrixSimple *dmat = new DMatrixSimple();
|
||||
dmat->LoadText(fname, silent);
|
||||
dmat->LoadText(fname, silent, loadsplit);
|
||||
return dmat;
|
||||
}
|
||||
int magic;
|
||||
|
||||
@@ -19,9 +19,14 @@ typedef learner::DMatrix DataMatrix;
|
||||
* \param fname file name to be loaded
|
||||
* \param silent whether print message during loading
|
||||
* \param savebuffer whether temporal buffer the file if the file is in text format
|
||||
* \param loadsplit whether we only load a split of input files
|
||||
* such that each worker node get a split of the data
|
||||
* \return a loaded DMatrix
|
||||
*/
|
||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuffer = true);
|
||||
DataMatrix* LoadDataMatrix(const char *fname,
|
||||
bool silent,
|
||||
bool savebuffer,
|
||||
bool loadsplit);
|
||||
/*!
|
||||
* \brief save DataMatrix into stream,
|
||||
* note: the saved dmatrix format may not be in exactly same as input
|
||||
|
||||
@@ -11,12 +11,14 @@
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "../data.h"
|
||||
#include "../utils/utils.h"
|
||||
#include "../learner/dmatrix.h"
|
||||
#include "./io.h"
|
||||
#include "./simple_fmatrix-inl.hpp"
|
||||
#include "../sync/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace io {
|
||||
@@ -77,63 +79,59 @@ class DMatrixSimple : public DataMatrix {
|
||||
return row_ptr_.size() - 2;
|
||||
}
|
||||
/*!
|
||||
* \brief load from text file
|
||||
* \param fname name of text data
|
||||
* \brief load split of input, used in distributed mode
|
||||
* \param uri the uri of input
|
||||
* \param loadsplit whether loadsplit of data or all the data
|
||||
* \param silent whether print information or not
|
||||
*/
|
||||
inline void LoadText(const char* fname, bool silent = false) {
|
||||
using namespace std;
|
||||
inline void LoadText(const char *uri, bool silent = false, bool loadsplit = false) {
|
||||
int rank = 0, npart = 1;
|
||||
if (loadsplit) {
|
||||
rank = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
}
|
||||
rabit::io::InputSplit *in =
|
||||
rabit::io::CreateInputSplit(uri, rank, npart);
|
||||
this->Clear();
|
||||
FILE* file;
|
||||
if (!strcmp(fname, "stdin")) {
|
||||
file = stdin;
|
||||
} else {
|
||||
file = utils::FopenCheck(fname, "r");
|
||||
}
|
||||
float label; bool init = true;
|
||||
char tmp[1024];
|
||||
std::vector<RowBatch::Entry> feats;
|
||||
while (fscanf(file, "%s", tmp) == 1) {
|
||||
RowBatch::Entry e;
|
||||
if (sscanf(tmp, "%u:%f", &e.index, &e.fvalue) == 2) {
|
||||
std::string line;
|
||||
while (in->NextLine(&line)) {
|
||||
float label;
|
||||
std::istringstream ss(line);
|
||||
std::vector<RowBatch::Entry> feats;
|
||||
ss >> label;
|
||||
while (!ss.eof()) {
|
||||
RowBatch::Entry e;
|
||||
if (!(ss >> e.index)) break;
|
||||
ss.ignore(32, ':');
|
||||
if (!(ss >> e.fvalue)) break;
|
||||
feats.push_back(e);
|
||||
} else {
|
||||
if (!init) {
|
||||
info.labels.push_back(label);
|
||||
this->AddRow(feats);
|
||||
}
|
||||
feats.clear();
|
||||
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
|
||||
init = false;
|
||||
}
|
||||
info.labels.push_back(label);
|
||||
this->AddRow(feats);
|
||||
}
|
||||
|
||||
info.labels.push_back(label);
|
||||
this->AddRow(feats);
|
||||
|
||||
delete in;
|
||||
if (!silent) {
|
||||
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
||||
static_cast<unsigned long>(info.num_row()),
|
||||
static_cast<unsigned long>(info.num_col()),
|
||||
static_cast<unsigned long>(row_data_.size()), fname);
|
||||
}
|
||||
if (file != stdin) {
|
||||
fclose(file);
|
||||
static_cast<unsigned long>(row_data_.size()), uri);
|
||||
}
|
||||
// try to load in additional file
|
||||
std::string name = fname;
|
||||
std::string gname = name + ".group";
|
||||
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
||||
utils::Check(info.group_ptr.back() == info.num_row(),
|
||||
"DMatrix: group data does not match the number of rows in features");
|
||||
}
|
||||
std::string wname = name + ".weight";
|
||||
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
||||
utils::Check(info.weights.size() == info.num_row(),
|
||||
"DMatrix: weight data does not match the number of rows in features");
|
||||
}
|
||||
std::string mname = name + ".base_margin";
|
||||
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
|
||||
if (!loadsplit) {
|
||||
std::string name = uri;
|
||||
std::string gname = name + ".group";
|
||||
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
||||
utils::Check(info.group_ptr.back() == info.num_row(),
|
||||
"DMatrix: group data does not match the number of rows in features");
|
||||
}
|
||||
std::string wname = name + ".weight";
|
||||
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
||||
utils::Check(info.weights.size() == info.num_row(),
|
||||
"DMatrix: weight data does not match the number of rows in features");
|
||||
}
|
||||
std::string mname = name + ".base_margin";
|
||||
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user