Merge pull request #5 from dmlc/master

update from dmlc/xgboost
This commit is contained in:
yanqingmen 2015-10-07 13:55:57 +08:00
commit 3453b6e715
40 changed files with 1057 additions and 234 deletions

View File

@ -30,6 +30,8 @@ addons:
- wget - wget
- libcurl4-openssl-dev - libcurl4-openssl-dev
- unzip - unzip
- python-numpy
- python-scipy
before_install: before_install:
- scripts/travis_osx_install.sh - scripts/travis_osx_install.sh

View File

@ -34,6 +34,7 @@ List of Contributors
* [Zygmunt Zając](https://github.com/zygmuntz) * [Zygmunt Zając](https://github.com/zygmuntz)
- Zygmunt is the master behind the early stopping feature frequently used by kagglers. - Zygmunt is the master behind the early stopping feature frequently used by kagglers.
* [Ajinkya Kale](https://github.com/ajkl) * [Ajinkya Kale](https://github.com/ajkl)
* [Yuan Tang](https://github.com/terrytangyuan)
* [Boliang Chen](https://github.com/cblsjtu) * [Boliang Chen](https://github.com/cblsjtu)
* [Vadim Khotilovich](https://github.com/khotilov) * [Vadim Khotilovich](https://github.com/khotilov)
* [Yangqing Men](https://github.com/yanqingmen) * [Yangqing Men](https://github.com/yanqingmen)
@ -48,3 +49,4 @@ List of Contributors
- Masaaki is the initial creator of xgboost python plotting module. - Masaaki is the initial creator of xgboost python plotting module.
* [Hongliang Liu](https://github.com/phunterlau) * [Hongliang Liu](https://github.com/phunterlau)
- Hongliang is the maintainer of xgboost python PyPI package for pip installation. - Hongliang is the maintainer of xgboost python PyPI package for pip installation.
* [Huayi Zhang](https://github.com/irachex)

View File

@ -1,6 +1,6 @@
export CC = gcc export CC = $(if $(shell which gcc-5),gcc-5,gcc)
#build on the fly export CXX = $(if $(shell which g++-5),g++-5,g++)
export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm export LDFLAGS= -pthread -lm
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops
@ -21,9 +21,17 @@ endif
ifeq ($(no_omp),1) ifeq ($(no_omp),1)
CFLAGS += -DDISABLE_OPENMP CFLAGS += -DDISABLE_OPENMP
else else
CFLAGS += -fopenmp #CFLAGS += -fopenmp
ifeq ($(omp_mac_static),1)
#CFLAGS += -fopenmp -Bstatic
CFLAGS += -static-libgcc -static-libstdc++ -L. -fopenmp
#LDFLAGS += -Wl,--whole-archive -lpthread -Wl --no-whole-archive
else
CFLAGS += -fopenmp
endif
endif endif
# by default use c++11 # by default use c++11
ifeq ($(cxx11),1) ifeq ($(cxx11),1)
CFLAGS += -std=c++11 CFLAGS += -std=c++11

View File

@ -23,7 +23,8 @@ Suggests:
ggplot2 (>= 1.0.0), ggplot2 (>= 1.0.0),
DiagrammeR (>= 0.6), DiagrammeR (>= 0.6),
Ckmeans.1d.dp (>= 3.3.1), Ckmeans.1d.dp (>= 3.3.1),
vcd (>= 1.3) vcd (>= 1.3),
testthat
Depends: Depends:
R (>= 2.10) R (>= 2.10)
Imports: Imports:

View File

@ -103,17 +103,21 @@ xgb.Booster.check <- function(bst, saveraw = TRUE)
## ----the following are low level iteratively function, not needed if ## ----the following are low level iteratively function, not needed if
## you do not want to use them --------------------------------------- ## you do not want to use them ---------------------------------------
# get dmatrix from data, label # get dmatrix from data, label
xgb.get.DMatrix <- function(data, label = NULL, missing = NULL) { xgb.get.DMatrix <- function(data, label = NULL, missing = NULL, weight = NULL) {
inClass <- class(data) inClass <- class(data)
if (inClass == "dgCMatrix" || inClass == "matrix") { if (inClass == "dgCMatrix" || inClass == "matrix") {
if (is.null(label)) { if (is.null(label)) {
stop("xgboost: need label when data is a matrix") stop("xgboost: need label when data is a matrix")
} }
dtrain <- xgb.DMatrix(data, label = label)
if (is.null(missing)){ if (is.null(missing)){
dtrain <- xgb.DMatrix(data, label = label) dtrain <- xgb.DMatrix(data, label = label)
} else { } else {
dtrain <- xgb.DMatrix(data, label = label, missing = missing) dtrain <- xgb.DMatrix(data, label = label, missing = missing)
} }
if (!is.null(weight)){
xgb.setinfo(dtrain, "weight", weight)
}
} else { } else {
if (!is.null(label)) { if (!is.null(label)) {
warning("xgboost: label will be ignored.") warning("xgboost: label will be ignored.")
@ -122,6 +126,9 @@ xgb.get.DMatrix <- function(data, label = NULL, missing = NULL) {
dtrain <- xgb.DMatrix(data) dtrain <- xgb.DMatrix(data)
} else if (inClass == "xgb.DMatrix") { } else if (inClass == "xgb.DMatrix") {
dtrain <- data dtrain <- data
} else if (inClass == "data.frame") {
stop("xgboost only support numerical matrix input,
use 'data.frame' to transform the data.")
} else { } else {
stop("xgboost: Invalid input of data") stop("xgboost: Invalid input of data")
} }

View File

@ -72,6 +72,8 @@
#' keeps getting worse consecutively for \code{k} rounds. #' keeps getting worse consecutively for \code{k} rounds.
#' @param maximize If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well. #' @param maximize If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well.
#' \code{maximize=TRUE} means the larger the evaluation score the better. #' \code{maximize=TRUE} means the larger the evaluation score the better.
#' @param save_period save the model to the disk in every \code{save_period} rounds, 0 means no such action.
#' @param save_name the name or path for periodically saved model file.
#' @param ... other parameters to pass to \code{params}. #' @param ... other parameters to pass to \code{params}.
#' #'
#' @details #' @details
@ -120,7 +122,8 @@
#' #'
xgb.train <- function(params=list(), data, nrounds, watchlist = list(), xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
obj = NULL, feval = NULL, verbose = 1, print.every.n=1L, obj = NULL, feval = NULL, verbose = 1, print.every.n=1L,
early.stop.round = NULL, maximize = NULL, ...) { early.stop.round = NULL, maximize = NULL,
save_period = 0, save_name = "xgboost.model", ...) {
dtrain <- data dtrain <- data
if (typeof(params) != "list") { if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list") stop("xgb.train: first argument params must be list")
@ -215,6 +218,11 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
} }
} }
} }
if (save_period > 0) {
if (i %% save_period == 0) {
xgb.save(bst, save_name)
}
}
} }
bst <- xgb.Booster.check(bst) bst <- xgb.Booster.check(bst)
if (!is.null(early.stop.round)) { if (!is.null(early.stop.round)) {

View File

@ -31,11 +31,14 @@
#' @param print.every.n Print every N progress messages when \code{verbose>0}. Default is 1 which means all messages are printed. #' @param print.every.n Print every N progress messages when \code{verbose>0}. Default is 1 which means all messages are printed.
#' @param missing Missing is only used when input is dense matrix, pick a float #' @param missing Missing is only used when input is dense matrix, pick a float
#' value that represents missing value. Sometimes a data use 0 or other extreme value to represents missing values. #' value that represents missing value. Sometimes a data use 0 or other extreme value to represents missing values.
#' @param weight a vector indicating the weight for each row of the input.
#' @param early.stop.round If \code{NULL}, the early stopping function is not triggered. #' @param early.stop.round If \code{NULL}, the early stopping function is not triggered.
#' If set to an integer \code{k}, training with a validation set will stop if the performance #' If set to an integer \code{k}, training with a validation set will stop if the performance
#' keeps getting worse consecutively for \code{k} rounds. #' keeps getting worse consecutively for \code{k} rounds.
#' @param maximize If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well. #' @param maximize If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well.
#' \code{maximize=TRUE} means the larger the evaluation score the better. #' \code{maximize=TRUE} means the larger the evaluation score the better.
#' @param save_period save the model to the disk in every \code{save_period} rounds, 0 means no such action.
#' @param save_name the name or path for periodically saved model file.
#' @param ... other parameters to pass to \code{params}. #' @param ... other parameters to pass to \code{params}.
#' #'
#' @details #' @details
@ -56,14 +59,11 @@
#' #'
#' @export #' @export
#' #'
xgboost <- function(data = NULL, label = NULL, missing = NULL, params = list(), nrounds, xgboost <- function(data = NULL, label = NULL, missing = NULL, weight = NULL,
params = list(), nrounds,
verbose = 1, print.every.n = 1L, early.stop.round = NULL, verbose = 1, print.every.n = 1L, early.stop.round = NULL,
maximize = NULL, ...) { maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) {
if (is.null(missing)) { dtrain <- xgb.get.DMatrix(data, label, missing, weight)
dtrain <- xgb.get.DMatrix(data, label)
} else {
dtrain <- xgb.get.DMatrix(data, label, missing)
}
params <- append(params, list(...)) params <- append(params, list(...))
@ -74,7 +74,8 @@ xgboost <- function(data = NULL, label = NULL, missing = NULL, params = list(),
} }
bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose = verbose, print.every.n=print.every.n, bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose = verbose, print.every.n=print.every.n,
early.stop.round = early.stop.round) early.stop.round = early.stop.round, maximize = maximize,
save_period = save_period, save_name = save_name)
return(bst) return(bst)
} }

View File

@ -1,4 +1,5 @@
basic_walkthrough Basic feature walkthrough basic_walkthrough Basic feature walkthrough
caret_wrapper Use xgboost to train in caret library
custom_objective Cutomize loss function, and evaluation metric custom_objective Cutomize loss function, and evaluation metric
boost_from_prediction Boosting from existing prediction boost_from_prediction Boosting from existing prediction
predict_first_ntree Predicting using first n trees predict_first_ntree Predicting using first n trees

View File

@ -1,6 +1,7 @@
XGBoost R Feature Walkthrough XGBoost R Feature Walkthrough
==== ====
* [Basic walkthrough of wrappers](basic_walkthrough.R) * [Basic walkthrough of wrappers](basic_walkthrough.R)
* [Train a xgboost model from caret library](caret_wrapper.R)
* [Cutomize loss function, and evaluation metric](custom_objective.R) * [Cutomize loss function, and evaluation metric](custom_objective.R)
* [Boosting from existing prediction](boost_from_prediction.R) * [Boosting from existing prediction](boost_from_prediction.R)
* [Predicting using first n trees](predict_first_ntree.R) * [Predicting using first n trees](predict_first_ntree.R)

View File

@ -0,0 +1,35 @@
# install development version of caret library that contains xgboost models
devtools::install_github("topepo/caret/pkg/caret")
require(caret)
require(xgboost)
require(data.table)
require(vcd)
require(e1071)
# Load Arthritis dataset in memory.
data(Arthritis)
# Create a copy of the dataset with data.table package (data.table is 100% compliant with R dataframe but its syntax is a lot more consistent and its performance are really good).
df <- data.table(Arthritis, keep.rownames = F)
# Let's add some new categorical features to see if it helps. Of course these feature are highly correlated to the Age feature. Usually it's not a good thing in ML, but Tree algorithms (including boosted trees) are able to select the best features, even in case of highly correlated features.
# For the first feature we create groups of age by rounding the real age. Note that we transform it to factor (categorical data) so the algorithm treat them as independant values.
df[,AgeDiscret:= as.factor(round(Age/10,0))]
# Here is an even stronger simplification of the real age with an arbitrary split at 30 years old. I choose this value based on nothing. We will see later if simplifying the information based on arbitrary values is a good strategy (I am sure you already have an idea of how well it will work!).
df[,AgeCat:= as.factor(ifelse(Age > 30, "Old", "Young"))]
# We remove ID as there is nothing to learn from this feature (it will just add some noise as the dataset is small).
df[,ID:=NULL]
#-------------Basic Training using XGBoost in caret Library-----------------
# Set up control parameters for caret::train
# Here we use 10-fold cross-validation, repeating twice, and using random search for tuning hyper-parameters.
fitControl <- trainControl(method = "cv", number = 10, repeats = 2, search = "random")
# train a xgbTree model using caret::train
model <- train(factor(Improved)~., data = df, method = "xgbTree", trControl = fitControl)
# Instead of tree for our boosters, you can also fit a linear regression or logistic regression model using xgbLinear
# model <- train(factor(Improved)~., data = df, method = "xgbLinear", trControl = fitControl)
# See model results
print(model)

View File

@ -9,3 +9,4 @@ demo(create_sparse_matrix)
demo(predict_leaf_indices) demo(predict_leaf_indices)
demo(early_stopping) demo(early_stopping)
demo(poisson_regression) demo(poisson_regression)
demo(caret_wrapper)

View File

@ -6,7 +6,8 @@
\usage{ \usage{
xgb.train(params = list(), data, nrounds, watchlist = list(), obj = NULL, xgb.train(params = list(), data, nrounds, watchlist = list(), obj = NULL,
feval = NULL, verbose = 1, print.every.n = 1L, feval = NULL, verbose = 1, print.every.n = 1L,
early.stop.round = NULL, maximize = NULL, ...) early.stop.round = NULL, maximize = NULL, save_period = 0,
save_name = "xgboost.model", ...)
} }
\arguments{ \arguments{
\item{params}{the list of parameters. \item{params}{the list of parameters.
@ -87,6 +88,10 @@ keeps getting worse consecutively for \code{k} rounds.}
\item{maximize}{If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well. \item{maximize}{If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well.
\code{maximize=TRUE} means the larger the evaluation score the better.} \code{maximize=TRUE} means the larger the evaluation score the better.}
\item{save_period}{save the model to the disk in every \code{save_period} rounds, 0 means no such action.}
\item{save_name}{the name or path for periodically saved model file.}
\item{...}{other parameters to pass to \code{params}.} \item{...}{other parameters to pass to \code{params}.}
} }
\description{ \description{

View File

@ -4,9 +4,10 @@
\alias{xgboost} \alias{xgboost}
\title{eXtreme Gradient Boosting (Tree) library} \title{eXtreme Gradient Boosting (Tree) library}
\usage{ \usage{
xgboost(data = NULL, label = NULL, missing = NULL, params = list(), xgboost(data = NULL, label = NULL, missing = NULL, weight = NULL,
nrounds, verbose = 1, print.every.n = 1L, early.stop.round = NULL, params = list(), nrounds, verbose = 1, print.every.n = 1L,
maximize = NULL, ...) early.stop.round = NULL, maximize = NULL, save_period = 0,
save_name = "xgboost.model", ...)
} }
\arguments{ \arguments{
\item{data}{takes \code{matrix}, \code{dgCMatrix}, local data file or \item{data}{takes \code{matrix}, \code{dgCMatrix}, local data file or
@ -18,6 +19,8 @@ if data is local data file or \code{xgb.DMatrix}.}
\item{missing}{Missing is only used when input is dense matrix, pick a float \item{missing}{Missing is only used when input is dense matrix, pick a float
value that represents missing value. Sometimes a data use 0 or other extreme value to represents missing values.} value that represents missing value. Sometimes a data use 0 or other extreme value to represents missing values.}
\item{weight}{a vector indicating the weight for each row of the input.}
\item{params}{the list of parameters. \item{params}{the list of parameters.
Commonly used ones are: Commonly used ones are:
@ -51,6 +54,10 @@ keeps getting worse consecutively for \code{k} rounds.}
\item{maximize}{If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well. \item{maximize}{If \code{feval} and \code{early.stop.round} are set, then \code{maximize} must be set as well.
\code{maximize=TRUE} means the larger the evaluation score the better.} \code{maximize=TRUE} means the larger the evaluation score the better.}
\item{save_period}{save the model to the disk in every \code{save_period} rounds, 0 means no such action.}
\item{save_name}{the name or path for periodically saved model file.}
\item{...}{other parameters to pass to \code{params}.} \item{...}{other parameters to pass to \code{params}.}
} }
\description{ \description{

View File

@ -0,0 +1,4 @@
library(testthat)
library(xgboost)
test_check("xgboost")

View File

@ -0,0 +1,33 @@
require(xgboost)
context("basic functions")
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train = agaricus.train
test = agaricus.test
test_that("train and predict", {
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
pred = predict(bst, test$data)
})
test_that("early stopping", {
res = xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE)
expect_true(nrow(res)<20)
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE)
pred = predict(bst, test$data)
})
test_that("save_period", {
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
save_period = 10, save_name = "xgb.model")
pred = predict(bst, test$data)
})

View File

@ -0,0 +1,47 @@
context('Test models with custom objective')
require(xgboost)
test_that("custom objective works", {
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
watchlist <- list(eval = dtest, train = dtrain)
num_round <- 2
logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
return(list(metric = "error", value = err))
}
param <- list(max.depth=2, eta=1, nthread = 2, silent=1,
objective=logregobj, eval_metric=evalerror)
bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1064)
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
logregobjattr <- function(preds, dtrain) {
labels <- attr(dtrain, 'label')
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
param <- list(max.depth=2, eta=1, nthread = 2, silent=1,
objective=logregobjattr, eval_metric=evalerror)
bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1064)
})

View File

@ -0,0 +1,19 @@
context('Test generalized linear models')
require(xgboost)
test_that("glm works", {
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
expect_equal(class(dtrain), "xgb.DMatrix")
expect_equal(class(dtest), "xgb.DMatrix")
param <- list(objective = "binary:logistic", booster = "gblinear",
nthread = 2, alpha = 0.0001, lambda = 1)
watchlist <- list(eval = dtest, train = dtrain)
num_round <- 2
bst <- xgb.train(param, dtrain, num_round, watchlist)
ypred <- predict(bst, dtest)
expect_equal(length(getinfo(dtest, 'label')), 1611)
})

View File

@ -0,0 +1,32 @@
context('Test helper functions')
require(xgboost)
require(data.table)
require(Matrix)
require(vcd)
data(Arthritis)
data(agaricus.train, package='xgboost')
df <- data.table(Arthritis, keep.rownames = F)
df[,AgeDiscret:= as.factor(round(Age/10,0))]
df[,AgeCat:= as.factor(ifelse(Age > 30, "Old", "Young"))]
df[,ID:=NULL]
sparse_matrix = sparse.model.matrix(Improved~.-1, data = df)
output_vector = df[,Y:=0][Improved == "Marked",Y:=1][,Y]
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
eta = 1, nthread = 2, nround = 10,objective = "binary:logistic")
test_that("xgb.dump works", {
capture.output(print(xgb.dump(bst)))
})
test_that("xgb.importance works", {
xgb.dump(bst, 'xgb.model.dump', with.stats = T)
importance <- xgb.importance(sparse_matrix@Dimnames[[2]], 'xgb.model.dump')
expect_equal(dim(importance), c(7, 4))
})
test_that("xgb.plot.tree works", {
xgb.plot.tree(agaricus.train$data@Dimnames[[2]], model = bst)
})

View File

@ -0,0 +1,13 @@
context('Test poisson regression model')
require(xgboost)
test_that("poisson regression works", {
data(mtcars)
bst = xgboost(data=as.matrix(mtcars[,-11]),label=mtcars[,11],
objective='count:poisson',nrounds=5)
expect_equal(class(bst), "xgb.Booster")
pred = predict(bst,as.matrix(mtcars[,-11]))
expect_equal(length(pred), 32)
sqrt(mean((pred-mtcars[,11])^2))
})

View File

@ -160,7 +160,7 @@ bstDense <- xgboost(data = as.matrix(train$data), label = train$label, max.depth
#### xgb.DMatrix #### xgb.DMatrix
**XGBoost** offers a way to group them in a `xgb.DMatrix`. You can even add other meta data in it. It will be usefull for the most advanced features we will discover later. **XGBoost** offers a way to group them in a `xgb.DMatrix`. You can even add other meta data in it. It will be useful for the most advanced features we will discover later.
```{r trainingDmatrix, message=F, warning=F} ```{r trainingDmatrix, message=F, warning=F}
dtrain <- xgb.DMatrix(data = train$data, label = train$label) dtrain <- xgb.DMatrix(data = train$data, label = train$label)
@ -169,7 +169,7 @@ bstDMatrix <- xgboost(data = dtrain, max.depth = 2, eta = 1, nthread = 2, nround
#### Verbose option #### Verbose option
**XGBoost** has severa features to help you to view how the learning progress internally. The purpose is to help you to set the best parameters, which is the key of your model quality. **XGBoost** has several features to help you to view how the learning progress internally. The purpose is to help you to set the best parameters, which is the key of your model quality.
One of the simplest way to see the training progress is to set the `verbose` option (see below for more advanced technics). One of the simplest way to see the training progress is to set the `verbose` option (see below for more advanced technics).
@ -194,7 +194,7 @@ Basic prediction using XGBoost
Perform the prediction Perform the prediction
---------------------- ----------------------
The pupose of the model we have built is to classify new data. As explained before, we will use the `test` dataset for this step. The purpose of the model we have built is to classify new data. As explained before, we will use the `test` dataset for this step.
```{r predicting, message=F, warning=F} ```{r predicting, message=F, warning=F}
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
@ -267,7 +267,7 @@ Measure learning progress with xgb.train
Both `xgboost` (simple) and `xgb.train` (advanced) functions train models. Both `xgboost` (simple) and `xgb.train` (advanced) functions train models.
One of the special feature of `xgb.train` is the capacity to follow the progress of the learning after each round. Because of the way boosting works, there is a time when having too many rounds lead to an overfitting. You can see this feature as a cousin of cross-validation method. The following technics will help you to avoid overfitting or optimizing the learning time in stopping it as soon as possible. One of the special feature of `xgb.train` is the capacity to follow the progress of the learning after each round. Because of the way boosting works, there is a time when having too many rounds lead to an overfitting. You can see this feature as a cousin of cross-validation method. The following techniques will help you to avoid overfitting or optimizing the learning time in stopping it as soon as possible.
One way to measure progress in learning of a model is to provide to **XGBoost** a second dataset already classified. Therefore it can learn on the first dataset and test its model on the second one. Some metrics are measured after each round during the learning. One way to measure progress in learning of a model is to provide to **XGBoost** a second dataset already classified. Therefore it can learn on the first dataset and test its model on the second one. Some metrics are measured after each round during the learning.
@ -285,7 +285,7 @@ bst <- xgb.train(data=dtrain, max.depth=2, eta=1, nthread = 2, nround=2, watchli
Both training and test error related metrics are very similar, and in some way, it makes sense: what we have learned from the training dataset matches the observations from the test dataset. Both training and test error related metrics are very similar, and in some way, it makes sense: what we have learned from the training dataset matches the observations from the test dataset.
If with your own dataset you have not such results, you should think about how you did to divide your dataset in training and test. May be there is something to fix. Again, `caret` package may [help](http://topepo.github.io/caret/splitting.html). If with your own dataset you have not such results, you should think about how you divided your dataset in training and test. May be there is something to fix. Again, `caret` package may [help](http://topepo.github.io/caret/splitting.html).
For a better understanding of the learning progression, you may want to have some specific metric or even use multiple evaluation metrics. For a better understanding of the learning progression, you may want to have some specific metric or even use multiple evaluation metrics.
@ -306,7 +306,7 @@ bst <- xgb.train(data=dtrain, booster = "gblinear", max.depth=2, nthread = 2, nr
In this specific case, *linear boosting* gets sligtly better performance metrics than decision trees based algorithm. In this specific case, *linear boosting* gets sligtly better performance metrics than decision trees based algorithm.
In simple cases, it will happem because there is nothing better than a linear algorithm to catch a linear link. However, decision trees are much better to catch a non linear link between predictors and outcome. Because there is no silver bullet, we advise you to check both algorithms with your own datasets to have an idea of what to use. In simple cases, it will happen because there is nothing better than a linear algorithm to catch a linear link. However, decision trees are much better to catch a non linear link between predictors and outcome. Because there is no silver bullet, we advise you to check both algorithms with your own datasets to have an idea of what to use.
Manipulating xgb.DMatrix Manipulating xgb.DMatrix
------------------------ ------------------------
@ -368,7 +368,7 @@ xgb.plot.tree(model = bst)
Save and load models Save and load models
-------------------- --------------------
May be your dataset is big, and it takes time to train a model on it? May be you are not a big fan of loosing time in redoing the same task again and again? In these very rare cases, you will want to save your model and load it when required. Maybe your dataset is big, and it takes time to train a model on it? May be you are not a big fan of losing time in redoing the same task again and again? In these very rare cases, you will want to save your model and load it when required.
Hopefully for you, **XGBoost** implements such functions. Hopefully for you, **XGBoost** implements such functions.
@ -379,7 +379,7 @@ xgb.save(bst, "xgboost.model")
> `xgb.save` function should return `r TRUE` if everything goes well and crashes otherwise. > `xgb.save` function should return `r TRUE` if everything goes well and crashes otherwise.
An interesting test to see how identic is our saved model with the original one would be to compare the two predictions. An interesting test to see how identical our saved model is to the original one would be to compare the two predictions.
```{r loadModel, message=F, warning=F} ```{r loadModel, message=F, warning=F}
# load binary model to R # load binary model to R

View File

@ -19,7 +19,7 @@ Contents
* [Build Instruction](doc/build.md) * [Build Instruction](doc/build.md)
* [Features](#features) * [Features](#features)
* [Distributed XGBoost](multi-node) * [Distributed XGBoost](multi-node)
* [Usecases](doc/README.md#highlight-links) * [Usecases](doc/index.md#highlight-links)
* [Bug Reporting](#bug-reporting) * [Bug Reporting](#bug-reporting)
* [Contributing to XGBoost](#contributing-to-xgboost) * [Contributing to XGBoost](#contributing-to-xgboost)
* [Committers and Contributors](CONTRIBUTORS.md) * [Committers and Contributors](CONTRIBUTORS.md)
@ -29,6 +29,7 @@ Contents
What's New What's New
---------- ----------
* XGBoost helps Owen Zhang to win the [Avito Context Ad Click competition](https://www.kaggle.com/c/avito-context-ad-clicks). Check out the [interview from Kaggle](http://blog.kaggle.com/2015/08/26/avito-winners-interview-1st-place-owen-zhang/).
* XGBoost helps Chenglong Chen to win [Kaggle CrowdFlower Competition](https://www.kaggle.com/c/crowdflower-search-relevance) * XGBoost helps Chenglong Chen to win [Kaggle CrowdFlower Competition](https://www.kaggle.com/c/crowdflower-search-relevance)
Check out the [winning solution](https://github.com/ChenglongChen/Kaggle_CrowdFlower) Check out the [winning solution](https://github.com/ChenglongChen/Kaggle_CrowdFlower)
* XGBoost-0.4 release, see [CHANGES.md](CHANGES.md#xgboost-04) * XGBoost-0.4 release, see [CHANGES.md](CHANGES.md#xgboost-04)

View File

@ -6,6 +6,18 @@
# See additional instruction in doc/build.md # See additional instruction in doc/build.md
#for building static OpenMP lib in MAC for easier installation in MAC
#doesn't work with XCode clang/LLVM since Apple doesn't support,
#needs brew install gcc 4.9+ with OpenMP. By default the static link is OFF
static_omp=0
if ((${static_omp}==1)); then
rm libgomp.a
ln -s `g++ -print-file-name=libgomp.a`
make clean
make omp_mac_static=1
echo "Successfully build multi-thread static link xgboost"
exit 0
fi
if make; then if make; then
echo "Successfully build multi-thread xgboost" echo "Successfully build multi-thread xgboost"

View File

@ -1,6 +1,6 @@
Frequent Asked Questions Frequently Asked Questions
======================== ========================
This document contains the frequent asked question to xgboost. This document contains frequently asked questions about xgboost.
How to tune parameters How to tune parameters
---------------------- ----------------------
@ -13,7 +13,7 @@ See [Introduction to Boosted Trees](model.md)
I have a big dataset I have a big dataset
-------------------- --------------------
XGBoost is designed to be memory efficient. Usually it could handle problems as long as the data fit into your memory XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fit into your memory
(This usually means millions of instances). (This usually means millions of instances).
If you are running out of memory, checkout [external memory version](external_memory.md) or If you are running out of memory, checkout [external memory version](external_memory.md) or
[distributed version](https://github.com/dmlc/wormhole/tree/master/learn/xgboost) of xgboost. [distributed version](https://github.com/dmlc/wormhole/tree/master/learn/xgboost) of xgboost.
@ -23,30 +23,30 @@ Running xgboost on Platform X (Hadoop/Yarn, Mesos)
-------------------------------------------------- --------------------------------------------------
The distributed version of XGBoost is designed to be portable to various environment. The distributed version of XGBoost is designed to be portable to various environment.
Distributed XGBoost can be ported to any platform that supports [rabit](https://github.com/dmlc/rabit). Distributed XGBoost can be ported to any platform that supports [rabit](https://github.com/dmlc/rabit).
You can directly run xgboost on Yarn. In theory Mesos and other resource allocation engine can be easily supported as well. You can directly run xgboost on Yarn. In theory Mesos and other resource allocation engines can be easily supported as well.
Why not implement distributed xgboost on top of X (Spark, Hadoop) Why not implement distributed xgboost on top of X (Spark, Hadoop)
----------------------------------------------------------------- -----------------------------------------------------------------
The first fact we need to know is going distributed does not necessarily solve all the problems. The first fact we need to know is going distributed does not necessarily solve all the problems.
Instead, it creates more problems such as more communication over head and fault tolerance. Instead, it creates more problems such as more communication overhead and fault tolerance.
The ultimate question will still come back into how to push the limit of each computation node The ultimate question will still come back to how to push the limit of each computation node
and use less resources to complete the task (thus with less communication and chance of failure). and use less resources to complete the task (thus with less communication and chance of failure).
To achieve these, we decide to reuse the optimizations in the single node xgboost and build distributed version on top of it. To achieve these, we decide to reuse the optimizations in the single node xgboost and build distributed version on top of it.
The demand of communication in machine learning is rather simple, in a sense that we can depend on a limited set of API (in our case rabit). The demand of communication in machine learning is rather simple, in the sense that we can depend on a limited set of API (in our case rabit).
Such design allows us to reuse most of the code, and being portable to major platforms such as Hadoop/Yarn, MPI, SGE. Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE.
Most importantly, pushs the limit of the computation resources we can use. Most importantly, it pushes the limit of the computation resources we can use.
How can I port the model to my own system How can I port the model to my own system
----------------------------------------- -----------------------------------------
The model and data format of XGBoost is exchangable. The model and data format of XGBoost is exchangable,
Which means the model trained by one langauge can be loaded in another. which means the model trained by one language can be loaded in another.
This means you can train the model using R, while running prediction using This means you can train the model using R, while running prediction using
Java or C++, which are more common in production system. Java or C++, which are more common in production systems.
You can also train the model using distributed version, You can also train the model using distributed versions,
and load them in from python to do some interactive analysis. and load them in from Python to do some interactive analysis.
Do you support LambdaMART Do you support LambdaMART

View File

@ -2,29 +2,29 @@ Introduction to Boosted Trees
============================= =============================
XGBoost is short for "Extreme Gradient Boosting", where the term "Gradient Boosting" is proposed in the paper _Greedy Function Approximation: A Gradient Boosting Machine_, Friedman. Based on this original model. This is a tutorial on boosted trees, most of content are based on this [slide](http://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf) by the author of xgboost. XGBoost is short for "Extreme Gradient Boosting", where the term "Gradient Boosting" is proposed in the paper _Greedy Function Approximation: A Gradient Boosting Machine_, Friedman. Based on this original model. This is a tutorial on boosted trees, most of content are based on this [slide](http://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf) by the author of xgboost.
The GBM(boosted trees) has been around for really a while, and there are a lot of materials on the topic. This tutorial tries to explain boosted trees in a self-contained and principled way of supervised learning. We think this explaination is cleaner, more formal, and motivates the variant used in xgboost. The GBM(boosted trees) has been around for really a while, and there are a lot of materials on the topic. This tutorial tries to explain boosted trees in a self-contained and principled way of supervised learning. We think this explanation is cleaner, more formal, and motivates the variant used in xgboost.
Elements of Supervised Learning Elements of Supervised Learning
------------------------------- -------------------------------
XGBoost is used for supervised learning problems, where we use the training data ``$ x_i $`` to predict a target variable ``$ y_i $``. XGBoost is used for supervised learning problems, where we use the training data ``$ x_i $`` to predict a target variable ``$ y_i $``.
Before we get dived into trees, let us start from reviwing the basic elements in supervised learning. Before we dive into trees, let us start by reviewing the basic elements in supervised learning.
### Model and Parameters ### Model and Parameters
The ***model*** in supervised learning usually refers to the mathematical structure on how to given the prediction ``$ y_i $`` given ``$ x_i $``. The ***model*** in supervised learning usually refers to the mathematical structure on how to given the prediction ``$ y_i $`` given ``$ x_i $``.
For example, a common model is *linear model*, where the prediction is given by ``$ \hat{y}_i = \sum_j w_j x_{ij} $``, a linear combination of weighted input features. For example, a common model is *linear model*, where the prediction is given by ``$ \hat{y}_i = \sum_j w_j x_{ij} $``, a linear combination of weighted input features.
The prediction value can have different interpretations, depending on the task. The prediction value can have different interpretations, depending on the task.
For example, it can be logistic transformed to get the probability of postitive class in logistic regression, it can also be used as ranking score when we want to rank the outputs. For example, it can be logistic transformed to get the probability of positive class in logistic regression, and it can also be used as ranking score when we want to rank the outputs.
The ***parameters*** are the undermined part that we need to learn from data. In linear regression problem, the parameters are the co-efficients ``$ w $``. The ***parameters*** are the undermined part that we need to learn from data. In linear regression problem, the parameters are the co-efficients ``$ w $``.
Usually we will use ``$ \Theta $`` to denote the parameters. Usually we will use ``$ \Theta $`` to denote the parameters.
### Object Function : Training Loss + Regularization ### Objective Function : Training Loss + Regularization
Based on different understanding or assumption of ``$ y_i $``, we can have different problems as regression, classification, ordering, etc. Based on different understanding or assumption of ``$ y_i $``, we can have different problems as regression, classification, ordering, etc.
We need to find a way to find the best parameters given the training data. In order to do so, we need to define a so called ***objective function***, We need to find a way to find the best parameters given the training data. In order to do so, we need to define a so called ***objective function***,
to measure the performance of the model under certain set of parameters. to measure the performance of the model under certain set of parameters.
A very important about objective functions, is they ***must always*** contains two parts: training loss and regularization. A very important fact about objective functions, is they ***must always*** contains two parts: training loss and regularization.
```math ```math
Obj(\Theta) = L(\Theta) + \Omega(\Theta) Obj(\Theta) = L(\Theta) + \Omega(\Theta)
@ -42,8 +42,8 @@ Another commonly used loss function is logistic loss for logistic regression
L(\theta) = \sum_i[ y_i\ln (1+e^{-\hat{y}_i}) + (1-y_i)\ln (1+e^{\hat{y}_i})] L(\theta) = \sum_i[ y_i\ln (1+e^{-\hat{y}_i}) + (1-y_i)\ln (1+e^{\hat{y}_i})]
``` ```
The ***regularization term*** is usually people forget to add. The regularization term controls the complexity of the model, this helps us to avoid overfitting. The ***regularization term*** is what people usually forget to add. The regularization term controls the complexity of the model, which helps us to avoid overfitting.
This sounds a bit abstract, let us consider the following problem in the following picture. You are asked to *fit* visually a step function given the input data points This sounds a bit abstract, so let us consider the following problem in the following picture. You are asked to *fit* visually a step function given the input data points
on the upper left corner of the image, which solution among the tree you think is the best fit? on the upper left corner of the image, which solution among the tree you think is the best fit?
![Step function](img/step_fit.png) ![Step function](img/step_fit.png)
@ -55,12 +55,12 @@ The tradeoff between the two is also referred as bias-variance tradeoff in machi
### Why introduce the general principle ### Why introduce the general principle
The elements introduced in above forms the basic elements of supervised learning, and they are naturally the building blocks of machine learning toolkits. The elements introduced in above forms the basic elements of supervised learning, and they are naturally the building blocks of machine learning toolkits.
For example, you should be able to answer what is the difference and common parts between boosted trees and random forest. For example, you should be able to answer what is the difference and common parts between boosted trees and random forest.
Understanding the process in a formalized way also helps us to understand the objective what we are learning and getting the reason behind the heurestics such as Understanding the process in a formalized way also helps us to understand the objective that we are learning and the reason behind the heurestics such as
pruning and smoothing. pruning and smoothing.
Tree Ensemble Tree Ensemble
------------- -------------
Now we have introduce the elements of supervised learning, let us getting started with real trees. Now that we have introduced the elements of supervised learning, let us get started with real trees.
To begin with, let us first learn what is the ***model*** of xgboost: tree ensembles. To begin with, let us first learn what is the ***model*** of xgboost: tree ensembles.
The tree ensemble model is a set of classification and regression trees (CART). Here's a simple example of a CART The tree ensemble model is a set of classification and regression trees (CART). Here's a simple example of a CART
that classifies is someone will like computer games. that classifies is someone will like computer games.
@ -69,17 +69,17 @@ that classifies is someone will like computer games.
We classify the members in thie family into different leaves, and assign them the score on corresponding leaf. We classify the members in thie family into different leaves, and assign them the score on corresponding leaf.
A CART is a bit different from decision trees, where the leaf only contain decision values. In CART, a real score A CART is a bit different from decision trees, where the leaf only contain decision values. In CART, a real score
is associated with each of the leaves, this allows gives us richer interpretations that go beyond classification. is associated with each of the leaves, which gives us richer interpretations that go beyond classification.
This also makes the unified optimization step easier, as we will see in later part of this tutorial. This also makes the unified optimization step easier, as we will see in later part of this tutorial.
Usually, a single tree is not so strong enough to be used in practice. What is actually used is the so called Usually, a single tree is not so strong enough to be used in practice. What is actually used is the so called
tree ensemble model, that sumes the prediction of multiple trees together. tree ensemble model, that sums the prediction of multiple trees together.
![TwoCART](img/twocart.png) ![TwoCART](img/twocart.png)
Here is an example of tree ensemble of two trees. The prediction scores of each individual tree are summed up to get the final score. Here is an example of tree ensemble of two trees. The prediction scores of each individual tree are summed up to get the final score.
If you look at the example, an important fact is that the two trees tries to *complement* each other. If you look at the example, an important fact is that the two trees try to *complement* each other.
Mathematically, we can write our model into the form Mathematically, we can write our model in the form
```math ```math
\hat{y}_i = \sum_{k=1}^K f_k(x_i), f_k \in \mathcal{F} \hat{y}_i = \sum_{k=1}^K f_k(x_i), f_k \in \mathcal{F}
@ -219,7 +219,7 @@ This formula can be decomposited as 1) the score on the new left leaf 2) the sco
We can find an important fact here: if the gain is smaller than ``$\gamma$``, we would better not to add that branch. This is exactly the ***prunning*** techniques in tree based We can find an important fact here: if the gain is smaller than ``$\gamma$``, we would better not to add that branch. This is exactly the ***prunning*** techniques in tree based
models! By using the principles of supervised learning, we can naturally comes up with the reason these techniques :) models! By using the principles of supervised learning, we can naturally comes up with the reason these techniques :)
For real valued data, we usually want to search for an optimal split. To efficiently doing so, we place all the instances in a sorted way, like the following picture. For real valued data, we usually want to search for an optimal split. To efficiently do so, we place all the instances in a sorted way, like the following picture.
![Best split](img/split_find.png) ![Best split](img/split_find.png)
Then a left to right scan is sufficient to calculate the structure score of all possible split solutions, and we can find the best split efficiently. Then a left to right scan is sufficient to calculate the structure score of all possible split solutions, and we can find the best split efficiently.

View File

@ -46,6 +46,10 @@ Parameters for Tree Booster
* colsample_bytree [default=1] * colsample_bytree [default=1]
- subsample ratio of columns when constructing each tree. - subsample ratio of columns when constructing each tree.
- range: (0,1] - range: (0,1]
* lambda [default=1]
- L2 regularization term on weights
* alpha [default=0]
- L1 regularization term on weights
Parameters for Linear Booster Parameters for Linear Booster
----------------------------- -----------------------------
@ -105,7 +109,7 @@ The following parameters are only used in the console version of xgboost
* task [default=train] options: train, pred, eval, dump * task [default=train] options: train, pred, eval, dump
- train: training using data - train: training using data
- pred: making prediction for test:data - pred: making prediction for test:data
- eval: for evaluating statistics specified by eval[name]=filenam - eval: for evaluating statistics specified by eval[name]=filename
- dump: for dump the learned model into text format(preliminary) - dump: for dump the learned model into text format(preliminary)
* model_in [default=NULL] * model_in [default=NULL]
- path to input model, needed for test, eval, dump, if it is specified in training, xgboost will continue training from the input model - path to input model, needed for test, eval, dump, if it is specified in training, xgboost will continue training from the input model

View File

@ -1,4 +1,4 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name, exec-used
"""Setup xgboost package.""" """Setup xgboost package."""
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
@ -17,9 +17,17 @@ if 'pip' in __file__:
output = build_sh.communicate() output = build_sh.communicate()
print(output) print(output)
import xgboost
LIB_PATH = xgboost.core.find_lib_path() CURRENT_DIR = os.path.dirname(__file__)
# We can not import `xgboost.libpath` in setup.py directly since xgboost/__init__.py
# import `xgboost.core` and finally will import `numpy` and `scipy` which are setup
# `install_requires`. That's why we're using `exec` here.
libpath_py = os.path.join(CURRENT_DIR, 'xgboost/libpath.py')
libpath = {'__file__': libpath_py}
exec(compile(open(libpath_py, "rb").read(), libpath_py, 'exec'), libpath, libpath)
LIB_PATH = libpath['find_lib_path']()
#print LIB_PATH #print LIB_PATH
#to deploy to pip, please use #to deploy to pip, please use
@ -27,9 +35,9 @@ LIB_PATH = xgboost.core.find_lib_path()
#python setup.py register sdist upload #python setup.py register sdist upload
#and be sure to test it firstly using "python setup.py register sdist upload -r pypitest" #and be sure to test it firstly using "python setup.py register sdist upload -r pypitest"
setup(name='xgboost', setup(name='xgboost',
version=xgboost.__version__, version=open(os.path.join(CURRENT_DIR, 'xgboost/VERSION')).read().strip(),
#version='0.4a13', #version='0.4a13',
description=xgboost.__doc__, description=open(os.path.join(CURRENT_DIR, 'README.md')).read(),
install_requires=[ install_requires=[
'numpy', 'numpy',
'scipy', 'scipy',

View File

@ -0,0 +1 @@
0.4

View File

@ -5,12 +5,16 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
""" """
from __future__ import absolute_import from __future__ import absolute_import
import os
from .core import DMatrix, Booster from .core import DMatrix, Booster
from .training import train, cv from .training import train, cv
from .sklearn import XGBModel, XGBClassifier, XGBRegressor from .sklearn import XGBModel, XGBClassifier, XGBRegressor
from .plotting import plot_importance, plot_tree, to_graphviz from .plotting import plot_importance, plot_tree, to_graphviz
__version__ = '0.4' VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
__version__ = open(VERSION_FILE).read().strip()
__all__ = ['DMatrix', 'Booster', __all__ = ['DMatrix', 'Booster',
'train', 'cv', 'train', 'cv',

View File

@ -1,67 +1,81 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments, too-many-branches
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import from __future__ import absolute_import
import os import os
import sys import sys
import ctypes import ctypes
import platform
import collections import collections
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .libpath import find_lib_path
class XGBoostLibraryNotFound(Exception):
"""Error throwed by when xgboost is not found"""
pass
class XGBoostError(Exception): class XGBoostError(Exception):
"""Error throwed by xgboost trainer.""" """Error throwed by xgboost trainer."""
pass pass
PY3 = (sys.version_info[0] == 3)
if sys.version_info[0] == 3: if PY3:
# pylint: disable=invalid-name # pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str, STRING_TYPES = str,
else: else:
# pylint: disable=invalid-name # pylint: disable=invalid-name
STRING_TYPES = basestring, STRING_TYPES = basestring,
def find_lib_path(): def from_pystr_to_cstr(data):
"""Load find the path to xgboost dynamic library files. """Convert a list of Python str to C pointer
Returns Parameters
------- ----------
lib_path: list(string) data : list
List of all found library path to xgboost list of str
""" """
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
#make pythonpack hack: copy this directory one level upper for setup.py if isinstance(data, list):
dll_path = [curr_path, os.path.join(curr_path, '../../wrapper/') pointers = (ctypes.c_char_p * len(data))()
, os.path.join(curr_path, './wrapper/')] if PY3:
if os.name == 'nt': data = [bytes(d, 'utf-8') for d in data]
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../windows/x64/Release/'))
#hack for pip installation when copy all parent source directory here
dll_path.append(os.path.join(curr_path, './windows/x64/Release/'))
else: else:
dll_path.append(os.path.join(curr_path, '../../windows/Release/')) data = [d.encode('utf-8') if isinstance(d, unicode) else d
#hack for pip installation when copy all parent source directory here for d in data]
dll_path.append(os.path.join(curr_path, './windows/Release/')) pointers[:] = data
if os.name == 'nt': return pointers
dll_path = [os.path.join(p, 'xgboost_wrapper.dll') for p in dll_path]
else: else:
dll_path = [os.path.join(p, 'libxgboostwrapper.so') for p in dll_path] # copy from above when we actually use it
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] raise NotImplementedError
if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
raise XGBoostLibraryNotFound(
'Cannot find XGBoost Libarary in the candicate path, ' + def from_cstr_to_pystr(data, length):
'did you run build.sh in root path?\n' """Revert C pointer to Python str
'List of candidates:\n' + ('\n'.join(dll_path)))
return lib_path Parameters
----------
data : ctypes pointer
pointer to data
length : ctypes pointer
pointer to length of data
"""
if PY3:
res = []
for i in range(length.value):
try:
res.append(str(data[i].decode('ascii')))
except UnicodeDecodeError:
res.append(str(data[i].decode('utf-8')))
else:
res = []
for i in range(length.value):
try:
res.append(str(data[i].decode('ascii')))
except UnicodeDecodeError:
res.append(unicode(data[i].decode('utf-8')))
return res
def _load_lib(): def _load_lib():
@ -124,6 +138,28 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _maybe_from_pandas(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, feature_names, feature_types
if not isinstance(data, pd.DataFrame):
return data, feature_names, feature_types
dtypes = data.dtypes
if not all(dtype.name in ('int64', 'float64', 'bool') for dtype in dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
if feature_names is None:
feature_names = data.columns.format()
if feature_types is None:
mapper = {'int64': 'int', 'float64': 'q', 'bool': 'i'}
feature_types = [mapper[dtype.name] for dtype in dtypes]
data = data.values.astype('float')
return data, feature_names, feature_types
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -131,13 +167,19 @@ class DMatrix(object):
which is optimized for both memory efficiency and training speed. which is optimized for both memory efficiency and training speed.
You can construct DMatrix from numpy.arrays You can construct DMatrix from numpy.arrays
""" """
def __init__(self, data, label=None, missing=0.0, weight=None, silent=False):
_feature_names = None # for previous version's pickle
_feature_types = None
def __init__(self, data, label=None, missing=0.0,
weight=None, silent=False,
feature_names=None, feature_types=None):
""" """
Data matrix used in XGBoost. Data matrix used in XGBoost.
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse data : string/numpy array/scipy.sparse/pd.DataFrame
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from. or binary file that xgboost can read from.
@ -149,11 +191,22 @@ class DMatrix(object):
Weight for each instance. Weight for each instance.
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
feature_names : list, optional
Labels for features.
feature_types : list, optional
Labels for features.
""" """
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data is None: if data is None:
self.handle = None self.handle = None
return return
klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, feature_names, feature_types = _maybe_from_pandas(data, feature_names,
feature_types)
if isinstance(data, STRING_TYPES): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
@ -163,7 +216,7 @@ class DMatrix(object):
self._init_from_csr(data) self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data) self._init_from_csc(data)
elif isinstance(data, np.ndarray) and len(data.shape) == 2: elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing) self._init_from_npy2d(data, missing)
else: else:
try: try:
@ -176,6 +229,9 @@ class DMatrix(object):
if weight is not None: if weight is not None:
self.set_weight(weight) self.set_weight(weight)
self.feature_names = feature_names
self.feature_types = feature_types
def _init_from_csr(self, csr): def _init_from_csr(self, csr):
""" """
Initialize data from a CSR matrix. Initialize data from a CSR matrix.
@ -206,6 +262,8 @@ class DMatrix(object):
""" """
Initialize data from a 2-D numpy matrix. Initialize data from a 2-D numpy matrix.
""" """
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
data = np.array(mat.reshape(mat.size), dtype=np.float32) data = np.array(mat.reshape(mat.size), dtype=np.float32)
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), _check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
@ -391,6 +449,18 @@ class DMatrix(object):
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
def num_col(self):
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns : int
"""
ret = ctypes.c_uint()
_check_call(_LIB.XGDMatrixNumCol(self.handle,
ctypes.byref(ret)))
return ret.value
def slice(self, rindex): def slice(self, rindex):
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`. """Slice the DMatrix and return a new DMatrix that only contains `rindex`.
@ -404,7 +474,7 @@ class DMatrix(object):
res : DMatrix res : DMatrix
A new DMatrix containing only selected indices. A new DMatrix containing only selected indices.
""" """
res = DMatrix(None) res = DMatrix(None, feature_names=self.feature_names)
res.handle = ctypes.c_void_p() res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle, _check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
c_array(ctypes.c_int, rindex), c_array(ctypes.c_int, rindex),
@ -412,6 +482,88 @@ class DMatrix(object):
ctypes.byref(res.handle))) ctypes.byref(res.handle)))
return res return res
@property
def feature_names(self):
"""Get feature names (column labels).
Returns
-------
feature_names : list or None
"""
return self._feature_names
@property
def feature_types(self):
"""Get feature types (column types).
Returns
-------
feature_types : list or None
"""
return self._feature_types
@feature_names.setter
def feature_names(self, feature_names):
"""Set feature names (column labels).
Parameters
----------
feature_names : list or None
Labels for features. None will reset existing feature names
"""
if not feature_names is None:
# validate feature name
if not isinstance(feature_names, list):
feature_names = list(feature_names)
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
for f in feature_names):
raise ValueError('all feature_names must be alphanumerics')
else:
# reset feature_types also
self.feature_types = None
self._feature_names = feature_names
@feature_types.setter
def feature_types(self, feature_types):
"""Set feature types (column types).
This is for displaying the results and unrelated
to the learning process.
Parameters
----------
feature_types : list or None
Labels for features. None will reset existing feature names
"""
if not feature_types is None:
if self.feature_names is None:
msg = 'Unable to set feature types before setting names'
raise ValueError(msg)
if isinstance(feature_types, STRING_TYPES):
# single string will be applied to all columns
feature_types = [feature_types] * self.num_col()
if not isinstance(feature_types, list):
feature_types = list(feature_types)
if len(feature_types) != self.num_col():
msg = 'feature_types must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
valid = ('q', 'i', 'int', 'float')
if not all(isinstance(f, STRING_TYPES) and f in valid
for f in feature_types):
raise ValueError('all feature_names must be {i, q, int, float}')
self._feature_types = feature_types
class Booster(object): class Booster(object):
""""A Booster of of XGBoost. """"A Booster of of XGBoost.
@ -419,6 +571,9 @@ class Booster(object):
Booster is the model of xgboost, that contains low level routines for Booster is the model of xgboost, that contains low level routines for
training, prediction and evaluation. training, prediction and evaluation.
""" """
feature_names = None
def __init__(self, params=None, cache=(), model_file=None): def __init__(self, params=None, cache=(), model_file=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Initialize the Booster. """Initialize the Booster.
@ -435,6 +590,8 @@ class Booster(object):
for d in cache: for d in cache:
if not isinstance(d, DMatrix): if not isinstance(d, DMatrix):
raise TypeError('invalid cache item: {}'.format(type(d).__name__)) raise TypeError('invalid cache item: {}'.format(type(d).__name__))
self._validate_features(d)
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle))) _check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
@ -519,6 +676,8 @@ class Booster(object):
""" """
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
self._validate_features(dtrain)
if fobj is None: if fobj is None:
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle)) _check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
else: else:
@ -543,6 +702,8 @@ class Booster(object):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
self._validate_features(dtrain)
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle, _check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
c_array(ctypes.c_float, grad), c_array(ctypes.c_float, grad),
c_array(ctypes.c_float, hess), c_array(ctypes.c_float, hess),
@ -572,6 +733,8 @@ class Booster(object):
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__)) raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
if not isinstance(d[1], STRING_TYPES): if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format(type(d[1]).__name__)) raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
self._validate_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
msg = ctypes.c_char_p() msg = ctypes.c_char_p()
@ -605,6 +768,7 @@ class Booster(object):
result: str result: str
Evaluation result string. Evaluation result string.
""" """
self._validate_features(data)
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
@ -642,6 +806,9 @@ class Booster(object):
option_mask |= 0x01 option_mask |= 0x01
if pred_leaf: if pred_leaf:
option_mask |= 0x02 option_mask |= 0x02
self._validate_features(data)
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = ctypes.POINTER(ctypes.c_float)() preds = ctypes.POINTER(ctypes.c_float)()
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle, _check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
@ -694,8 +861,11 @@ class Booster(object):
fname : string or a memory buffer fname : string or a memory buffer
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, str): # assume file name if isinstance(fname, STRING_TYPES): # assume file name
_LIB.XGBoosterLoadModel(self.handle, c_str(fname)) if os.path.exists(fname):
_LIB.XGBoosterLoadModel(self.handle, c_str(fname))
else:
raise ValueError("No such file: {0}".format(fname))
else: else:
buf = fname buf = fname
length = ctypes.c_ulong(len(buf)) length = ctypes.c_ulong(len(buf))
@ -731,16 +901,36 @@ class Booster(object):
""" """
Returns the dump the model as a list of strings. Returns the dump the model as a list of strings.
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
_check_call(_LIB.XGBoosterDumpModel(self.handle, if self.feature_names is not None and fmap == '':
c_str(fmap), flen = int(len(self.feature_names))
int(with_stats),
ctypes.byref(length), fname = from_pystr_to_cstr(self.feature_names)
ctypes.byref(sarr)))
res = [] if self.feature_types is None:
for i in range(length.value): # use quantitative as default
res.append(str(sarr[i].decode('ascii'))) # {'q': quantitative, 'i': indicator}
ftype = from_pystr_to_cstr(['q'] * flen)
else:
ftype = from_pystr_to_cstr(self.feature_types)
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
flen,
fname,
ftype,
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
else:
if fmap != '' and not os.path.exists(fmap):
raise ValueError("No such file: {0}".format(fmap))
_check_call(_LIB.XGBoosterDumpModel(self.handle,
c_str(fmap),
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
res = from_cstr_to_pystr(sarr, length)
return res return res
def get_fscore(self, fmap=''): def get_fscore(self, fmap=''):
@ -765,3 +955,19 @@ class Booster(object):
else: else:
fmap[fid] += 1 fmap[fid] += 1
return fmap return fmap
def _validate_features(self, data):
"""
Validate Booster and data's feature_names are identical.
Set feature_names and feature_types from DMatrix
"""
if self.feature_names is None:
self.feature_names = data.feature_names
self.feature_types = data.feature_types
else:
# Booster can't accept data with different feature names
if self.feature_names != data.feature_names:
msg = 'feature_names mismatch: {0} {1}'
raise ValueError(msg.format(self.feature_names,
data.feature_names))

View File

@ -0,0 +1,44 @@
# coding: utf-8
"""Find the path to xgboost dynamic library files."""
import os
import platform
class XGBoostLibraryNotFound(Exception):
"""Error throwed by when xgboost is not found"""
pass
def find_lib_path():
"""Load find the path to xgboost dynamic library files.
Returns
-------
lib_path: list(string)
List of all found library path to xgboost
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
# make pythonpack hack: copy this directory one level upper for setup.py
dll_path = [curr_path, os.path.join(curr_path, '../../wrapper/'),
os.path.join(curr_path, './wrapper/')]
if os.name == 'nt':
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../windows/x64/Release/'))
# hack for pip installation when copy all parent source directory here
dll_path.append(os.path.join(curr_path, './windows/x64/Release/'))
else:
dll_path.append(os.path.join(curr_path, '../../windows/Release/'))
# hack for pip installation when copy all parent source directory here
dll_path.append(os.path.join(curr_path, './windows/Release/'))
if os.name == 'nt':
dll_path = [os.path.join(p, 'xgboost_wrapper.dll') for p in dll_path]
else:
dll_path = [os.path.join(p, 'libxgboostwrapper.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
raise XGBoostLibraryNotFound(
'Cannot find XGBoost Libarary in the candicate path, ' +
'did you run build.sh in root path?\n'
'List of candidates:\n' + ('\n'.join(dll_path)))
return lib_path

View File

@ -92,7 +92,7 @@ def plot_importance(booster, ax=None, height=0.2,
_NODEPAT = re.compile(r'(\d+):\[(.+)\]') _NODEPAT = re.compile(r'(\d+):\[(.+)\]')
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)') _LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)') _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text): def _parse_node(graph, text):
"""parse dumped node""" """parse dumped node"""
@ -111,15 +111,24 @@ def _parse_node(graph, text):
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'): def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge""" """parse dumped edge"""
match = _EDGEPAT.match(text) try:
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
except ValueError:
pass
match = _EDGEPAT2.match(text)
if match is not None: if match is not None:
yes, no, missing = match.groups() yes, no = match.groups()
if yes == missing: graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, yes, label='yes, missing', color=yes_color) graph.edge(node, no, label='no', color=no_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return return
raise ValueError('Unable to parse edge: {0}'.format(text)) raise ValueError('Unable to parse edge: {0}'.format(text))

View File

@ -319,7 +319,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if len(class_probs.shape) > 1: if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1) column_indexes = np.argmax(class_probs, axis=1)
else: else:
column_indexes = np.repeat(0, data.shape[0]) column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name # pylint: disable=too-many-locals, too-many-arguments, invalid-name
# pylint: disable=too-many-branches
"""Training Library containing training routines.""" """Training Library containing training routines."""
from __future__ import absolute_import from __future__ import absolute_import
@ -118,7 +119,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
sys.stderr.write(msg + '\n') sys.stderr.write(msg + '\n')
if evals_result is not None: if evals_result is not None:
res = re.findall(":-([0-9.]+).", msg) res = re.findall(":-?([0-9.]+).", msg)
for key, val in zip(evals_name, res): for key, val in zip(evals_name, res):
evals_result[key].append(val) evals_result[key].append(val)
@ -179,16 +180,16 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
return ret return ret
def aggcv(rlist, show_stdv=True): def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True):
# pylint: disable=invalid-name # pylint: disable=invalid-name
""" """
Aggregate cross-validation results. Aggregate cross-validation results.
""" """
cvmap = {} cvmap = {}
ret = rlist[0].split()[0] idx = rlist[0].split()[0]
for line in rlist: for line in rlist:
arr = line.split() arr = line.split()
assert ret == arr[0] assert idx == arr[0]
for it in arr[1:]: for it in arr[1:]:
if not isinstance(it, STRING_TYPES): if not isinstance(it, STRING_TYPES):
it = it.decode() it = it.decode()
@ -196,19 +197,50 @@ def aggcv(rlist, show_stdv=True):
if k not in cvmap: if k not in cvmap:
cvmap[k] = [] cvmap[k] = []
cvmap[k].append(float(v)) cvmap[k].append(float(v))
msg = idx
if show_stdv:
fmt = '\tcv-{0}:{1}+{2}'
else:
fmt = '\tcv-{0}:{1}'
index = []
results = []
for k, v in sorted(cvmap.items(), key=lambda x: x[0]): for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
v = np.array(v) v = np.array(v)
if not isinstance(ret, STRING_TYPES): if not isinstance(msg, STRING_TYPES):
ret = ret.decode() msg = msg.decode()
if show_stdv: mean, std = np.mean(v), np.std(v)
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v)) msg += fmt.format(k, mean, std)
else:
ret += '\tcv-%s:%f' % (k, np.mean(v)) index.extend([k + '-mean', k + '-std'])
return ret results.extend([mean, std])
if as_pandas:
try:
import pandas as pd
results = pd.Series(results, index=index)
except ImportError:
if show_progress is None:
show_progress = True
else:
# if show_progress is default (None),
# result will be np.ndarray as it can't hold column name
if show_progress is None:
show_progress = True
if show_progress:
sys.stderr.write(msg + '\n')
return results
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0): obj=None, feval=None, fpreproc=None, as_pandas=True,
show_progress=None, show_stdv=True, seed=0):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given paramaters. """Cross-validation with given paramaters.
@ -231,8 +263,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
fpreproc : function fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those. transformed versions of those.
show_stdv : bool as_pandas : bool, default True
Whether to display the standard deviation. Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return np.ndarray
show_progress : bool or None, default None
Whether to display the progress. If None, progress will be displayed
when np.ndarray is returned.
show_stdv : bool, default True
Whether to display the standard deviation in progress.
Results are not affected, and always contains std.
seed : int seed : int
Seed used to generate the folds (passed to numpy.random.seed). Seed used to generate the folds (passed to numpy.random.seed).
@ -245,8 +284,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
for i in range(num_boost_round): for i in range(num_boost_round):
for fold in cvfolds: for fold in cvfolds:
fold.update(i, obj) fold.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv) res = aggcv([f.eval(i, feval) for f in cvfolds],
sys.stderr.write(res + '\n') show_stdv=show_stdv, show_progress=show_progress,
as_pandas=as_pandas)
results.append(res) results.append(res)
if as_pandas:
try:
import pandas as pd
results = pd.DataFrame(results)
except ImportError:
results = np.array(results)
else:
results = np.array(results)
return results return results

View File

@ -5,15 +5,3 @@ if [ ${TRAVIS_OS_NAME} != "osx" ]; then
fi fi
brew update brew update
if [ ${TASK} == "python-package" ]; then
brew install python git graphviz
easy_install pip
pip install numpy scipy matplotlib nose
fi
if [ ${TASK} == "python-package3" ]; then
brew install python3 git graphviz
sudo pip3 install --upgrade setuptools
pip3 install numpy scipy matplotlib nose graphviz
fi

View File

@ -33,30 +33,44 @@ if [ ${TASK} == "R-package" ]; then
scripts/travis_R_script.sh || exit -1 scripts/travis_R_script.sh || exit -1
fi fi
if [ ${TASK} == "python-package" ]; then if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
sudo apt-get install graphviz
sudo apt-get install python-numpy python-scipy python-matplotlib python-nose
sudo python -m pip install graphviz
make all CXX=${CXX} || exit -1
nosetests tests/python || exit -1
fi
if [ ${TASK} == "python-package3" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
sudo apt-get install graphviz brew install graphviz
# python3-matplotlib is unavailale on Ubuntu 12.04 if [ ${TASK} == "python-package3" ]; then
sudo apt-get install python3-dev wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
sudo apt-get install python3-numpy python3-scipy python3-nose python3-setuptools else
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-MacOSX-x86_64.sh
make all CXX=${CXX} || exit -1 fi
if [ ${TRAVIS_OS_NAME} != "osx" ]; then
sudo easy_install3 pip
sudo easy_install3 -U distribute
sudo pip install graphviz matplotlib
nosetests3 tests/python || exit -1
else else
nosetests tests/python || exit -1 sudo apt-get install graphviz
if [ ${TASK} == "python-package3" ]; then
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
else
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh
fi
fi fi
bash conda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
hash -r
conda config --set always_yes yes --set changeps1 no
conda update -q conda
# Useful for debugging any issues with conda
conda info -a
if [ ${TASK} == "python-package3" ]; then
conda create -n myenv python=3.4
else
conda create -n myenv python=2.7
fi
source activate myenv
conda install numpy scipy pandas matplotlib nose
python -m pip install graphviz
make all CXX=${CXX} || exit -1
python -m nose tests/python || exit -1
python --version
fi fi
# only test java under linux for now # only test java under linux for now

View File

@ -1 +1 @@
This folder contains tetstcases for xgboost. This folder contains testcases for xgboost.

View File

@ -1,75 +1,239 @@
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest
dpath = 'demo/data/' dpath = 'demo/data/'
def test_basic(): class TestBasic(unittest.TestCase):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
# specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)
# this is prediction
preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
# save dmatrix into binary buffer def test_basic(self):
dtest.save_binary('dtest.buffer') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
# save model dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
bst.save_model('xgb.model') param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
# load model and data in # specify validations set to watch performance
bst2 = xgb.Booster(model_file='xgb.model') watchlist = [(dtest,'eval'), (dtrain,'train')]
dtest2 = xgb.DMatrix('dtest.buffer') num_round = 2
preds2 = bst2.predict(dtest2) bst = xgb.train(param, dtrain, num_round, watchlist)
# assert they are the same # this is prediction
assert np.sum(np.abs(preds2-preds)) == 0 preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
def test_plotting(): # save dmatrix into binary buffer
bst2 = xgb.Booster(model_file='xgb.model') dtest.save_binary('dtest.buffer')
# plotting # save model
bst.save_model('xgb.model')
# load model and data in
bst2 = xgb.Booster(model_file='xgb.model')
dtest2 = xgb.DMatrix('dtest.buffer')
preds2 = bst2.predict(dtest2)
# assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0
import matplotlib def test_dmatrix_init(self):
matplotlib.use('Agg') data = np.random.randn(5, 5)
from matplotlib.axes import Axes # different length
from graphviz import Digraph self.assertRaises(ValueError, xgb.DMatrix, data,
feature_names=list('abcdef'))
# contains duplicates
self.assertRaises(ValueError, xgb.DMatrix, data,
feature_names=['a', 'b', 'c', 'd', 'd'])
# contains symbol
self.assertRaises(ValueError, xgb.DMatrix, data,
feature_names=['a', 'b', 'c', 'd', 'e=1'])
ax = xgb.plot_importance(bst2) dm = xgb.DMatrix(data)
assert isinstance(ax, Axes) dm.feature_names = list('abcde')
assert ax.get_title() == 'Feature importance' assert dm.feature_names == list('abcde')
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
ax = xgb.plot_importance(bst2, color='r', dm.feature_types = 'q'
title='t', xlabel='x', ylabel='y') assert dm.feature_types == list('qqqqq')
assert isinstance(ax, Axes)
assert ax.get_title() == 't' dm.feature_types = list('qiqiq')
assert ax.get_xlabel() == 'x' assert dm.feature_types == list('qiqiq')
assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4 def incorrect_type_set():
for p in ax.patches: dm.feature_types = list('abcde')
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red self.assertRaises(ValueError, incorrect_type_set)
# reset
dm.feature_names = None
assert dm.feature_names is None
assert dm.feature_types is None
def test_feature_names(self):
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
cases = [['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'],
[u'要因1', u'要因2', u'要因3', u'要因4', u'要因5']]
for features in cases:
dm = xgb.DMatrix(data, label=target,
feature_names=features)
assert dm.feature_names == features
assert dm.num_row() == 100
assert dm.num_col() == 5
params={'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
scores = bst.get_fscore()
assert list(sorted(k for k in scores)) == features
dummy = np.random.randn(5, 5)
dm = xgb.DMatrix(dummy, feature_names=features)
bst.predict(dm)
# different feature name must raises error
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm)
def test_pandas(self):
import pandas as pd
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'q', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.feature_types == ['q', 'q', 'q']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
self.assertRaises(ValueError, xgb.DMatrix, df)
# numeric columns
df = pd.DataFrame([[1, 2., True], [2, 3., False]])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['0', '1', '2']
assert dm.feature_types == ['int', 'q', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['4', '5', '6']
assert dm.feature_types == ['int', 'q', 'int']
assert dm.num_row() == 2
assert dm.num_col() == 3
def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster,
model_file='incorrect_path')
self.assertRaises(ValueError, xgb.Booster,
model_file=u'不正なパス')
def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5)
dm = xgb.DMatrix(data)
assert dm.num_row() == 5
assert dm.num_col() == 5
data = np.matrix([[1, 2], [3, 4]])
dm = xgb.DMatrix(data)
assert dm.num_row() == 2
assert dm.num_col() == 2
# 0d array
self.assertRaises(ValueError, xgb.DMatrix, np.array(1))
# 1d array
self.assertRaises(ValueError, xgb.DMatrix, np.array([1, 2, 3]))
# 3d array
data = np.random.randn(5, 5, 5)
self.assertRaises(ValueError, xgb.DMatrix, data)
# object dtype
data = np.array([['a', 'b'], ['c', 'd']])
self.assertRaises(ValueError, xgb.DMatrix, data)
def test_cv(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
import pandas as pd
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
assert isinstance(cv, pd.DataFrame)
exp = pd.Index([u'test-error-mean', u'test-error-std',
u'train-error-mean', u'train-error-std'])
assert cv.columns.equals(exp)
# show progress log (result is the same as above)
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
show_progress=True)
assert isinstance(cv, pd.DataFrame)
exp = pd.Index([u'test-error-mean', u'test-error-std',
u'train-error-mean', u'train-error-std'])
assert cv.columns.equals(exp)
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
show_progress=True, show_stdv=False)
assert isinstance(cv, pd.DataFrame)
exp = pd.Index([u'test-error-mean', u'test-error-std',
u'train-error-mean', u'train-error-std'])
assert cv.columns.equals(exp)
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
assert isinstance(cv, np.ndarray)
assert cv.shape == (10, 4)
def test_plotting(self):
bst2 = xgb.Booster(model_file='xgb.model')
# plotting
import matplotlib
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
ax = xgb.plot_importance(bst2)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
ax = xgb.plot_importance(bst2, color='r',
title='t', xlabel='x', ylabel='y')
assert isinstance(ax, Axes)
assert ax.get_title() == 't'
assert ax.get_xlabel() == 'x'
assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4
for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'], ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None) title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
assert ax.get_title() == '' assert ax.get_title() == ''
assert ax.get_xlabel() == '' assert ax.get_xlabel() == ''
assert ax.get_ylabel() == '' assert ax.get_ylabel() == ''
assert len(ax.patches) == 4 assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0) g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph) assert isinstance(g, Digraph)
ax = xgb.plot_tree(bst2, num_trees=0) ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)

View File

@ -0,0 +1,39 @@
import numpy as np
import xgboost as xgb
dpath = 'demo/data/'
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def test_glm():
param = {'silent':1, 'objective':'binary:logistic', 'booster':'gblinear', 'alpha': 0.0001, 'lambda': 1 }
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
assert isinstance(bst, xgb.core.Booster)
preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1
def test_custom_objective():
param = {'max_depth':2, 'eta':1, 'silent':1 }
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 2
def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0-preds)
return grad, hess
def evalerror(preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
assert isinstance(bst, xgb.core.Booster)
preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1

View File

@ -435,6 +435,7 @@ int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
*out_dptr = BeginPtr(vec); *out_dptr = BeginPtr(vec);
API_END(); API_END();
} }
int XGDMatrixNumRow(const DMatrixHandle handle, int XGDMatrixNumRow(const DMatrixHandle handle,
bst_ulong *out) { bst_ulong *out) {
API_BEGIN(); API_BEGIN();
@ -442,6 +443,13 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
API_END(); API_END();
} }
int XGDMatrixNumCol(const DMatrixHandle handle,
bst_ulong *out) {
API_BEGIN();
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
API_END();
}
// xgboost implementation // xgboost implementation
int XGBoosterCreate(DMatrixHandle dmats[], int XGBoosterCreate(DMatrixHandle dmats[],
bst_ulong len, bst_ulong len,
@ -572,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
featmap, with_stats != 0, len); featmap, with_stats != 0, len);
API_END(); API_END();
} }
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models) {
API_BEGIN();
utils::FeatMap featmap;
for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]);
}
*out_models = static_cast<Booster*>(handle)->GetModelDump(
featmap, with_stats != 0, len);
API_END();
}

View File

@ -184,6 +184,13 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
*/ */
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
bst_ulong *out); bst_ulong *out);
/*!
* \brief get number of columns
* \param handle the handle to the DMatrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
bst_ulong *out);
// --- start XGBoost class // --- start XGBoost class
/*! /*!
* \brief create xgboost learner * \brief create xgboost learner
@ -324,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
int with_stats, int with_stats,
bst_ulong *out_len, bst_ulong *out_len,
const char ***out_dump_array); const char ***out_dump_array);
/*!
* \brief dump model, return array of strings representing model dump
* \param handle handle
* \param fnum number of features
* \param fnum names of features
* \param fnum types of features
* \param with_stats whether to dump with statistics
* \param out_len length of output array
* \param out_dump_array pointer to hold representing dump of each model
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models);
#endif // XGBOOST_WRAPPER_H_ #endif // XGBOOST_WRAPPER_H_