commit
6ed82edad7
@ -25,9 +25,13 @@ setClass("xgb.Booster")
|
||||
#' @export
|
||||
#'
|
||||
setMethod("predict", signature = "xgb.Booster",
|
||||
definition = function(object, newdata, outputmargin = FALSE, ntreelimit = NULL) {
|
||||
definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL) {
|
||||
if (class(newdata) != "xgb.DMatrix") {
|
||||
newdata <- xgb.DMatrix(newdata)
|
||||
if (is.null(missing)) {
|
||||
newdata <- xgb.DMatrix(newdata)
|
||||
} else {
|
||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||
}
|
||||
}
|
||||
if (is.null(ntreelimit)) {
|
||||
ntreelimit <- 0
|
||||
|
||||
@ -68,13 +68,17 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
||||
## ----the following are low level iteratively function, not needed if
|
||||
## you do not want to use them ---------------------------------------
|
||||
# get dmatrix from data, label
|
||||
xgb.get.DMatrix <- function(data, label = NULL) {
|
||||
xgb.get.DMatrix <- function(data, label = NULL, missing = NULL) {
|
||||
inClass <- class(data)
|
||||
if (inClass == "dgCMatrix" || inClass == "matrix") {
|
||||
if (is.null(label)) {
|
||||
stop("xgboost: need label when data is a matrix")
|
||||
}
|
||||
dtrain <- xgb.DMatrix(data, label = label)
|
||||
if (is.null(missing)){
|
||||
dtrain <- xgb.DMatrix(data, label = label)
|
||||
} else {
|
||||
dtrain <- xgb.DMatrix(data, label = label, missing = missing)
|
||||
}
|
||||
} else {
|
||||
if (!is.null(label)) {
|
||||
warning("xgboost: label will be ignored.")
|
||||
|
||||
@ -53,7 +53,7 @@
|
||||
#' "max.depth"=3, "eta"=1, "objective"="binary:logistic")
|
||||
#' @export
|
||||
#'
|
||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL,
|
||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
|
||||
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, ...) {
|
||||
if (typeof(params) != "list") {
|
||||
stop("xgb.cv: first argument params must be list")
|
||||
@ -61,7 +61,11 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL,
|
||||
if (nfold <= 1) {
|
||||
stop("nfold must be bigger than 1")
|
||||
}
|
||||
dtrain <- xgb.get.DMatrix(data, label)
|
||||
if (is.null(missing)) {
|
||||
dtrain <- xgb.get.DMatrix(data, label)
|
||||
} else {
|
||||
dtrain <- xgb.get.DMatrix(data, label, missing)
|
||||
}
|
||||
params <- append(params, list(...))
|
||||
params <- append(params, list(silent=1))
|
||||
for (mc in metrics) {
|
||||
|
||||
@ -43,9 +43,14 @@
|
||||
#'
|
||||
#' @export
|
||||
#'
|
||||
xgboost <- function(data = NULL, label = NULL, params = list(), nrounds,
|
||||
xgboost <- function(data = NULL, label = NULL, missing = NULL, params = list(), nrounds,
|
||||
verbose = 1, ...) {
|
||||
dtrain <- xgb.get.DMatrix(data, label)
|
||||
if (is.null(missing)) {
|
||||
dtrain <- xgb.get.DMatrix(data, label)
|
||||
} else {
|
||||
dtrain <- xgb.get.DMatrix(data, label, missing)
|
||||
}
|
||||
|
||||
params <- append(params, list(...))
|
||||
|
||||
if (verbose > 0) {
|
||||
|
||||
@ -37,3 +37,26 @@ print ('start training with user customized objective')
|
||||
# training with customized objective, we can also do step by step training
|
||||
# simply look at xgboost.py's implementation of train
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
||||
|
||||
#
|
||||
# there can be cases where you want additional information
|
||||
# being considered besides the property of DMatrix you can get by getinfo
|
||||
# you can set additional information as attributes if DMatrix
|
||||
|
||||
# set label attribute of dtrain to be label, we use label as an example, it can be anything
|
||||
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
|
||||
# this is new customized objective, where you can access things you set
|
||||
# same thing applies to customized evaluation function
|
||||
logregobjattr <- function(preds, dtrain) {
|
||||
# now you can access the attribute in customized function
|
||||
labels <- attr(dtrain, 'label')
|
||||
preds <- 1/(1 + exp(-preds))
|
||||
grad <- preds - labels
|
||||
hess <- preds * (1 - preds)
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
print ('start training with user customized objective, with additional attributes in DMatrix')
|
||||
# training with customized objective, we can also do step by step training
|
||||
# simply look at xgboost.py's implementation of train
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobjattr, evalerror)
|
||||
|
||||
@ -12,8 +12,15 @@ Examples Code: [Learning to use xgboost by examples](demo)
|
||||
|
||||
Notes on the Code: [Code Guide](src)
|
||||
|
||||
Learning about the model: [Introduction to Boosted Trees](http://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf)
|
||||
* This slide is made by Tianqi Chen to introduce gradient boosting in a statistical view.
|
||||
* It present boosted tree learning as formal functional space optimization of defined objective.
|
||||
* The model presented is used by xgboost for boosted trees
|
||||
|
||||
What's New
|
||||
=====
|
||||
|
||||
* Thanks to Bing Xu, [XGBoost.jl](https://github.com/antinucleon/XGBoost.jl) allows you to use xgboost from Julia
|
||||
* See the updated [demo folder](demo) for feature walkthrough
|
||||
* Thanks to Tong He, the new [R package](R-package) is available
|
||||
|
||||
@ -26,7 +33,6 @@ Features
|
||||
* Speed: XGBoost is very fast
|
||||
- IN [demo/higgs/speedtest.py](demo/kaggle-higgs/speedtest.py), kaggle higgs data it is faster(on our machine 20 times faster using 4 threads) than sklearn.ensemble.GradientBoostingClassifier
|
||||
* Layout of gradient boosting algorithm to support user defined objective
|
||||
* Python interface, works with numpy and scipy.sparse matrix
|
||||
|
||||
Build
|
||||
=====
|
||||
|
||||
@ -8,12 +8,30 @@ This folder contains the all example codes using xgboost.
|
||||
Features Walkthrough
|
||||
====
|
||||
This is a list of short codes introducing different functionalities of xgboost and its wrapper.
|
||||
* Basic walkthrough of wrappers [python](guide-python/basic_walkthrough.py)
|
||||
* Cutomize loss function, and evaluation metric [python](guide-python/custom_objective.py)
|
||||
* Boosting from existing prediction [python](guide-python/boost_from_prediction.py)
|
||||
* Predicting using first n trees [python](guide-python/predict_first_ntree.py)
|
||||
* Generalized Linear Model [python](guide-python/generalized_linear_model.py)
|
||||
* Cross validation [python](guide-python/cross_validation.py)
|
||||
* Basic walkthrough of wrappers
|
||||
[python](guide-python/basic_walkthrough.py)
|
||||
[R](../R-package/demo/basic_walkthrough.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/basic_walkthrough.jl)
|
||||
* Cutomize loss function, and evaluation metric
|
||||
[python](guide-python/custom_objective.py)
|
||||
[R](../R-package/demo/custom_objective.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/custom_objective.jl)
|
||||
* Boosting from existing prediction
|
||||
[python](guide-python/boost_from_prediction.py)
|
||||
[R](../R-package/demo/boost_from_prediction.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/boost_from_prediction.jl)
|
||||
* Predicting using first n trees
|
||||
[python](guide-python/predict_first_ntree.py)
|
||||
[R](../R-package/demo/boost_from_prediction.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/boost_from_prediction.jl)
|
||||
* Generalized Linear Model
|
||||
[python](guide-python/generalized_linear_model.py)
|
||||
[R](../R-package/demo/generalized_linear_model.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/generalized_linear_model.jl)
|
||||
* Cross validation
|
||||
[python](guide-python/cross_validation.py)
|
||||
[R](../R-package/demo/cross_validation.R)
|
||||
[Julia](https://github.com/antinucleon/XGBoost.jl/blob/master/demo/cross_validation.jl)
|
||||
|
||||
Basic Examples by Tasks
|
||||
====
|
||||
|
||||
@ -7,6 +7,10 @@ Python
|
||||
* To make the python module, type ```make``` in the root directory of project
|
||||
* Refer also to the walk through example in [demo folder](../demo/guide-python)
|
||||
|
||||
R
|
||||
R
|
||||
=====
|
||||
* See [R-package](../R-package)
|
||||
|
||||
Julia
|
||||
=====
|
||||
* See [XGBoost.jl](https://github.com/antinucleon/XGBoost.jl)
|
||||
|
||||
@ -436,7 +436,11 @@ def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None
|
||||
for i in range(num_boost_round):
|
||||
bst.update( dtrain, i, obj )
|
||||
if len(evals) != 0:
|
||||
sys.stderr.write(bst.eval_set(evals, i, feval).decode()+'\n')
|
||||
bst_eval_set=bst.eval_set(evals, i, feval)
|
||||
if isinstance(bst_eval_set,str):
|
||||
sys.stderr.write(bst_eval_set+'\n')
|
||||
else:
|
||||
sys.stderr.write(bst_eval_set.decode()+'\n')
|
||||
return bst
|
||||
|
||||
class CVPack:
|
||||
@ -467,7 +471,7 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
|
||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||
else:
|
||||
tparam = param
|
||||
plst = tparam.items() + [('eval_metric', itm) for itm in evals]
|
||||
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
|
||||
ret.append(CVPack(dtrain, dtest, plst))
|
||||
return ret
|
||||
|
||||
@ -481,12 +485,16 @@ def aggcv(rlist, show_stdv=True):
|
||||
arr = line.split()
|
||||
assert ret == arr[0]
|
||||
for it in arr[1:]:
|
||||
if not isinstance(it,str):
|
||||
it=it.decode()
|
||||
k, v = it.split(':')
|
||||
if k not in cvmap:
|
||||
cvmap[k] = []
|
||||
cvmap[k].append(float(v))
|
||||
for k, v in sorted(cvmap.items(), key = lambda x:x[0]):
|
||||
v = np.array(v)
|
||||
if not isinstance(ret,str):
|
||||
ret = ret.decode()
|
||||
if show_stdv:
|
||||
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
||||
else:
|
||||
|
||||
@ -132,6 +132,7 @@ extern "C"{
|
||||
bst_ulong nrow,
|
||||
bst_ulong ncol,
|
||||
float missing) {
|
||||
bool nan_missing = std::isnan(missing);
|
||||
DMatrixSimple *p_mat = new DMatrixSimple();
|
||||
DMatrixSimple &mat = *p_mat;
|
||||
mat.info.info.num_row = nrow;
|
||||
@ -139,9 +140,13 @@ extern "C"{
|
||||
for (bst_ulong i = 0; i < nrow; ++i, data += ncol) {
|
||||
bst_ulong nelem = 0;
|
||||
for (bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (data[j] != missing) {
|
||||
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
|
||||
++nelem;
|
||||
if (std::isnan(data[j])) {
|
||||
utils::Check(nan_missing, "There are NAN in the matrix, however, you did not set missing=NAN");
|
||||
} else {
|
||||
if (nan_missing || data[j] != missing) {
|
||||
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
|
||||
++nelem;
|
||||
}
|
||||
}
|
||||
}
|
||||
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
* \brief a C style wrapper of xgboost
|
||||
* can be used to create wrapper of other languages
|
||||
*/
|
||||
#include <cstdio>
|
||||
#ifdef _MSC_VER
|
||||
#define XGB_DLL __declspec(dllexport)
|
||||
#else
|
||||
@ -15,8 +14,9 @@
|
||||
// manually define unsign long
|
||||
typedef unsigned long bst_ulong;
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \return a loaded data matrix
|
||||
@ -205,5 +205,7 @@ extern "C" {
|
||||
*/
|
||||
XGB_DLL const char **XGBoosterDumpModel(void *handle, const char *fmap,
|
||||
bst_ulong *out_len);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // XGBOOST_WRAPPER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user