Merge remote branch 'src/master'

This commit is contained in:
Vadim Khotilovich 2015-04-07 17:16:19 -05:00
commit 0405676734
43 changed files with 815 additions and 438 deletions

View File

@ -16,18 +16,28 @@ ifeq ($(cxx11),1)
else else
endif endif
ifeq ($(hdfs),1) # handling dmlc
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include ifdef dmlc
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(JAVA_HOME)/jre/lib/amd64/server -lhdfs -ljvm ifndef config
ifneq ("$(wildcard $(dmlc)/config.mk)","")
config = $(dmlc)/config.mk
else
config = $(dmlc)/make/config.mk
endif
endif
include $(config)
include $(dmlc)/make/dmlc.mk
LDFLAGS+= $(DMLC_LDFLAGS)
LIBDMLC=$(dmlc)/libdmlc.a
else else
CFLAGS+= -DRABIT_USE_HDFS=0 LIBDMLC=dmlc_simple.o
endif endif
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
MOCKBIN = xgboost.mock MOCKBIN = xgboost.mock
OBJ = updater.o gbm.o io.o main.o OBJ = updater.o gbm.o io.o main.o dmlc_simple.o
MPIBIN = xgboost.mpi MPIBIN =
SLIB = wrapper/libxgboostwrapper.so SLIB = wrapper/libxgboostwrapper.so
.PHONY: clean all mpi python Rpack .PHONY: clean all mpi python Rpack
@ -38,23 +48,22 @@ mpi: $(MPIBIN)
python: wrapper/libxgboostwrapper.so python: wrapper/libxgboostwrapper.so
# now the wrapper takes in two files. io and wrapper part # now the wrapper takes in two files. io and wrapper part
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
dmlc_simple.o: src/io/dmlc_simple.cpp src/utils/*.h
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
xgboost.mpi: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit_mpi.a xgboost: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit.a $(LIBDMLC)
xgboost.mock: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit_mock.a wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o subtree/rabit/lib/librabit.a $(LIBDMLC)
xgboost: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit.a
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o subtree/rabit/lib/librabit.a
# dependency on rabit # dependency on rabit
subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc
cd subtree/rabit;make lib/librabit.a; cd ../.. + cd subtree/rabit;make lib/librabit.a; cd ../..
subtree/rabit/lib/librabit_empty.a: subtree/rabit/src/engine_empty.cc subtree/rabit/lib/librabit_empty.a: subtree/rabit/src/engine_empty.cc
cd subtree/rabit;make lib/librabit_empty.a; cd ../.. + cd subtree/rabit;make lib/librabit_empty.a; cd ../..
subtree/rabit/lib/librabit_mock.a: subtree/rabit/src/engine_mock.cc subtree/rabit/lib/librabit_mock.a: subtree/rabit/src/engine_mock.cc
cd subtree/rabit;make lib/librabit_mock.a; cd ../.. + cd subtree/rabit;make lib/librabit_mock.a; cd ../..
subtree/rabit/lib/librabit_mpi.a: subtree/rabit/src/engine_mpi.cc subtree/rabit/lib/librabit_mpi.a: subtree/rabit/src/engine_mpi.cc
cd subtree/rabit;make lib/librabit_mpi.a; cd ../.. + cd subtree/rabit;make lib/librabit_mpi.a; cd ../..
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)

View File

@ -4,4 +4,5 @@ PKGROOT=../../
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT) PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT)
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o $(PKGROOT)/src/io/dmlc_simple.o

View File

@ -15,5 +15,5 @@ xgblib:
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT) -I../.. PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT) -I../..
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o $(PKGROOT)/src/io/dmlc_simple.o
$(OBJECTS) : xgblib $(OBJECTS) : xgblib

View File

@ -26,7 +26,8 @@ Learning about the model: [Introduction to Boosted Trees](http://homes.cs.washin
What's New What's New
========== ==========
* [Distributed XGBoost now runs on YARN](multi-node/hadoop)! * XGBoost now support HDFS and S3
* [Distributed XGBoost now runs on YARN](https://github.com/dmlc/wormhole/tree/master/learn/xgboost)!
* [xgboost user group](https://groups.google.com/forum/#!forum/xgboost-user/) for tracking changes, sharing your experience on xgboost * [xgboost user group](https://groups.google.com/forum/#!forum/xgboost-user/) for tracking changes, sharing your experience on xgboost
* [Distributed XGBoost](multi-node) is now available!! * [Distributed XGBoost](multi-node) is now available!!
* New features in the lastest changes :) * New features in the lastest changes :)
@ -35,8 +36,6 @@ What's New
- Predict leaf index, see [demo/guide-python/predict_leaf_indices.py](demo/guide-python/predict_leaf_indices.py) - Predict leaf index, see [demo/guide-python/predict_leaf_indices.py](demo/guide-python/predict_leaf_indices.py)
* XGBoost wins [Tradeshift Text Classification](https://kaggle2.blob.core.windows.net/forum-message-attachments/60041/1813/TradeshiftTextClassification.pdf?sv=2012-02-12&se=2015-01-02T13%3A55%3A16Z&sr=b&sp=r&sig=5MHvyjCLESLexYcvbSRFumGQXCS7MVmfdBIY3y01tMk%3D) * XGBoost wins [Tradeshift Text Classification](https://kaggle2.blob.core.windows.net/forum-message-attachments/60041/1813/TradeshiftTextClassification.pdf?sv=2012-02-12&se=2015-01-02T13%3A55%3A16Z&sr=b&sp=r&sig=5MHvyjCLESLexYcvbSRFumGQXCS7MVmfdBIY3y01tMk%3D)
* XGBoost wins [HEP meets ML Award in Higgs Boson Challenge](http://atlas.ch/news/2014/machine-learning-wins-the-higgs-challenge.html) * XGBoost wins [HEP meets ML Award in Higgs Boson Challenge](http://atlas.ch/news/2014/machine-learning-wins-the-higgs-challenge.html)
* Thanks to Bing Xu, [XGBoost.jl](https://github.com/antinucleon/XGBoost.jl) allows you to use xgboost from Julia
* Thanks to Tong He, the new [R package](R-package) is available
Features Features
======== ========
@ -87,6 +86,16 @@ Build
``` ```
Then run ```bash build.sh``` normally. This solution is given by [Phil Culliton](https://www.kaggle.com/c/otto-group-product-classification-challenge/forums/t/12947/achieve-0-50776-on-the-leaderboard-in-a-minute-with-xgboost/68308#post68308). Then run ```bash build.sh``` normally. This solution is given by [Phil Culliton](https://www.kaggle.com/c/otto-group-product-classification-challenge/forums/t/12947/achieve-0-50776-on-the-leaderboard-in-a-minute-with-xgboost/68308#post68308).
Build with HDFS and S3 Support
=====
* To build xgboost use with HDFS/S3 support and distributed learnig. It is recommended to build with dmlc, with the following steps
- ```git clone https://github.com/dmlc/dmlc-core```
- Follow instruction in dmlc-core/make/config.mk to compile libdmlc.a
- In root folder of xgboost, type ```make dmlc=dmlc-core```
* This will allow xgboost to directly load data and save model from/to hdfs and s3
- Simply replace the filename with prefix s3:// or hdfs://
* This xgboost that can be used for distributed learning
Version Version
======= =======
* This version xgboost-0.3, the code has been refactored from 0.2x to be cleaner and more flexibility * This version xgboost-0.3, the code has been refactored from 0.2x to be cleaner and more flexibility

View File

@ -1,17 +1,10 @@
Distributed XGBoost Distributed XGBoost
====== ======
This folder contains information of Distributed XGBoost (Distributed GBDT). Distributed XGBoost is now part of [Wormhole](https://github.com/dmlc/wormhole).
Checkout this [Link](https://github.com/dmlc/wormhole/tree/master/learn/xgboost) for usage examples, build and job submissions.
* The distributed version is built on Rabit:[Reliable Allreduce and Broadcast Library](https://github.com/dmlc/rabit) * The distributed version is built on Rabit:[Reliable Allreduce and Broadcast Library](https://github.com/dmlc/rabit)
- Rabit is a portable library that provides fault-tolerance for Allreduce calls for distributed machine learning - Rabit is a portable library that provides fault-tolerance for Allreduce calls for distributed machine learning
- This makes xgboost portable and fault-tolerant against node failures - This makes xgboost portable and fault-tolerant against node failures
* You can run Distributed XGBoost on platforms including Hadoop(see [hadoop folder](hadoop)) and MPI
- Rabit only replies a platform to start the programs, so it should be easy to port xgboost to most platforms
Build
=====
* In the root folder, type ```make```
- If you have C++11 compiler, it is recommended to use ```make cxx11=1```
Notes Notes
==== ====
@ -27,11 +20,9 @@ Notes
Solvers Solvers
===== =====
There are two solvers in distributed xgboost. You can check for local demo of the two solvers, see [row-split](row-split) and [col-split](col-split) * Column-based solver split data by column, each node work on subset of columns,
* Column-based solver split data by column, each node work on subset of columns,
it uses exactly the same algorithm as single node version. it uses exactly the same algorithm as single node version.
* Row-based solver split data by row, each node work on subset of rows, * Row-based solver split data by row, each node work on subset of rows,
it uses an approximate histogram count algorithm, and will only examine subset of it uses an approximate histogram count algorithm, and will only examine subset of
potential split points as opposed to all split points. potential split points as opposed to all split points.
- This is the mode used by current hadoop version, since usually data was stored by rows in many industry system - This is the mode used by current hadoop version, since usually data was stored by rows in many industry system

View File

@ -1,40 +0,0 @@
Distributed XGBoost: Hadoop Yarn Version
====
* The script in this fold shows an example of how to run distributed xgboost on hadoop platform with YARN
* It relies on [Rabit Library](https://github.com/dmlc/rabit) (Reliable Allreduce and Broadcast Interface) and Yarn. Rabit provides an interface to aggregate gradient values and split statistics, that allow xgboost to run reliably on hadoop. You do not need to care how to update model in each iteration, just use the script ```rabit_yarn.py```. For those who want to know how it exactly works, plz refer to the main page of [Rabit](https://github.com/dmlc/rabit).
* Quick start: run ```bash run_mushroom.sh <n_hadoop_workers> <n_thread_per_worker> <path_in_HDFS>```
- This is the hadoop version of binary classification example in the demo folder.
- More info of the usage of xgboost can be refered to [wiki page](https://github.com/dmlc/xgboost/wiki)
Before you run the script
====
* Make sure you have set up the hadoop environment.
- Check variable $HADOOP_PREFIX exists (e.g. run ```echo $HADOOP_PREFIX```)
- Compile xgboost with hdfs support by typing ```make hdfs=1```
How to Use
====
* Input data format: LIBSVM format. The example here uses generated data in demo/data folder.
* Put the training data in HDFS (hadoop distributed file system).
* Use rabit ```rabit_yarn.py``` to submit training task to yarn
* Get the final model file from HDFS, and locally do prediction as well as visualization of model.
Single machine vs Hadoop version
====
If you have used xgboost (single machine version) before, this section will show you how to run xgboost on hadoop with a slight modification on conf file.
* IO: instead of reading and writing file locally, we now use HDFS, put ```hdfs://``` prefix to the address of file you like to access
* File cache: ```rabit_yarn.py``` also provide several ways to cache necesary files, including binary file (xgboost), conf file
- ```rabit_yarn.py``` will automatically cache files in the command line. For example, ```rabit_yarn.py -n 3 $localPath/xgboost mushroom.hadoop.conf``` will cache "xgboost" and "mushroom.hadoop.conf".
- You could also use "-f" to manually cache one or more files, like ```-f file1 -f file2``` or ```-f file1#file2``` (use "#" to spilt file names).
- The local path of cached files in command is "./".
- Since the cached files will be packaged and delivered to hadoop slave nodes, the cached file should not be large.
* Hadoop version also support evaluting each training round. You just need to modify parameters "eval_train".
* More details of submission can be referred to the usage of ```rabit_yarn.py```.
* The model saved by hadoop version is compatible with single machine version.
Notes
====
* The code has been tested on YARN.
* The code is optimized with multi-threading, so you will want to run one xgboost per node/worker for best performance.
- You will want to set <n_thread_per_worker> to be number of cores you have on each machine.
* It is also possible to submit job with hadoop streaming, however, YARN is highly recommended for efficiency reason

View File

@ -1,36 +0,0 @@
# General Parameters, see comment for each definition
# choose the booster, can be gbtree or gblinear
booster = gbtree
# choose logistic regression loss function for binary classification
objective = binary:logistic
# Tree Booster Parameters
# step size shrinkage
eta = 1.0
# minimum loss reduction required to make a further partition
gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child
min_child_weight = 1
# maximum depth of a tree
max_depth = 3
# Task Parameters
# the number of round to do boosting
num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# evaluate on training data as well each round
# eval_train = 1
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
# eval[test] = "agaricus.txt.test"
# Plz donot modify the following parameters
# The path of training data, with prefix hdfs
#data = hdfs:/data/
# The path of model file
#model_out =
# split pattern of xgboost
dsplit = row
# evaluate on training data as well each round
eval_train = 1

View File

@ -1,28 +0,0 @@
#!/bin/bash
if [ "$#" -lt 3 ];
then
echo "Usage: <nworkers> <nthreads> <path_in_HDFS>"
exit -1
fi
# put the local training file to HDFS
hadoop fs -mkdir $3/data
hadoop fs -put ../../demo/data/agaricus.txt.train $3/data
hadoop fs -put ../../demo/data/agaricus.txt.test $3/data
# running rabit, pass address in hdfs
../../subtree/rabit/tracker/rabit_yarn.py -n $1 --vcores $2 ../../xgboost mushroom.hadoop.conf nthread=$2\
data=hdfs://$3/data/agaricus.txt.train\
eval[test]=hdfs://$3/data/agaricus.txt.test\
model_out=hdfs://$3/mushroom.final.model
# get the final model file
hadoop fs -get $3/mushroom.final.model final.model
# output prediction task=pred
../../xgboost mushroom.hadoop.conf task=pred model_in=final.model test:data=../../demo/data/agaricus.txt.test
# print the boosters of final.model in dump.raw.txt
../../xgboost mushroom.hadoop.conf task=dump model_in=final.model name_dump=dump.raw.txt
# use the feature map in printing for better visualization
../../xgboost mushroom.hadoop.conf task=dump model_in=final.model fmap=../../demo/data/featmap.txt name_dump=dump.nice.txt
cat dump.nice.txt

View File

@ -1,18 +0,0 @@
Distributed XGBoost: Row Split Version
====
* You might be interested to checkout the [Hadoop example](../hadoop)
* Machine Rabit: run ```bash machine-row-rabit.sh <n-mpi-process>```
- machine-col-rabit.sh starts xgboost job using rabit
How to Use
====
* First split the data by rows
* In the config, specify data file as containing a wildcard %d, where %d is the rank of the node, each node will load their part of data
* Enable ow split mode by ```dsplit=row```
Notes
====
* The code is multi-threaded, so you want to run one xgboost-mpi per node
* Row-based solver split data by row, each node work on subset of rows, it uses an approximate histogram count algorithm,
and will only examine subset of potential split points as opposed to all split points.

View File

@ -1,20 +0,0 @@
#!/bin/bash
if [[ $# -ne 1 ]]
then
echo "Usage: nprocess"
exit -1
fi
rm -rf train-machine.row* *.model
k=$1
# make machine data
cd ../../demo/regression/
python mapfeat.py
python mknfold.py machine.txt 1
cd -
# split the lib svm file into k subfiles
python splitrows.py ../../demo/regression/machine.txt.train train-machine $k
# run xgboost mpi
../../subtree/rabit/tracker/rabit_demo.py -n $k ../../xgboost.mock machine-row.conf dsplit=row num_round=3 mock=1,1,1,0 mock=0,0,3,0 mock=2,2,3,0

View File

@ -1,24 +0,0 @@
#!/bin/bash
if [[ $# -ne 1 ]]
then
echo "Usage: nprocess"
exit -1
fi
rm -rf train-machine.row* *.model
k=$1
# make machine data
cd ../../demo/regression/
python mapfeat.py
python mknfold.py machine.txt 1
cd -
# split the lib svm file into k subfiles
python splitrows.py ../../demo/regression/machine.txt.train train-machine $k
# run xgboost mpi
../../subtree/rabit/tracker/rabit_demo.py -n $k ../../xgboost machine-row.conf dsplit=row num_round=3 eval_train=1
# run xgboost-mpi save model 0001, continue to run from existing model
../../subtree/rabit/tracker/rabit_demo.py -n $k ../../xgboost machine-row.conf dsplit=row num_round=1
../../subtree/rabit/tracker/rabit_demo.py -n $k ../../xgboost machine-row.conf dsplit=row num_round=2 model_in=0001.model

View File

@ -1,30 +0,0 @@
# General Parameters, see comment for each definition
# choose the tree booster, can also change to gblinear
booster = gbtree
# this is the only difference with classification, use reg:linear to do linear classification
# when labels are in [0,1] we can also use reg:logistic
objective = reg:linear
# Tree Booster Parameters
# step size shrinkage
eta = 1.0
# minimum loss reduction required to make a further partition
gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child
min_child_weight = 1
# maximum depth of a tree
max_depth = 3
# Task parameters
# the number of round to do boosting
num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
use_buffer = 0
# The path of training data
data = "train-machine.row%d"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "../../demo/regression/machine.txt.test"
# The path of test data
test:data = "../../demo/regression/machine.txt.test"

View File

@ -1,24 +0,0 @@
#!/usr/bin/python
import sys
import random
# split libsvm file into different rows
if len(sys.argv) < 4:
print ('Usage:<fin> <fo> k')
exit(0)
random.seed(10)
k = int(sys.argv[3])
fi = open( sys.argv[1], 'r' )
fos = []
for i in range(k):
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
for l in open(sys.argv[1]):
i = random.randint(0, k-1)
fos[i].write(l)
for f in fos:
f.close()

View File

@ -206,6 +206,10 @@ class GBTree : public IGradBooster {
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
delete trees[i]; delete trees[i];
} }
for (size_t i = 0; i < updaters.size(); ++i) {
delete updaters[i];
}
updaters.clear();
trees.clear(); trees.clear();
pred_buffer.clear(); pred_buffer.clear();
pred_counter.clear(); pred_counter.clear();
@ -444,12 +448,12 @@ class GBTree : public IGradBooster {
int reserved[31]; int reserved[31];
/*! \brief constructor */ /*! \brief constructor */
ModelParam(void) { ModelParam(void) {
std::memset(this, 0, sizeof(ModelParam));
num_trees = 0; num_trees = 0;
num_roots = num_feature = 0; num_roots = num_feature = 0;
num_pbuffer = 0; num_pbuffer = 0;
num_output_group = 1; num_output_group = 1;
size_leaf_vector = 0; size_leaf_vector = 0;
std::memset(reserved, 0, sizeof(reserved));
} }
/*! /*!
* \brief set parameters from outside * \brief set parameters from outside

127
src/io/dmlc_simple.cpp Normal file
View File

@ -0,0 +1,127 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "../utils/io.h"
// implements a single no split version of DMLC
// in case we want to avoid dependency on dmlc-core
namespace xgboost {
namespace utils {
class SingleFileSplit : public dmlc::InputSplit {
public:
explicit SingleFileSplit(const char *fname)
: use_stdin_(false) {
if (!std::strcmp(fname, "stdin")) {
#ifndef XGBOOST_STRICT_CXX98_
use_stdin_ = true; fp_ = stdin;
#endif
}
if (!use_stdin_) {
fp_ = utils::FopenCheck(fname, "r");
}
end_of_file_ = false;
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) std::fclose(fp_);
}
virtual bool ReadLine(std::string *out_data) {
if (end_of_file_) return false;
out_data->clear();
while (true) {
char c = std::fgetc(fp_);
if (c == EOF) {
end_of_file_ = true;
}
if (c != '\r' && c != '\n' && c != EOF) {
*out_data += c;
} else {
if (out_data->length() != 0) return true;
if (end_of_file_) return false;
}
}
return false;
}
private:
std::FILE *fp_;
bool use_stdin_;
bool end_of_file_;
};
class StdFile : public dmlc::IStream {
public:
explicit StdFile(const char *fname, const char *mode)
: use_stdio(false) {
using namespace std;
#ifndef XGBOOST_STRICT_CXX98_
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
fp = utils::FopenCheck(fname, flag.c_str());
}
}
virtual ~StdFile(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return std::feof(fp) != 0;
}
inline void Close(void) {
if (fp != NULL && !use_stdio) {
std::fclose(fp); fp = NULL;
}
}
private:
std::FILE *fp;
bool use_stdio;
};
} // namespace utils
} // namespace xgboost
namespace dmlc {
InputSplit* InputSplit::Create(const char *uri,
unsigned part,
unsigned nsplit) {
using namespace xgboost;
const char *msg = "xgboost is compiled in local mode\n"\
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
utils::Check(strncmp(uri, "s3://", 5) != 0, msg);
utils::Check(strncmp(uri, "hdfs://", 7) != 0, msg);
utils::Check(nsplit == 1, msg);
return new utils::SingleFileSplit(uri);
}
IStream *IStream::Create(const char *uri, const char * const flag) {
using namespace xgboost;
const char *msg = "xgboost is compiled in local mode\n"\
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
utils::Check(strncmp(uri, "s3://", 5) != 0, msg);
utils::Check(strncmp(uri, "hdfs://", 7) != 0, msg);
return new utils::StdFile(uri, flag);
}
} // namespace dmlc

View File

@ -16,7 +16,10 @@ namespace xgboost {
namespace io { namespace io {
DataMatrix* LoadDataMatrix(const char *fname, bool silent, DataMatrix* LoadDataMatrix(const char *fname, bool silent,
bool savebuffer, bool loadsplit) { bool savebuffer, bool loadsplit) {
if (!std::strcmp(fname, "stdin") || loadsplit) { if (!std::strcmp(fname, "stdin") ||
!std::strncmp(fname, "s3://", 5) ||
!std::strncmp(fname, "hdfs://", 7) ||
loadsplit) {
DMatrixSimple *dmat = new DMatrixSimple(); DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadText(fname, silent, loadsplit); dmat->LoadText(fname, silent, loadsplit);
return dmat; return dmat;

View File

@ -90,11 +90,11 @@ class DMatrixSimple : public DataMatrix {
rank = rabit::GetRank(); rank = rabit::GetRank();
npart = rabit::GetWorldSize(); npart = rabit::GetWorldSize();
} }
rabit::io::InputSplit *in = dmlc::InputSplit *in =
rabit::io::CreateInputSplit(uri, rank, npart); dmlc::InputSplit::Create(uri, rank, npart);
this->Clear(); this->Clear();
std::string line; std::string line;
while (in->NextLine(&line)) { while (in->ReadLine(&line)) {
float label; float label;
std::istringstream ss(line); std::istringstream ss(line);
std::vector<RowBatch::Entry> feats; std::vector<RowBatch::Entry> feats;

View File

@ -192,8 +192,10 @@ class FMatrixS : public IFMatrix{
bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol()); bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ncol; ++i) { for (bst_omp_uint i = 0; i < ncol; ++i) {
std::sort(&col_data_[0] + col_ptr_[i], if (col_ptr_[i] < col_ptr_[i + 1]) {
&col_data_[0] + col_ptr_[i + 1], Entry::CmpValue); std::sort(BeginPtr(col_data_) + col_ptr_[i],
BeginPtr(col_data_) + col_ptr_[i + 1], Entry::CmpValue);
}
} }
} }

View File

@ -119,17 +119,29 @@ struct EvalMClassBase : public IEvaluator {
utils::Check(preds.size() % info.labels.size() == 0, utils::Check(preds.size() % info.labels.size() == 0,
"label and prediction size not match"); "label and prediction size not match");
const size_t nclass = preds.size() / info.labels.size(); const size_t nclass = preds.size() / info.labels.size();
utils::Check(nclass > 1,
"mlogloss and merror are only used for multi-class classification,"\
" use logloss for binary classification");
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size()); const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
float sum = 0.0, wsum = 0.0; float sum = 0.0, wsum = 0.0;
int label_error = 0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static) #pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const float wt = info.GetWeight(i); const float wt = info.GetWeight(i);
int label = static_cast<int>(info.labels[i]);
if (label >= 0 && label < static_cast<int>(nclass)) {
sum += Derived::EvalRow(info.labels[i], sum += Derived::EvalRow(info.labels[i],
BeginPtr(preds) + i * nclass, BeginPtr(preds) + i * nclass,
nclass) * wt; nclass) * wt;
wsum += wt; wsum += wt;
} else {
label_error = label;
} }
}
utils::Check(label_error >= 0 && label_error < static_cast<int>(nclass),
"MultiClassEvaluation: label must be in [0, num_class)," \
" num_class=%d but found %d in label",
static_cast<int>(nclass), label_error);
float dat[2]; dat[0] = sum, dat[1] = wsum; float dat[2]; dat[0] = sum, dat[1] = wsum;
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
@ -143,7 +155,7 @@ struct EvalMClassBase : public IEvaluator {
* \param pred prediction value of current instance * \param pred prediction value of current instance
* \param nclass number of class in the prediction * \param nclass number of class in the prediction
*/ */
inline static float EvalRow(float label, inline static float EvalRow(int label,
const float *pred, const float *pred,
size_t nclass); size_t nclass);
/*! /*!
@ -154,13 +166,15 @@ struct EvalMClassBase : public IEvaluator {
inline static float GetFinal(float esum, float wsum) { inline static float GetFinal(float esum, float wsum) {
return esum / wsum; return esum / wsum;
} }
// used to store error message
const char *error_msg_;
}; };
/*! \brief match error */ /*! \brief match error */
struct EvalMatchError : public EvalMClassBase<EvalMatchError> { struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
virtual const char *Name(void) const { virtual const char *Name(void) const {
return "merror"; return "merror";
} }
inline static float EvalRow(float label, inline static float EvalRow(int label,
const float *pred, const float *pred,
size_t nclass) { size_t nclass) {
return FindMaxIndex(pred, nclass) != static_cast<int>(label); return FindMaxIndex(pred, nclass) != static_cast<int>(label);
@ -171,12 +185,11 @@ struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
virtual const char *Name(void) const { virtual const char *Name(void) const {
return "mlogloss"; return "mlogloss";
} }
inline static float EvalRow(float label, inline static float EvalRow(int label,
const float *pred, const float *pred,
size_t nclass) { size_t nclass) {
const float eps = 1e-16f; const float eps = 1e-16f;
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
if (pred[k] > eps) { if (pred[k] > eps) {
return -std::log(pred[k]); return -std::log(pred[k]);
} else { } else {

View File

@ -163,7 +163,21 @@ class BoostLearner : public rabit::ISerializable {
bool calc_num_feature = true) { bool calc_num_feature = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format"); "BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format"); {
// backward compatibility code for compatible with old model type
// for new model, Read(&name_obj_) is suffice
size_t len;
utils::Check(fi.Read(&len, sizeof(len)) != 0, "BoostLearner: wrong model format");
if (len >= std::numeric_limits<unsigned>::max()) {
int gap;
utils::Check(fi.Read(&gap, sizeof(gap)) != 0, "BoostLearner: wrong model format");
len = len >> 32UL;
}
if (len != 0) {
name_obj_.resize(len);
utils::Check(fi.Read(&name_obj_[0], len) != 0, "BoostLearner: wrong model format");
}
}
utils::Check(fi.Read(&name_gbm_), "BoostLearner: wrong model format"); utils::Check(fi.Read(&name_gbm_), "BoostLearner: wrong model format");
// delete existing gbm if any // delete existing gbm if any
if (obj_ != NULL) delete obj_; if (obj_ != NULL) delete obj_;
@ -193,7 +207,7 @@ class BoostLearner : public rabit::ISerializable {
* \param fname file name * \param fname file name
*/ */
inline void LoadModel(const char *fname) { inline void LoadModel(const char *fname) {
utils::IStream *fi = rabit::io::CreateStream(fname, "r"); utils::IStream *fi = utils::IStream::Create(fname, "r");
std::string header; header.resize(4); std::string header; header.resize(4);
// check header for different binary encode // check header for different binary encode
// can be base64 or binary // can be base64 or binary
@ -207,7 +221,7 @@ class BoostLearner : public rabit::ISerializable {
this->LoadModel(*fi); this->LoadModel(*fi);
} else { } else {
delete fi; delete fi;
fi = rabit::io::CreateStream(fname, "r"); fi = utils::IStream::Create(fname, "r");
this->LoadModel(*fi); this->LoadModel(*fi);
} }
delete fi; delete fi;
@ -224,7 +238,7 @@ class BoostLearner : public rabit::ISerializable {
* \param save_base64 whether save in base64 format * \param save_base64 whether save in base64 format
*/ */
inline void SaveModel(const char *fname, bool save_base64 = false) const { inline void SaveModel(const char *fname, bool save_base64 = false) const {
utils::IStream *fo = rabit::io::CreateStream(fname, "w"); utils::IStream *fo = utils::IStream::Create(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) { if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5); fo->Write("bs64\t", 5);
utils::Base64OutStream bout(fo); utils::Base64OutStream bout(fo);

View File

@ -197,6 +197,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
gpair.resize(preds.size()); gpair.resize(preds.size());
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass); const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass); const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
int label_error = 0;
#pragma omp parallel #pragma omp parallel
{ {
std::vector<float> rec(nclass); std::vector<float> rec(nclass);
@ -208,8 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
Softmax(&rec); Softmax(&rec);
const unsigned j = i % nstep; const unsigned j = i % nstep;
int label = static_cast<int>(info.labels[j]); int label = static_cast<int>(info.labels[j]);
utils::Check(label >= 0 && label < nclass, if (label < 0 || label >= nclass) {
"SoftmaxMultiClassObj: label must be in [0, num_class)"); label_error = label; label = 0;
}
const float wt = info.GetWeight(j); const float wt = info.GetWeight(j);
for (int k = 0; k < nclass; ++k) { for (int k = 0; k < nclass; ++k) {
float p = rec[k]; float p = rec[k];
@ -222,6 +224,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
} }
} }
} }
utils::Check(label_error >= 0 && label_error < nclass,
"SoftmaxMultiClassObj: label must be in [0, num_class),"\
" num_class=%d but found %d in label", nclass, label_error);
} }
virtual void PredTransform(std::vector<float> *io_preds) { virtual void PredTransform(std::vector<float> *io_preds) {
this->Transform(io_preds, output_prob); this->Transform(io_preds, output_prob);

View File

@ -7,7 +7,6 @@
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "../../subtree/rabit/include/rabit.h" #include "../../subtree/rabit/include/rabit.h"
#include "../../subtree/rabit/rabit-learn/io/io.h"
#endif // XGBOOST_SYNC_H_ #endif // XGBOOST_SYNC_H_

View File

@ -406,7 +406,8 @@ class ColMaker: public IUpdater {
c.SetSubstract(snode[nid].stats, e.stats); c.SetSubstract(snode[nid].stats, e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
const float delta = d_step == +1 ? rt_eps : -rt_eps; const float gap = std::abs(e.last_fvalue) + rt_eps;
const float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1); e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
} }
} }

264
src/utils/base64-inl.h Normal file
View File

@ -0,0 +1,264 @@
#ifndef XGBOOST_UTILS_BASE64_INL_H_
#define XGBOOST_UTILS_BASE64_INL_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen
*/
#include <cctype>
#include <cstdio>
#include "./io.h"
namespace xgboost {
namespace utils {
/*! \brief buffer reader of the stream that allows you to get */
class StreamBufferReader {
public:
StreamBufferReader(size_t buffer_size)
:stream_(NULL),
read_len_(1), read_ptr_(1) {
buffer_.resize(buffer_size);
}
/*!
* \brief set input stream
*/
inline void set_stream(IStream *stream) {
stream_ = stream;
read_len_ = read_ptr_ = 1;
}
/*!
* \brief allows quick read using get char
*/
inline char GetChar(void) {
while (true) {
if (read_ptr_ < read_len_) {
return buffer_[read_ptr_++];
} else {
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
if (read_len_ == 0) return EOF;
read_ptr_ = 0;
}
}
}
/*! \brief whether we are reaching the end of file */
inline bool AtEnd(void) const {
return read_len_ == 0;
}
private:
/*! \brief the underlying stream */
IStream *stream_;
/*! \brief buffer to hold data */
std::string buffer_;
/*! \brief length of valid data in buffer */
size_t read_len_;
/*! \brief pointer in the buffer */
size_t read_ptr_;
};
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(IStream *fs) : reader_(256) {
reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0;
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* call this function before actually start read
*/
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = reader_.GetChar();
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
inline bool IsEOF(void) const {
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
}
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev != 0) {
if (num_prev == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev = 1;
}
} else {
// assert num_prev == 1
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
// first byte
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format");
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev ++] = nvalue & 0xFF;
}
}
// get next char
tmp_ch = reader_.GetChar();
}
if (kStrictCheck) {
utils::Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write");
}
private:
StreamBufferReader reader_;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(IStream *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf_top < 3 && tlen != 0) {
buf[++buf_top] = *cptr++; --tlen;
}
if (buf_top == 3) {
// flush 4 bytes out
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
PutChar(EncodeTable[buf[3] & 0x3F]);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
utils::Error("Base64OutStream do not support read");
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
*/
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
PutChar('=');
PutChar('=');
}
if (buf_top == 2) {
PutChar(EncodeTable[buf[1] >> 2]);
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
PutChar('=');
}
buf_top = 0;
if (endch != EOF) PutChar(endch);
this->Flush();
}
private:
IStream *fp;
int buf_top;
unsigned char buf[4];
std::string out_buf;
const static size_t kBufferSize = 256;
inline void PutChar(char ch) {
out_buf += ch;
if (out_buf.length() >= kBufferSize) Flush();
}
inline void Flush(void) {
if (out_buf.length() != 0) {
fp->Write(&out_buf[0], out_buf.length());
out_buf.clear();
}
}
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_BASE64_INL_H_

View File

@ -18,8 +18,6 @@ typedef rabit::IStream IStream;
typedef rabit::utils::ISeekStream ISeekStream; typedef rabit::utils::ISeekStream ISeekStream;
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer; typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
typedef rabit::utils::MemoryBufferStream MemoryBufferStream; typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
typedef rabit::io::Base64InStream Base64InStream;
typedef rabit::io::Base64OutStream Base64OutStream;
/*! \brief implementation of file i/o stream */ /*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream { class FileStream : public ISeekStream {
@ -54,4 +52,6 @@ class FileStream : public ISeekStream {
}; };
} // namespace utils } // namespace utils
} // namespace xgboost } // namespace xgboost
#include "./base64-inl.h"
#endif #endif

View File

@ -0,0 +1,4 @@
This folder is part of dmlc-core library, this allows rabit to use unified stream interface with other dmlc projects.
- Since it is only interface dependency DMLC core is not required to compile rabit
- To compile project that uses dmlc-core functions, link to libdmlc.a (provided by dmlc-core) will be required.

View File

@ -0,0 +1,162 @@
/*!
* Copyright (c) 2015 by Contributors
* \file io.h
* \brief defines serializable interface of dmlc
*/
#ifndef DMLC_IO_H_
#define DMLC_IO_H_
#include <cstdio>
#include <string>
#include <vector>
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief interface of stream I/O for serialization
*/
class IStream {
public:
/*!
* \brief reads data from a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \return the size of data read
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief writes data to a stream
* \param ptr pointer to a memory buffer
* \param size block size
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
/*!
* \brief generic factory function
* create an stream, the stream will close the underlying files
* upon deletion
* \param uri the uri of the input currently we support
* hdfs://, s3://, and file:// by default file:// will be used
* \param flag can be "w", "r", "a"
*/
static IStream *Create(const char *uri, const char* const flag);
// helper functions to write/read different data structures
/*!
* \brief writes a vector
* \param vec vector to be written/serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec);
/*!
* \brief loads a vector
* \param out_vec vector to be loaded/deserialized
* \return whether the load was successful
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec);
/*!
* \brief writes a string
* \param str the string to be written/serialized
*/
inline void Write(const std::string &str);
/*!
* \brief loads a string
* \param out_str string to be loaded/deserialized
* \return whether the load/deserialization was successful
*/
inline bool Read(std::string *out_str);
};
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:
// virtual destructor
virtual ~ISeekStream(void) {}
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
/*! \return whether we are at end of file */
virtual bool AtEnd(void) const = 0;
};
/*! \brief interface for serializable objects */
class ISerializable {
public:
/*!
* \brief load the model from a stream
* \param fi stream where to load the model from
*/
virtual void Load(IStream &fi) = 0;
/*!
* \brief saves the model to a stream
* \param fo stream where to save the model to
*/
virtual void Save(IStream &fo) const = 0;
};
/*!
* \brief input split header, used to create input split on input dataset
* this class can be used to obtain filesystem invariant splits from input files
*/
class InputSplit {
public:
/*!
* \brief read next line, store into out_data
* \param out_data the string that stores the line data, \n is not included
* \return true of next line was found, false if we read all the lines
*/
virtual bool ReadLine(std::string *out_data) = 0;
/*! \brief destructor*/
virtual ~InputSplit(void) {}
/*!
* \brief factory function:
* create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
*/
static InputSplit* Create(const char *uri,
unsigned part_index,
unsigned num_parts);
};
// implementations of inline functions
template<typename T>
inline void IStream::Write(const std::vector<T> &vec) {
size_t sz = vec.size();
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
template<typename T>
inline bool IStream::Read(std::vector<T> *out_vec) {
size_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
inline void IStream::Write(const std::string &str) {
size_t sz = str.length();
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
inline bool IStream::Read(std::string *out_str) {
size_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) {
return false;
}
}
return true;
}
} // namespace dmlc
#endif // DMLC_IO_H_

View File

@ -16,19 +16,8 @@
namespace rabit { namespace rabit {
namespace utils { namespace utils {
/*! \brief interface of i/o stream that support seek */ /*! \brief re-use definition of dmlc::ISeekStream */
class ISeekStream: public IStream { typedef dmlc::ISeekStream ISeekStream;
public:
// virtual destructor
virtual ~ISeekStream(void) {}
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
/*! \return whether we are at end of file */
virtual bool AtEnd(void) const = 0;
};
/*! \brief fixed size memory buffer */ /*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream { struct MemoryFixSizeBuffer : public ISeekStream {
public: public:

View File

@ -9,98 +9,19 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./rabit/utils.h" #include "./rabit/utils.h"
#include "./dmlc/io.h"
namespace rabit { namespace rabit {
/*! /*!
* \brief interface of stream I/O, used by ISerializable * \brief defines stream used in rabit
* \sa ISerializable * see definition of IStream in dmlc/io.h
*/ */
class IStream { typedef dmlc::IStream IStream;
public: /*!
/*! * \brief defines serializable objects used in rabit
* \brief reads data from a stream * see definition of ISerializable in dmlc/io.h
* \param ptr pointer to a memory buffer
* \param size block size
* \return the size of data read
*/ */
virtual size_t Read(void *ptr, size_t size) = 0; typedef dmlc::ISerializable ISerializable;
/*!
* \brief writes data to a stream
* \param ptr pointer to a memory buffer
* \param size block size
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write/read different data structures
/*!
* \brief writes a vector
* \param vec vector to be written/serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief loads a vector
* \param out_vec vector to be loaded/deserialized
* \return whether the load was successful
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief writes a string
* \param str the string to be written/serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief loads a string
* \param out_str string to be loaded/deserialized
* \return whether the load/deserialization was successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface for serializable objects */
class ISerializable {
public:
/*!
* \brief load the model from a stream
* \param fi stream where to load the model from
*/
virtual void Load(IStream &fi) = 0;
/*!
* \brief saves the model to a stream
* \param fo stream where to save the model to
*/
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit } // namespace rabit
#endif // RABIT_RABIT_SERIALIZABLE_H_ #endif // RABIT_RABIT_SERIALIZABLE_H_

View File

@ -9,10 +9,13 @@
#include <cstring> #include <cstring>
#include "./io.h" #include "./io.h"
#if RABIT_USE_WORMHOLE == 0
#if RABIT_USE_HDFS #if RABIT_USE_HDFS
#include "./hdfs-inl.h" #include "./hdfs-inl.h"
#endif #endif
#include "./file-inl.h" #include "./file-inl.h"
#endif
namespace rabit { namespace rabit {
namespace io { namespace io {
@ -25,6 +28,9 @@ namespace io {
inline InputSplit *CreateInputSplit(const char *uri, inline InputSplit *CreateInputSplit(const char *uri,
unsigned part, unsigned part,
unsigned nsplit) { unsigned nsplit) {
#if RABIT_USE_WORMHOLE
return dmlc::InputSplit::Create(uri, part, nsplit);
#else
using namespace std; using namespace std;
if (!strcmp(uri, "stdin")) { if (!strcmp(uri, "stdin")) {
return new SingleFileSplit(uri); return new SingleFileSplit(uri);
@ -40,7 +46,28 @@ inline InputSplit *CreateInputSplit(const char *uri,
#endif #endif
} }
return new LineSplitter(new FileProvider(uri), part, nsplit); return new LineSplitter(new FileProvider(uri), part, nsplit);
#endif
} }
template<typename TStream>
class StreamAdapter : public IStream {
public:
explicit StreamAdapter(TStream *stream)
: stream_(stream) {
}
virtual ~StreamAdapter(void) {
delete stream_;
}
virtual size_t Read(void *ptr, size_t size) {
return stream_->Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
stream_->Write(ptr, size);
}
private:
TStream *stream_;
};
/*! /*!
* \brief create an stream, the stream must be able to close * \brief create an stream, the stream must be able to close
* the underlying resources(files) when deleted * the underlying resources(files) when deleted
@ -49,6 +76,9 @@ inline InputSplit *CreateInputSplit(const char *uri,
* \param mode can be 'w' or 'r' for read or write * \param mode can be 'w' or 'r' for read or write
*/ */
inline IStream *CreateStream(const char *uri, const char *mode) { inline IStream *CreateStream(const char *uri, const char *mode) {
#if RABIT_USE_WORMHOLE
return new StreamAdapter<dmlc::IStream>(dmlc::IStream::Create(uri, mode));
#else
using namespace std; using namespace std;
if (!strncmp(uri, "file://", 7)) { if (!strncmp(uri, "file://", 7)) {
return new FileStream(uri + 7, mode); return new FileStream(uri + 7, mode);
@ -62,6 +92,7 @@ inline IStream *CreateStream(const char *uri, const char *mode) {
#endif #endif
} }
return new FileStream(uri, mode); return new FileStream(uri, mode);
#endif
} }
} // namespace io } // namespace io
} // namespace rabit } // namespace rabit

View File

@ -13,6 +13,13 @@
#define RABIT_USE_HDFS 0 #define RABIT_USE_HDFS 0
#endif #endif
#ifndef RABIT_USE_WORMHOLE
#define RABIT_USE_WORMHOLE 0
#endif
#if RABIT_USE_WORMHOLE
#include <dmlc/io.h>
#endif
/*! \brief io interface */ /*! \brief io interface */
namespace rabit { namespace rabit {
/*! /*!
@ -20,6 +27,10 @@ namespace rabit {
*/ */
namespace io { namespace io {
/*! \brief reused ISeekStream's definition */ /*! \brief reused ISeekStream's definition */
#if RABIT_USE_WORMHOLE
typedef dmlc::ISeekStream ISeekStream;
typedef dmlc::InputSplit InputSplit;
#else
typedef utils::ISeekStream ISeekStream; typedef utils::ISeekStream ISeekStream;
/*! /*!
* \brief user facing input split helper, * \brief user facing input split helper,
@ -33,10 +44,11 @@ class InputSplit {
* \n is not included * \n is not included
* \return true of next line was found, false if we read all the lines * \return true of next line was found, false if we read all the lines
*/ */
virtual bool NextLine(std::string *out_data) = 0; virtual bool ReadLine(std::string *out_data) = 0;
/*! \brief destructor*/ /*! \brief destructor*/
virtual ~InputSplit(void) {} virtual ~InputSplit(void) {}
}; };
#endif
/*! /*!
* \brief create input split given a uri * \brief create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix * \param uri the uri of the input, can contain hdfs prefix

View File

@ -51,7 +51,7 @@ class LineSplitter : public InputSplit {
delete provider_; delete provider_;
} }
// get next line // get next line
virtual bool NextLine(std::string *out_data) { virtual bool ReadLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ && if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false; offset_curr_ >= offset_end_) return false;
out_data->clear(); out_data->clear();
@ -178,7 +178,7 @@ class SingleFileSplit : public InputSplit {
virtual ~SingleFileSplit(void) { virtual ~SingleFileSplit(void) {
if (!use_stdin_) std::fclose(fp_); if (!use_stdin_) std::fclose(fp_);
} }
virtual bool NextLine(std::string *out_data) { virtual bool ReadLine(std::string *out_data) {
if (end_of_file_) return false; if (end_of_file_) return false;
out_data->clear(); out_data->clear();
while (true) { while (true) {

View File

@ -12,7 +12,7 @@ hadoop fs -mkdir $2/data
hadoop fs -put ../data/agaricus.txt.train $2/data hadoop fs -put ../data/agaricus.txt.train $2/data
# submit to hadoop # submit to hadoop
../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}" ../../wormhole/tracker/dmlc_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
# get the final model file # get the final model file
hadoop fs -get $2/mushroom.linear.model ./linear.model hadoop fs -get $2/mushroom.linear.model ./linear.model

View File

@ -3,6 +3,15 @@
export LDFLAGS= -L../../lib -pthread -lm -lrt export LDFLAGS= -L../../lib -pthread -lm -lrt
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
# setup opencv
ifeq ($(USE_DMLC),1)
include ../../dmlc-core/make/dmlc.mk
CFLAGS+= -DRABIT_USE_DMLC=1 -I ../../dmlc-core/include $(DMLC_CFLAGS)
LDFLAGS+= -L../../dmlc-core -ldmlc $(DMLC_LDFLAGS)
else
CFLAGS+= -DRABIT_USE_DMLC=0
endif
# setup opencv # setup opencv
ifeq ($(USE_HDFS),1) ifeq ($(USE_HDFS),1)
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
@ -11,6 +20,7 @@ else
CFLAGS+= -DRABIT_USE_HDFS=0 CFLAGS+= -DRABIT_USE_HDFS=0
endif endif
.PHONY: clean all lib mpi .PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN) all: $(BIN) $(MOCKBIN)

View File

@ -17,5 +17,8 @@ export MPICXX = mpicxx
# whether use HDFS support during compile # whether use HDFS support during compile
USE_HDFS = 1 USE_HDFS = 1
# whether use dmlc's io utils
USE_DMLC = 0
# path to libjvm.so # path to libjvm.so
LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server

View File

@ -56,7 +56,7 @@ struct SparseMat {
data.clear(); data.clear();
feat_dim = 0; feat_dim = 0;
std::string line; std::string line;
while (in->NextLine(&line)) { while (in->ReadLine(&line)) {
float label; float label;
std::istringstream ss(line); std::istringstream ss(line);
ss >> label; ss >> label;

View File

@ -31,6 +31,7 @@ AllreduceBase::AllreduceBase(void) {
// tracker URL // tracker URL
task_id = "NULL"; task_id = "NULL";
err_link = NULL; err_link = NULL;
dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB"); this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of intrest // setup possible enviroment variable of intrest
env_vars.push_back("rabit_task_id"); env_vars.push_back("rabit_task_id");
@ -39,6 +40,12 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("rabit_reduce_ring_mincount"); env_vars.push_back("rabit_reduce_ring_mincount");
env_vars.push_back("rabit_tracker_uri"); env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port"); env_vars.push_back("rabit_tracker_port");
// also include dmlc support direct variables
env_vars.push_back("DMLC_TASK_ID");
env_vars.push_back("DMLC_ROLE");
env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
} }
// initialization function // initialization function
@ -86,6 +93,10 @@ void AllreduceBase::Init(void) {
this->SetParam("rabit_world_size", num_task); this->SetParam("rabit_world_size", num_task);
} }
} }
if (dmlc_role != "worker") {
fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n");
exit(0);
}
// clear the setting before start reconnection // clear the setting before start reconnection
this->rank = -1; this->rank = -1;
//--------------------- //---------------------
@ -150,6 +161,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val; if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) { if (!strcmp(name, "rabit_reduce_ring_mincount")) {

View File

@ -496,6 +496,8 @@ class AllreduceBase : public IEngine {
std::string host_uri; std::string host_uri;
// uri of tracker // uri of tracker
std::string tracker_uri; std::string tracker_uri;
// role in dmlc jobs
std::string dmlc_role;
// port of tracker address // port of tracker address
int tracker_port; int tracker_port;
// port of slave process // port of slave process

View File

@ -31,6 +31,7 @@ class AllreduceMock : public AllreduceRobust {
AllreduceRobust::SetParam(name, val); AllreduceRobust::SetParam(name, val);
// additional parameters // additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val); if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val); if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "mock")) { if (!strcmp(name, "mock")) {

View File

@ -87,7 +87,7 @@ def get_world_size():
""" """
Returns get total number of process Returns get total number of process
""" """
ret = rbtlib.RabitGetWorlSize() ret = rbtlib.RabitGetWorldSize()
check_err__() check_err__()
return ret return ret

View File

@ -20,6 +20,7 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\..\src\gbm\gbm.cpp" /> <ClCompile Include="..\..\src\gbm\gbm.cpp" />
<ClCompile Include="..\..\src\io\dmlc_simple.cpp" />
<ClCompile Include="..\..\src\io\io.cpp" /> <ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\tree\updater.cpp" /> <ClCompile Include="..\..\src\tree\updater.cpp" />
<ClCompile Include="..\..\src\xgboost_main.cpp" /> <ClCompile Include="..\..\src\xgboost_main.cpp" />

View File

@ -20,6 +20,7 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\..\src\gbm\gbm.cpp" /> <ClCompile Include="..\..\src\gbm\gbm.cpp" />
<ClCompile Include="..\..\src\io\dmlc_simple.cpp" />
<ClCompile Include="..\..\src\io\io.cpp" /> <ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\tree\updater.cpp" /> <ClCompile Include="..\..\src\tree\updater.cpp" />
<ClCompile Include="..\..\subtree\rabit\src\engine_empty.cc" /> <ClCompile Include="..\..\subtree\rabit\src\engine_empty.cc" />

View File

@ -26,7 +26,6 @@ except ImportError:
SKLEARN_INSTALLED = False SKLEARN_INSTALLED = False
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] __all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
@ -632,7 +631,6 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
return bst return bst
class CVPack(object): class CVPack(object):
def __init__(self, dtrain, dtest, param): def __init__(self, dtrain, dtest, param):
self.dtrain = dtrain self.dtrain = dtrain
@ -765,34 +763,40 @@ class XGBModel(BaseEstimator):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise Exception('sklearn needs to be installed in order to use this module') raise Exception('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth self.max_depth = max_depth
self.eta = learning_rate self.learning_rate = learning_rate
self.silent = 1 if silent else 0 self.silent = silent
self.n_rounds = n_estimators self.n_estimators = n_estimators
self.objective = objective self.objective = objective
self._Booster = Booster() self._Booster = Booster()
def get_params(self, deep=True): def get_params(self, deep=True):
return {'max_depth': self.max_depth, return {'max_depth': self.max_depth,
'learning_rate': self.eta, 'learning_rate': self.learning_rate,
'n_estimators': self.n_rounds, 'n_estimators': self.n_estimators,
'silent': True if self.silent == 1 else False, 'silent': self.silent,
'objective': self.objective 'objective': self.objective
} }
def get_xgb_params(self): def get_xgb_params(self):
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective} return {'eta': self.learning_rate,
'max_depth': self.max_depth,
'silent': 1 if self.silent else 0,
'objective': self.objective
}
def fit(self, X, y): def fit(self, X, y):
trainDmatrix = DMatrix(X, label=y) trainDmatrix = DMatrix(X, label=y)
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds) self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_estimators)
return self return self
def predict(self, X): def predict(self, X):
testDmatrix = DMatrix(X) testDmatrix = DMatrix(X)
return self._Booster.predict(testDmatrix) return self._Booster.predict(testDmatrix)
class XGBClassifier(XGBModel, ClassifierMixin): class XGBClassifier(XGBModel, ClassifierMixin):
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True): def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="binary:logistic"):
super().__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic") super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective)
def fit(self, X, y, sample_weight=None): def fit(self, X, y, sample_weight=None):
y_values = list(np.unique(y)) y_values = list(np.unique(y))
@ -812,7 +816,7 @@ class XGBClassifier(XGBModel, ClassifierMixin):
else: else:
trainDmatrix = DMatrix(X, label=training_labels) trainDmatrix = DMatrix(X, label=training_labels)
self._Booster = train(xgb_options, trainDmatrix, self.n_rounds) self._Booster = train(xgb_options, trainDmatrix, self.n_estimators)
return self return self
@ -834,9 +838,8 @@ class XGBClassifier(XGBModel, ClassifierMixin):
else: else:
classone_probs = class_probs classone_probs = class_probs
classzero_probs = 1.0 - classone_probs classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs,classone_probs)).transpose() return np.vstack((classzero_probs, classone_probs)).transpose()
class XGBRegressor(XGBModel, RegressorMixin): class XGBRegressor(XGBModel, RegressorMixin):
pass pass