introduce input split
This commit is contained in:
2
rabit-learn/linear/.gitignore
vendored
Normal file
2
rabit-learn/linear/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
mushroom.row*
|
||||
*.model
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "./linear.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/base64.h"
|
||||
#include "../io/io.h"
|
||||
#include "../io/base64.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace linear {
|
||||
@@ -74,23 +74,20 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
printf("Finishing writing to %s\n", name_pred.c_str());
|
||||
}
|
||||
inline void LoadModel(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
io::FileStream fi(fname, "rb");
|
||||
std::string header; header.resize(4);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
utils::FileStream fi(fp);
|
||||
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
|
||||
// base64 format
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fp);
|
||||
io::Base64InStream bsin(&fi);
|
||||
bsin.InitPosition();
|
||||
model.Load(bsin);
|
||||
fclose(fp);
|
||||
return;
|
||||
} else if (header == "binf") {
|
||||
model.Load(fi);
|
||||
fclose(fp);
|
||||
return;
|
||||
return;
|
||||
} else {
|
||||
utils::Error("invalid model file");
|
||||
}
|
||||
@@ -98,27 +95,16 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
inline void SaveModel(const char *fname,
|
||||
const float *wptr,
|
||||
bool save_base64 = false) {
|
||||
FILE *fp;
|
||||
bool use_stdout = false;
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fp = stdout;
|
||||
use_stdout = true;
|
||||
} else {
|
||||
fp = utils::FopenCheck(fname, "wb");
|
||||
}
|
||||
utils::FileStream fo(fp);
|
||||
if (save_base64 != 0|| use_stdout) {
|
||||
io::FileStream fo(fname, "wb");
|
||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||
fo.Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fp);
|
||||
io::Base64OutStream bout(&fo);
|
||||
model.Save(bout, wptr);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo.Write("binf", 4);
|
||||
model.Save(fo, wptr);
|
||||
}
|
||||
if (!use_stdout) {
|
||||
fclose(fp);
|
||||
}
|
||||
}
|
||||
inline void LoadData(const char *fname) {
|
||||
dtrain.Load(fname);
|
||||
|
||||
@@ -5,11 +5,7 @@ then
|
||||
exit -1
|
||||
fi
|
||||
|
||||
rm -rf mushroom.row* *.model
|
||||
rm -rf *.model
|
||||
k=$1
|
||||
|
||||
# split the lib svm file into k subfiles
|
||||
python splitrows.py ../data/agaricus.txt.train mushroom $k
|
||||
|
||||
# run xgboost mpi
|
||||
../../tracker/rabit_demo.py -n $k linear.mock mushroom.row\%d "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1
|
||||
../../tracker/rabit_demo.py -n $k linear.mock ../data/agaricus.txt.train "${*:2}" reg_L1=1 mock=0,1,1,0 mock=1,1,1,0 mock=0,2,1,1
|
||||
|
||||
@@ -5,13 +5,10 @@ then
|
||||
exit -1
|
||||
fi
|
||||
|
||||
rm -rf mushroom.row* *.model
|
||||
rm -rf *.model
|
||||
k=$1
|
||||
|
||||
# split the lib svm file into k subfiles
|
||||
python splitrows.py ../data/agaricus.txt.train mushroom $k
|
||||
|
||||
# run xgboost mpi
|
||||
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
|
||||
# run linear model, the program will automatically split the inputs
|
||||
../../tracker/rabit_demo.py -n $k linear.rabit ../data/agaricus.txt.train reg_L1=1
|
||||
|
||||
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
import sys
|
||||
import random
|
||||
|
||||
# split libsvm file into different rows
|
||||
if len(sys.argv) < 4:
|
||||
print ('Usage:<fin> <fo> k')
|
||||
exit(0)
|
||||
|
||||
random.seed(10)
|
||||
|
||||
k = int(sys.argv[3])
|
||||
fi = open( sys.argv[1], 'r' )
|
||||
fos = []
|
||||
|
||||
for i in range(k):
|
||||
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
|
||||
|
||||
for l in open(sys.argv[1]):
|
||||
i = random.randint(0, k-1)
|
||||
fos[i].write(l)
|
||||
|
||||
for f in fos:
|
||||
f.close()
|
||||
Reference in New Issue
Block a user