make wrapper ok

This commit is contained in:
tqchen
2014-11-23 14:03:59 -08:00
parent 69b2f31098
commit 5f08313cb2
15 changed files with 160 additions and 24 deletions

View File

@@ -13,6 +13,11 @@
namespace xgboost {
namespace io {
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
if (!strcmp(fname, "stdin")) {
DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadText(fname, silent);
return dmat;
}
std::string tmp_fname;
const char *fname_ext = NULL;
if (strchr(fname, ';') != NULL) {

View File

@@ -84,7 +84,12 @@ class DMatrixSimple : public DataMatrix {
inline void LoadText(const char* fname, bool silent = false) {
using namespace std;
this->Clear();
FILE* file = utils::FopenCheck(fname, "r");
FILE* file;
if (!strcmp(fname, "stdin")) {
file = stdin;
} else {
file = utils::FopenCheck(fname, "r");
}
float label; bool init = true;
char tmp[1024];
std::vector<RowBatch::Entry> feats;
@@ -112,7 +117,9 @@ class DMatrixSimple : public DataMatrix {
static_cast<unsigned long>(info.num_col()),
static_cast<unsigned long>(row_data_.size()), fname);
}
fclose(file);
if (file != stdin) {
fclose(file);
}
// try to load in additional file
std::string name = fname;
std::string gname = name + ".group";

View File

@@ -352,7 +352,7 @@ class SyncManager {
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes < buffer_size, "too large type_nbytes=%lu, buffer_size", type_nbytes, buffer_size);
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
@@ -487,6 +487,8 @@ void AllReduce<uint32_t>(uint32_t *sendrecvbuf, int count, ReduceOp op) {
typedef uint32_t DType;
switch(op) {
case kBitwiseOR: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceBitOR<DType>); return;
case kSum: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceSum<DType>); return;
case kMax: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceMax<DType>); return;
default: utils::Error("reduce op not supported");
}
}