Merge pull request #580 from terrytangyuan/test

Fixed most of the lint issues
This commit is contained in:
Yuan (Terry) Tang 2015-10-29 00:54:16 -04:00
commit b9a9cd9db8
15 changed files with 191 additions and 197 deletions

View File

@ -48,7 +48,7 @@ setMethod("predict", signature = "xgb.Booster",
stop("predict: ntreelimit must be equal to or greater than 1") stop("predict: ntreelimit must be equal to or greater than 1")
} }
} }
option = 0 option <- 0
if (outputmargin) { if (outputmargin) {
option <- option + 1 option <- option + 1
} }

View File

@ -30,12 +30,12 @@ setMethod("slice", signature = "xgb.DMatrix",
} }
ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset, ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset,
PACKAGE = "xgboost") PACKAGE = "xgboost")
attr_list <- attributes(object) attr_list <- attributes(object)
nr <- xgb.numrow(object) nr <- xgb.numrow(object)
len <- sapply(attr_list,length) len <- sapply(attr_list,length)
ind <- which(len==nr) ind <- which(len == nr)
if (length(ind)>0) { if (length(ind) > 0) {
nms <- names(attr_list)[ind] nms <- names(attr_list)[ind]
for (i in 1:length(ind)) { for (i in 1:length(ind)) {
attr(ret,nms[i]) <- attr(object,nms[i])[idxset] attr(ret,nms[i]) <- attr(object,nms[i])[idxset]

View File

@ -1,4 +1,4 @@
#' @importClassesFrom Matrix dgCMatrix dgeMatrix #' @importClassesFrom Matrix dgCMatrix dgeMatrix
#' @import methods #' @import methods
# depends on matrix # depends on matrix
@ -15,14 +15,14 @@ xgb.setinfo <- function(dmat, name, info) {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix") stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
} }
if (name == "label") { if (name == "label") {
if (length(info)!=xgb.numrow(dmat)) if (length(info) != xgb.numrow(dmat))
stop("The length of labels must equal to the number of rows in the input data") stop("The length of labels must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
if (name == "weight") { if (name == "weight") {
if (length(info)!=xgb.numrow(dmat)) if (length(info) != xgb.numrow(dmat))
stop("The length of weights must equal to the number of rows in the input data") stop("The length of weights must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
@ -36,7 +36,7 @@ xgb.setinfo <- function(dmat, name, info) {
return(TRUE) return(TRUE)
} }
if (name == "group") { if (name == "group") {
if (sum(info)!=xgb.numrow(dmat)) if (sum(info) != xgb.numrow(dmat))
stop("The sum of groups must equal to the number of rows in the input data") stop("The sum of groups must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
@ -68,7 +68,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
if (typeof(modelfile) == "character") { if (typeof(modelfile) == "character") {
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost") .Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
} else if (typeof(modelfile) == "raw") { } else if (typeof(modelfile) == "raw") {
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost") .Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
} else { } else {
stop("xgb.Booster: modelfile must be character or raw vector") stop("xgb.Booster: modelfile must be character or raw vector")
} }
@ -122,7 +122,7 @@ xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) {
} else if (inClass == "xgb.DMatrix") { } else if (inClass == "xgb.DMatrix") {
dtrain <- data dtrain <- data
} else if (inClass == "data.frame") { } else if (inClass == "data.frame") {
stop("xgboost only support numerical matrix input, stop("xgboost only support numerical matrix input,
use 'data.frame' to transform the data.") use 'data.frame' to transform the data.")
} else { } else {
stop("xgboost: Invalid input of data") stop("xgboost: Invalid input of data")
@ -156,7 +156,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
} }
if (is.null(obj)) { if (is.null(obj)) {
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain,
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else { } else {
pred <- predict(booster, dtrain) pred <- predict(booster, dtrain)
@ -189,9 +189,9 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F
} }
evnames <- append(evnames, names(w)) evnames <- append(evnames, names(w))
} }
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist,
evnames, PACKAGE = "xgboost") evnames, PACKAGE = "xgboost")
} else { } else {
msg <- paste("[", iter, "]", sep="") msg <- paste("[", iter, "]", sep="")
for (j in 1:length(watchlist)) { for (j in 1:length(watchlist)) {
w <- watchlist[j] w <- watchlist[j]
@ -247,11 +247,11 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
if (length(unique(y)) <= 5) y <- factor(y) if (length(unique(y)) <= 5) y <- factor(y)
} }
folds <- xgb.createFolds(y, nfold) folds <- xgb.createFolds(y, nfold)
} else { } else {
# make simple non-stratified folds # make simple non-stratified folds
kstep <- length(randidx) %/% nfold kstep <- length(randidx) %/% nfold
folds <- list() folds <- list()
for (i in 1:(nfold-1)) { for (i in 1:(nfold - 1)) {
folds[[i]] <- randidx[1:kstep] folds[[i]] <- randidx[1:kstep]
randidx <- setdiff(randidx, folds[[i]]) randidx <- setdiff(randidx, folds[[i]])
} }
@ -261,7 +261,7 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
ret <- list() ret <- list()
for (k in 1:nfold) { for (k in 1:nfold) {
dtest <- slice(dall, folds[[k]]) dtest <- slice(dall, folds[[k]])
didx = c() didx <- c()
for (i in 1:nfold) { for (i in 1:nfold) {
if (i != k) { if (i != k) {
didx <- append(didx, folds[[i]]) didx <- append(didx, folds[[i]])
@ -282,7 +282,7 @@ xgb.cv.aggcv <- function(res, showsd = TRUE) {
kv <- strsplit(header[i], ":")[[1]] kv <- strsplit(header[i], ":")[[1]]
ret <- paste(ret, "\t", kv[1], ":", sep="") ret <- paste(ret, "\t", kv[1], ":", sep="")
stats <- c() stats <- c()
stats[1] <- as.numeric(kv[2]) stats[1] <- as.numeric(kv[2])
for (j in 2:length(res)) { for (j in 2:length(res)) {
tkv <- strsplit(res[[j]][i], ":")[[1]] tkv <- strsplit(res[[j]][i], ":")[[1]]
stats[j] <- as.numeric(tkv[2]) stats[j] <- as.numeric(tkv[2])
@ -310,9 +310,9 @@ xgb.createFolds <- function(y, k = 10)
## At most, we will use quantiles. If the sample ## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified ## is too small, we just do regular unstratified
## CV ## CV
cuts <- floor(length(y)/k) cuts <- floor(length(y) / k)
if(cuts < 2) cuts <- 2 if (cuts < 2) cuts <- 2
if(cuts > 5) cuts <- 5 if (cuts > 5) cuts <- 5
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq(0, 1, length = cuts))), unique(stats::quantile(y, probs = seq(0, 1, length = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
@ -324,7 +324,7 @@ xgb.createFolds <- function(y, k = 10)
y <- factor(as.character(y)) y <- factor(as.character(y))
numInClass <- table(y) numInClass <- table(y)
foldVector <- vector(mode = "integer", length(y)) foldVector <- vector(mode = "integer", length(y))
## For each class, balance the fold allocation as far ## For each class, balance the fold allocation as far
## as possible, then resample the remainder. ## as possible, then resample the remainder.
## The final assignment of folds is also randomized. ## The final assignment of folds is also randomized.

View File

@ -118,23 +118,23 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
for (mc in metrics) { for (mc in metrics) {
params <- append(params, list("eval_metric"=mc)) params <- append(params, list("eval_metric"=mc))
} }
# customized objective and evaluation metric interface # customized objective and evaluation metric interface
if (!is.null(params$objective) && !is.null(obj)) if (!is.null(params$objective) && !is.null(obj))
stop("xgb.cv: cannot assign two different objectives") stop("xgb.cv: cannot assign two different objectives")
if (!is.null(params$objective)) if (!is.null(params$objective))
if (class(params$objective) == 'function') { if (class(params$objective) == 'function') {
obj = params$objective obj <- params$objective
params[['objective']] = NULL params[['objective']] <- NULL
} }
# if (!is.null(params$eval_metric) && !is.null(feval)) # if (!is.null(params$eval_metric) && !is.null(feval))
# stop("xgb.cv: cannot assign two different evaluation metrics") # stop("xgb.cv: cannot assign two different evaluation metrics")
if (!is.null(params$eval_metric)) if (!is.null(params$eval_metric))
if (class(params$eval_metric)=='function') { if (class(params$eval_metric) == 'function') {
feval = params$eval_metric feval <- params$eval_metric
params[['eval_metric']] = NULL params[['eval_metric']] <- NULL
} }
# Early Stopping # Early Stopping
if (!is.null(early.stop.round)){ if (!is.null(early.stop.round)){
if (!is.null(feval) && is.null(maximize)) if (!is.null(feval) && is.null(maximize))
@ -144,12 +144,12 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
if (is.null(maximize)) if (is.null(maximize))
{ {
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) { if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
maximize = FALSE maximize <- FALSE
} else { } else {
maximize = TRUE maximize <- TRUE
} }
} }
if (maximize) { if (maximize) {
bestScore <- 0 bestScore <- 0
} else { } else {
@ -157,26 +157,26 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
} }
bestInd <- 0 bestInd <- 0
earlyStopflag <- FALSE earlyStopflag <- FALSE
if (length(metrics)>1) if (length(metrics) > 1)
warning('Only the first metric is used for early stopping process.') warning('Only the first metric is used for early stopping process.')
} }
xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds) xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds)
obj_type <- params[['objective']] obj_type <- params[['objective']]
mat_pred <- FALSE mat_pred <- FALSE
if (!is.null(obj_type) && obj_type == 'multi:softprob') if (!is.null(obj_type) && obj_type == 'multi:softprob')
{ {
num_class = params[['num_class']] num_class <- params[['num_class']]
if (is.null(num_class)) if (is.null(num_class))
stop('must set num_class to use softmax') stop('must set num_class to use softmax')
predictValues <- matrix(0,xgb.numrow(dtrain),num_class) predictValues <- matrix(0,xgb.numrow(dtrain),num_class)
mat_pred = TRUE mat_pred <- TRUE
} }
else else
predictValues <- rep(0,xgb.numrow(dtrain)) predictValues <- rep(0,xgb.numrow(dtrain))
history <- c() history <- c()
print.every.n = max(as.integer(print.every.n), 1L) print.every.n <- max(as.integer(print.every.n), 1L)
for (i in 1:nrounds) { for (i in 1:nrounds) {
msg <- list() msg <- list()
for (k in 1:nfold) { for (k in 1:nfold) {
@ -187,28 +187,27 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
ret <- xgb.cv.aggcv(msg, showsd) ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret) history <- c(history, ret)
if(verbose) if(verbose)
if (0 == (i-1L)%%print.every.n) if (0 == (i - 1L) %% print.every.n)
cat(ret, "\n", sep="") cat(ret, "\n", sep="")
# early_Stopping # early_Stopping
if (!is.null(early.stop.round)){ if (!is.null(early.stop.round)){
score <- strsplit(ret,'\\s+')[[1]][1+length(metrics)+2] score <- strsplit(ret,'\\s+')[[1]][1 + length(metrics) + 2]
score <- strsplit(score,'\\+|:')[[1]][[2]] score <- strsplit(score,'\\+|:')[[1]][[2]]
score <- as.numeric(score) score <- as.numeric(score)
if ((maximize && score > bestScore) || (!maximize && score < bestScore)) { if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
bestScore <- score bestScore <- score
bestInd <- i bestInd <- i
} else { } else {
if (i-bestInd >= early.stop.round) { if (i - bestInd >= early.stop.round) {
earlyStopflag <- TRUE earlyStopflag <- TRUE
cat('Stopping. Best iteration:',bestInd) cat('Stopping. Best iteration:',bestInd)
break break
} }
} }
} }
} }
if (prediction) { if (prediction) {
for (k in 1:nfold) { for (k in 1:nfold) {
fd <- xgb_folds[[k]] fd <- xgb_folds[[k]]
@ -225,24 +224,23 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
} }
} }
} }
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".") colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
colnamesMean <- paste(colnames, "mean") colnamesMean <- paste(colnames, "mean")
if(showsd) colnamesStd <- paste(colnames, "std") if(showsd) colnamesStd <- paste(colnames, "std")
colnames <- c() colnames <- c()
if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i]) if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i])
else colnames <- colnamesMean else colnames <- colnamesMean
type <- rep(x = "numeric", times = length(colnames)) type <- rep(x = "numeric", times = length(colnames))
dt <- utils::read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table dt <- utils::read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table
split <- str_split(string = history, pattern = "\t") split <- str_split(string = history, pattern = "\t")
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.numeric %>% as.list %>% {rbindlist(list(dt, .), use.names = F, fill = F)} for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.numeric %>% as.list %>% {rbindlist( list( dt, .), use.names = F, fill = F)}
if (prediction) { if (prediction) {
return(list(dt = dt,pred = predictValues)) return( list( dt = dt,pred = predictValues))
} }
return(dt) return(dt)
} }

View File

@ -66,8 +66,8 @@
#' xgb.importance(train$data@@Dimnames[[2]], model = bst, data = train$data, label = train$label) #' xgb.importance(train$data@@Dimnames[[2]], model = bst, data = train$data, label = train$label)
#' #'
#' @export #' @export
xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ((x + label) == 2)){ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ( (x + label) == 2)){
if (!class(feature_names) %in% c("character", "NULL")) { if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
} }
@ -79,7 +79,7 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
} }
if((is.null(data) & !is.null(label)) |(!is.null(data) & is.null(label))) { if((is.null(data) & !is.null(label)) | (!is.null(data) & is.null(label))) {
stop("data/label: Provide the two arguments if you want co-occurence computation or none of them if you are not interested but not one of them only.") stop("data/label: Provide the two arguments if you want co-occurence computation or none of them if you are not interested but not one of them only.")
} }
@ -98,7 +98,7 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.") if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
} else { } else {
result <- treeDump(feature_names, text = text, keepDetail = !is.null(data)) result <- treeDump(feature_names, text = text, keepDetail = !is.null(data))
# Co-occurence computation # Co-occurence computation
if(!is.null(data) & !is.null(label) & nrow(result) > 0) { if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
# Take care of missing column # Take care of missing column
@ -109,9 +109,9 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N
# Apply split # Apply split
d <- data[, result[,Feature], drop=FALSE] < as.numeric(result[,Split]) d <- data[, result[,Feature], drop=FALSE] < as.numeric(result[,Split])
apply(c & d, 2, . %>% target %>% sum) -> vec apply(c & d, 2, . %>% target %>% sum) -> vec
result <- result[, "RealCover":= as.numeric(vec), with = F][, "RealCover %" := RealCover / sum(label)][,MissingNo:=NULL] result <- result[, "RealCover" := as.numeric(vec), with = F][, "RealCover %" := RealCover / sum(label)][,MissingNo := NULL]
} }
} }
result result
} }
@ -119,13 +119,13 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N
treeDump <- function(feature_names, text, keepDetail){ treeDump <- function(feature_names, text, keepDetail){
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature" if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo":= Missing == No ][Feature!="Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequence = .N), by = groupBy, with = T][,`:=`(Gain = Gain/sum(Gain), Cover = Cover/sum(Cover), Frequence = Frequence/sum(Frequence))][order(Gain, decreasing = T)] result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequence = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequence = Frequence / sum(Frequence))][order(Gain, decreasing = T)]
result result
} }
linearDump <- function(feature_names, text){ linearDump <- function(feature_names, text){
which(text == "weight:") %>% {a=.+1;text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .) which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
} }
# Avoid error messages during CRAN check. # Avoid error messages during CRAN check.

View File

@ -57,7 +57,7 @@
#' @export #' @export
xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, text = NULL, n_first_tree = NULL){ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, text = NULL, n_first_tree = NULL){
if (!class(feature_names) %in% c("character", "NULL")) { if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
} }
if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) { if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) {
@ -81,12 +81,12 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model
} }
if(!is.null(model)){ if(!is.null(model)){
text = xgb.dump(model = model, with.stats = T) text <- xgb.dump(model = model, with.stats = T)
} else if(!is.null(filename_dump)){ } else if(!is.null(filename_dump)){
text <- readLines(filename_dump) %>% str_trim(side = "both") text <- readLines(filename_dump) %>% str_trim(side = "both")
} }
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text)+1) position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1)
extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist
@ -96,16 +96,16 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model
allTrees <- data.table() allTrees <- data.table()
anynumber_regex<-"[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"
for(i in 1:n_round){ for (i in 1:n_round){
tree <- text[(position[i]+1):(position[i+1]-1)] tree <- text[(position[i] + 1):(position[i + 1] - 1)]
# avoid tree made of a leaf only (no split) # avoid tree made of a leaf only (no split)
if(length(tree) <2) next if(length(tree) < 2) next
treeID <- i-1 treeID <- i - 1
notLeaf <- str_match(tree, "leaf") %>% is.na notLeaf <- str_match(tree, "leaf") %>% is.na
leaf <- notLeaf %>% not %>% tree[.] leaf <- notLeaf %>% not %>% tree[.]
branch <- notLeaf %>% tree[.] branch <- notLeaf %>% tree[.]
@ -128,38 +128,38 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model
qualityLeaf <- extract(leaf, paste0("leaf=",anynumber_regex)) qualityLeaf <- extract(leaf, paste0("leaf=",anynumber_regex))
coverBranch <- extract(branch, "cover=\\d*\\.*\\d*") coverBranch <- extract(branch, "cover=\\d*\\.*\\d*")
coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*") coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*")
dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=treeID] dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree := treeID]
allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F) allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F)
} }
yes <- allTrees[!is.na(Yes), Yes] yes <- allTrees[!is.na(Yes), Yes]
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "Yes.Feature", j = "Yes.Feature",
value = allTrees[ID %in% yes, Feature]) value = allTrees[ID %in% yes, Feature])
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "Yes.Cover", j = "Yes.Cover",
value = allTrees[ID %in% yes, Cover]) value = allTrees[ID %in% yes, Cover])
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "Yes.Quality", j = "Yes.Quality",
value = allTrees[ID %in% yes, Quality]) value = allTrees[ID %in% yes, Quality])
no <- allTrees[!is.na(No), No] no <- allTrees[!is.na(No), No]
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "No.Feature", j = "No.Feature",
value = allTrees[ID %in% no, Feature]) value = allTrees[ID %in% no, Feature])
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "No.Cover", j = "No.Cover",
value = allTrees[ID %in% no, Cover]) value = allTrees[ID %in% no, Cover])
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),
j = "No.Quality", j = "No.Quality",
value = allTrees[ID %in% no, Quality]) value = allTrees[ID %in% no, Quality])
allTrees allTrees
} }

View File

@ -30,7 +30,7 @@
#' #'
#' @export #' @export
xgb.plot.importance <- function(importance_matrix = NULL, numberOfClusters = c(1:10)){ xgb.plot.importance <- function(importance_matrix = NULL, numberOfClusters = c(1:10)){
if (!"data.table" %in% class(importance_matrix)) { if (!"data.table" %in% class(importance_matrix)) {
stop("importance_matrix: Should be a data.table.") stop("importance_matrix: Should be a data.table.")
} }
if (!requireNamespace("ggplot2", quietly = TRUE)) { if (!requireNamespace("ggplot2", quietly = TRUE)) {
@ -42,13 +42,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, numberOfClusters = c(1
# To avoid issues in clustering when co-occurences are used # To avoid issues in clustering when co-occurences are used
importance_matrix <- importance_matrix[, .(Gain = sum(Gain)), by = Feature] importance_matrix <- importance_matrix[, .(Gain = sum(Gain)), by = Feature]
clusters <- suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters)) clusters <- suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters))
importance_matrix[,"Cluster":=clusters$cluster %>% as.character] importance_matrix[,"Cluster" := clusters$cluster %>% as.character]
plot <- ggplot2::ggplot(importance_matrix, ggplot2::aes(x=stats::reorder(Feature, Gain), y = Gain, width= 0.05), environment = environment())+ ggplot2::geom_bar(ggplot2::aes(fill=Cluster), stat="identity", position="identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab("Gain") + ggplot2::ggtitle("Feature importance") + ggplot2::theme(plot.title = ggplot2::element_text(lineheight=.9, face="bold"), panel.grid.major.y = ggplot2::element_blank() ) plot <- ggplot2::ggplot(importance_matrix, ggplot2::aes(x=stats::reorder(Feature, Gain), y = Gain, width = 0.05), environment = environment()) + ggplot2::geom_bar(ggplot2::aes(fill=Cluster), stat="identity", position="identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab("Gain") + ggplot2::ggtitle("Feature importance") + ggplot2::theme(plot.title = ggplot2::element_text(lineheight=.9, face="bold"), panel.grid.major.y = ggplot2::element_blank() )
return(plot) return(plot)
} }
# Avoid error messages during CRAN check. # Avoid error messages during CRAN check.

View File

@ -54,40 +54,39 @@
#' #'
#' @export #' @export
#' #'
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, CSSstyle = NULL, width = NULL, height = NULL){ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, CSSstyle = NULL, width = NULL, height = NULL){
if (!(class(CSSstyle) %in% c("character", "NULL") && length(CSSstyle) <= 1)) { if (!(class(CSSstyle) %in% c("character", "NULL") && length(CSSstyle) <= 1)) {
stop("style: Has to be a character vector of size 1.") stop("style: Has to be a character vector of size 1.")
} }
if (!class(model) %in% c("xgb.Booster", "NULL")) { if (!class(model) %in% c("xgb.Booster", "NULL")) {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
} }
if (!requireNamespace("DiagrammeR", quietly = TRUE)) { if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE) stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
} }
if(is.null(model)){ if(is.null(model)){
allTrees <- xgb.model.dt.tree(feature_names = feature_names, filename_dump = filename_dump, n_first_tree = n_first_tree) allTrees <- xgb.model.dt.tree(feature_names = feature_names, filename_dump = filename_dump, n_first_tree = n_first_tree)
} else { } else {
allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree) allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree)
} }
allTrees[Feature!="Leaf" ,yesPath:= paste(ID,"(", Feature, "<br/>Cover: ", Cover, "<br/>Gain: ", Quality, ")-->|< ", Split, "|", Yes, ">", Yes.Feature, "]", sep = "")] allTrees[Feature != "Leaf" ,yesPath := paste(ID,"(", Feature, "<br/>Cover: ", Cover, "<br/>Gain: ", Quality, ")-->|< ", Split, "|", Yes, ">", Yes.Feature, "]", sep = "")]
allTrees[Feature!="Leaf" ,noPath:= paste(ID,"(", Feature, ")-->|>= ", Split, "|", No, ">", No.Feature, "]", sep = "")] allTrees[Feature != "Leaf" ,noPath := paste(ID,"(", Feature, ")-->|>= ", Split, "|", No, ">", No.Feature, "]", sep = "")]
if(is.null(CSSstyle)){ if(is.null(CSSstyle)){
CSSstyle <- "classDef greenNode fill:#A2EB86, stroke:#04C4AB, stroke-width:2px;classDef redNode fill:#FFA070, stroke:#FF5E5E, stroke-width:2px" CSSstyle <- "classDef greenNode fill:#A2EB86, stroke:#04C4AB, stroke-width:2px;classDef redNode fill:#FFA070, stroke:#FF5E5E, stroke-width:2px"
} }
yes <- allTrees[Feature!="Leaf", c(Yes)] %>% paste(collapse = ",") %>% paste("class ", ., " greenNode", sep = "") yes <- allTrees[Feature != "Leaf", c(Yes)] %>% paste(collapse = ",") %>% paste("class ", ., " greenNode", sep = "")
no <- allTrees[Feature!="Leaf", c(No)] %>% paste(collapse = ",") %>% paste("class ", ., " redNode", sep = "") no <- allTrees[Feature != "Leaf", c(No)] %>% paste(collapse = ",") %>% paste("class ", ., " redNode", sep = "")
path <- allTrees[Feature!="Leaf", c(yesPath, noPath)] %>% .[order(.)] %>% paste(sep = "", collapse = ";") %>% paste("graph LR", .,collapse = "", sep = ";") %>% paste(CSSstyle, yes, no, sep = ";") path <- allTrees[Feature != "Leaf", c(yesPath, noPath)] %>% .[order(.)] %>% paste(sep = "", collapse = ";") %>% paste("graph LR", .,collapse = "", sep = ";") %>% paste(CSSstyle, yes, no, sep = ";")
DiagrammeR::mermaid(path, width, height) DiagrammeR::mermaid(path, width, height)
} }

View File

@ -29,4 +29,4 @@ xgb.save <- function(model, fname) {
stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save
xgb.DMatrix object.") xgb.DMatrix object.")
return(FALSE) return(FALSE)
} }

View File

@ -120,9 +120,9 @@
#' bst <- xgb.train(param, dtrain, nthread = 2, nround = 2, watchlist) #' bst <- xgb.train(param, dtrain, nthread = 2, nround = 2, watchlist)
#' @export #' @export
#' #'
xgb.train <- function(params=list(), data, nrounds, watchlist = list(), xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
obj = NULL, feval = NULL, verbose = 1, print.every.n=1L, obj = NULL, feval = NULL, verbose = 1, print.every.n=1L,
early.stop.round = NULL, maximize = NULL, early.stop.round = NULL, maximize = NULL,
save_period = 0, save_name = "xgboost.model", ...) { save_period = 0, save_name = "xgboost.model", ...) {
dtrain <- data dtrain <- data
if (typeof(params) != "list") { if (typeof(params) != "list") {
@ -139,30 +139,30 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
if (length(watchlist) != 0 && verbose == 0) { if (length(watchlist) != 0 && verbose == 0) {
warning('watchlist is provided but verbose=0, no evaluation information will be printed') warning('watchlist is provided but verbose=0, no evaluation information will be printed')
} }
dot.params = list(...) dot.params <- list(...)
nms.params = names(params) nms.params <- names(params)
nms.dot.params = names(dot.params) nms.dot.params <- names(dot.params)
if (length(intersect(nms.params,nms.dot.params))>0) if (length(intersect(nms.params,nms.dot.params)) > 0)
stop("Duplicated term in parameters. Please check your list of params.") stop("Duplicated term in parameters. Please check your list of params.")
params = append(params, dot.params) params <- append(params, dot.params)
# customized objective and evaluation metric interface # customized objective and evaluation metric interface
if (!is.null(params$objective) && !is.null(obj)) if (!is.null(params$objective) && !is.null(obj))
stop("xgb.train: cannot assign two different objectives") stop("xgb.train: cannot assign two different objectives")
if (!is.null(params$objective)) if (!is.null(params$objective))
if (class(params$objective)=='function') { if (class(params$objective) == 'function') {
obj = params$objective obj <- params$objective
params$objective = NULL params$objective <- NULL
} }
if (!is.null(params$eval_metric) && !is.null(feval)) if (!is.null(params$eval_metric) && !is.null(feval))
stop("xgb.train: cannot assign two different evaluation metrics") stop("xgb.train: cannot assign two different evaluation metrics")
if (!is.null(params$eval_metric)) if (!is.null(params$eval_metric))
if (class(params$eval_metric)=='function') { if (class(params$eval_metric) == 'function') {
feval = params$eval_metric feval <- params$eval_metric
params$eval_metric = NULL params$eval_metric <- NULL
} }
# Early stopping # Early stopping
if (!is.null(early.stop.round)){ if (!is.null(early.stop.round)){
if (!is.null(feval) && is.null(maximize)) if (!is.null(feval) && is.null(maximize))
@ -174,44 +174,43 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
if (is.null(maximize)) if (is.null(maximize))
{ {
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) { if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
maximize = FALSE maximize <- FALSE
} else { } else {
maximize = TRUE maximize <- TRUE
} }
} }
if (maximize) { if (maximize) {
bestScore = 0 bestScore <- 0
} else { } else {
bestScore = Inf bestScore <- Inf
} }
bestInd = 0 bestInd <- 0
earlyStopflag = FALSE earlyStopflag = FALSE
if (length(watchlist)>1) if (length(watchlist) > 1)
warning('Only the first data set in watchlist is used for early stopping process.') warning('Only the first data set in watchlist is used for early stopping process.')
} }
handle <- xgb.Booster(params, append(watchlist, dtrain)) handle <- xgb.Booster(params, append(watchlist, dtrain))
bst <- xgb.handleToBooster(handle) bst <- xgb.handleToBooster(handle)
print.every.n=max( as.integer(print.every.n), 1L) print.every.n <- max( as.integer(print.every.n), 1L)
for (i in 1:nrounds) { for (i in 1:nrounds) {
succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj) succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj)
if (length(watchlist) != 0) { if (length(watchlist) != 0) {
msg <- xgb.iter.eval(bst$handle, watchlist, i - 1, feval) msg <- xgb.iter.eval(bst$handle, watchlist, i - 1, feval)
if (0== ( (i-1) %% print.every.n)) if (0 == ( (i - 1) %% print.every.n))
cat(paste(msg, "\n", sep="")) cat(paste(msg, "\n", sep = ""))
if (!is.null(early.stop.round)) if (!is.null(early.stop.round))
{ {
score = strsplit(msg,':|\\s+')[[1]][3] score <- strsplit(msg,':|\\s+')[[1]][3]
score = as.numeric(score) score <- as.numeric(score)
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) { if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
bestScore = score bestScore <- score
bestInd = i bestInd <- i
} else { } else {
if (i-bestInd>=early.stop.round) { earlyStopflag = TRUE
earlyStopflag = TRUE if (i - bestInd >= early.stop.round) {
cat('Stopping. Best iteration:',bestInd) cat('Stopping. Best iteration:',bestInd)
break break
} }
@ -226,8 +225,8 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
} }
bst <- xgb.Booster.check(bst) bst <- xgb.Booster.check(bst)
if (!is.null(early.stop.round)) { if (!is.null(early.stop.round)) {
bst$bestScore = bestScore bst$bestScore <- bestScore
bst$bestInd = bestInd bst$bestInd <- bestInd
} }
return(bst) return(bst)
} }

View File

@ -59,28 +59,26 @@
#' #'
#' @export #' @export
#' #'
xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL, xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
params = list(), nrounds, params = list(), nrounds,
verbose = 1, print.every.n = 1L, early.stop.round = NULL, verbose = 1, print.every.n = 1L, early.stop.round = NULL,
maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) { maximize = NULL, save_period = 0, save_name = "xgboost.model", ...) {
dtrain <- xgb.get.DMatrix(data, label, missing, weight) dtrain <- xgb.get.DMatrix(data, label, missing, weight)
params <- append(params, list(...)) params <- append(params, list(...))
if (verbose > 0) { if (verbose > 0) {
watchlist <- list(train = dtrain) watchlist <- list(train = dtrain)
} else { } else {
watchlist <- list() watchlist <- list()
} }
bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose = verbose, print.every.n=print.every.n, bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose = verbose, print.every.n=print.every.n,
early.stop.round = early.stop.round, maximize = maximize, early.stop.round = early.stop.round, maximize = maximize,
save_period = save_period, save_name = save_name) save_period = save_period, save_name = save_name)
return(bst) return(bst)
} }
#' Training part from Mushroom Data Set #' Training part from Mushroom Data Set
#' #'
#' This data set is originally from the Mushroom data set, #' This data set is originally from the Mushroom data set,

View File

@ -4,30 +4,30 @@ context("basic functions")
data(agaricus.train, package='xgboost') data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost') data(agaricus.test, package='xgboost')
train = agaricus.train train <- agaricus.train
test = agaricus.test test <- agaricus.test
test_that("train and predict", { test_that("train and predict", {
bst = xgboost(data = train$data, label = train$label, max.depth = 2, bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 1, nthread = 2, nround = 2, objective = "binary:logistic") eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
pred = predict(bst, test$data) pred <- predict(bst, test$data)
}) })
test_that("early stopping", { test_that("early stopping", {
res = xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5, res <- xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE) early.stop.round = 3, maximize = FALSE)
expect_true(nrow(res)<20) expect_true(nrow(res) < 20)
bst = xgboost(data = train$data, label = train$label, max.depth = 2, bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE) early.stop.round = 3, maximize = FALSE)
pred = predict(bst, test$data) pred <- predict(bst, test$data)
}) })
test_that("save_period", { test_that("save_period", {
bst = xgboost(data = train$data, label = train$label, max.depth = 2, bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
save_period = 10, save_name = "xgb.model") save_period = 10, save_name = "xgb.model")
pred = predict(bst, test$data) pred <- predict(bst, test$data)
}) })

View File

@ -7,40 +7,40 @@ test_that("custom objective works", {
data(agaricus.test, package='xgboost') data(agaricus.test, package='xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
watchlist <- list(eval = dtest, train = dtrain) watchlist <- list(eval = dtest, train = dtrain)
num_round <- 2 num_round <- 2
logregobj <- function(preds, dtrain) { logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- getinfo(dtrain, "label")
preds <- 1/(1 + exp(-preds)) preds <- 1 / (1 + exp(-preds))
grad <- preds - labels grad <- preds - labels
hess <- preds * (1 - preds) hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess)) return(list(grad = grad, hess = hess))
} }
evalerror <- function(preds, dtrain) { evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0)))/length(labels) err <- as.numeric(sum(labels != (preds > 0))) / length(labels)
return(list(metric = "error", value = err)) return(list(metric = "error", value = err))
} }
param <- list(max.depth=2, eta=1, nthread = 2, silent=1, param <- list(max.depth=2, eta=1, nthread = 2, silent=1,
objective=logregobj, eval_metric=evalerror) objective=logregobj, eval_metric=evalerror)
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1064) expect_equal(length(bst$raw), 1064)
attr(dtrain, 'label') <- getinfo(dtrain, 'label') attr(dtrain, 'label') <- getinfo(dtrain, 'label')
logregobjattr <- function(preds, dtrain) { logregobjattr <- function(preds, dtrain) {
labels <- attr(dtrain, 'label') labels <- attr(dtrain, 'label')
preds <- 1/(1 + exp(-preds)) preds <- 1 / (1 + exp(-preds))
grad <- preds - labels grad <- preds - labels
hess <- preds * (1 - preds) hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess)) return(list(grad = grad, hess = hess))
} }
param <- list(max.depth=2, eta=1, nthread = 2, silent=1, param <- list(max.depth=2, eta=1, nthread = 2, silent = 1,
objective=logregobjattr, eval_metric=evalerror) objective = logregobjattr, eval_metric = evalerror)
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1064) expect_equal(length(bst$raw), 1064)

View File

@ -8,11 +8,11 @@ require(vcd)
data(Arthritis) data(Arthritis)
data(agaricus.train, package='xgboost') data(agaricus.train, package='xgboost')
df <- data.table(Arthritis, keep.rownames = F) df <- data.table(Arthritis, keep.rownames = F)
df[,AgeDiscret:= as.factor(round(Age/10,0))] df[,AgeDiscret := as.factor(round(Age / 10,0))]
df[,AgeCat:= as.factor(ifelse(Age > 30, "Old", "Young"))] df[,AgeCat := as.factor(ifelse(Age > 30, "Old", "Young"))]
df[,ID:=NULL] df[,ID := NULL]
sparse_matrix = sparse.model.matrix(Improved~.-1, data = df) sparse_matrix <- sparse.model.matrix(Improved~.-1, data = df)
output_vector = df[,Y:=0][Improved == "Marked",Y:=1][,Y] output_vector <- df[,Y := 0][Improved == "Marked",Y := 1][,Y]
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9, bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
eta = 1, nthread = 2, nround = 10,objective = "binary:logistic") eta = 1, nthread = 2, nround = 10,objective = "binary:logistic")

View File

@ -4,10 +4,10 @@ require(xgboost)
test_that("poisson regression works", { test_that("poisson regression works", {
data(mtcars) data(mtcars)
bst = xgboost(data=as.matrix(mtcars[,-11]),label=mtcars[,11], bst <- xgboost(data = as.matrix(mtcars[,-11]),label = mtcars[,11],
objective='count:poisson',nrounds=5) objective = 'count:poisson', nrounds=5)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
pred = predict(bst,as.matrix(mtcars[,-11])) pred <- predict(bst,as.matrix(mtcars[, -11]))
expect_equal(length(pred), 32) expect_equal(length(pred), 32)
sqrt(mean((pred-mtcars[,11])^2)) sqrt(mean( (pred - mtcars[,11]) ^ 2))
}) })