Merge pull request #561 from terrytangyuan/test

Added test for code quality check
This commit is contained in:
Tong He 2015-10-24 22:27:19 -07:00
commit 224f574420
13 changed files with 120 additions and 124 deletions

View File

@ -35,7 +35,7 @@ getinfo <- function(object, ...){
#' @param ... other parameters #' @param ... other parameters
#' @rdname getinfo #' @rdname getinfo
#' @method getinfo xgb.DMatrix #' @method getinfo xgb.DMatrix
setMethod("getinfo", signature = "xgb.DMatrix", setMethod("getinfo", signature = "xgb.DMatrix",
definition = function(object, name) { definition = function(object, name) {
if (typeof(name) != "character") { if (typeof(name) != "character") {
stop("xgb.getinfo: name must be character") stop("xgb.getinfo: name must be character")
@ -43,7 +43,7 @@ setMethod("getinfo", signature = "xgb.DMatrix",
if (class(object) != "xgb.DMatrix") { if (class(object) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix") stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix")
} }
if (name != "label" && name != "weight" && if (name != "label" && name != "weight" &&
name != "base_margin" && name != "nrow") { name != "base_margin" && name != "nrow") {
stop(paste("xgb.getinfo: unknown info name", name)) stop(paste("xgb.getinfo: unknown info name", name))
} }
@ -54,4 +54,3 @@ setMethod("getinfo", signature = "xgb.DMatrix",
} }
return(ret) return(ret)
}) })

View File

@ -30,8 +30,8 @@ setClass("xgb.Booster",
#' pred <- predict(bst, test$data) #' pred <- predict(bst, test$data)
#' @export #' @export
#' #'
setMethod("predict", signature = "xgb.Booster", setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, missing = NA, definition = function(object, newdata, missing = NA,
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
if (class(object) != "xgb.Booster"){ if (class(object) != "xgb.Booster"){
stop("predict: model in prediction must be of class xgb.Booster") stop("predict: model in prediction must be of class xgb.Booster")
@ -55,7 +55,7 @@ setMethod("predict", signature = "xgb.Booster",
if (predleaf) { if (predleaf) {
option <- option + 2 option <- option + 2
} }
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option), ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option),
as.integer(ntreelimit), PACKAGE = "xgboost") as.integer(ntreelimit), PACKAGE = "xgboost")
if (predleaf){ if (predleaf){
len <- getinfo(newdata, "nrow") len <- getinfo(newdata, "nrow")
@ -68,4 +68,3 @@ setMethod("predict", signature = "xgb.Booster",
} }
return(ret) return(ret)
}) })

View File

@ -5,15 +5,14 @@
#' @param object Object of class "xgb.Boost.handle" #' @param object Object of class "xgb.Boost.handle"
#' @param ... Parameters pass to \code{predict.xgb.Booster} #' @param ... Parameters pass to \code{predict.xgb.Booster}
#' #'
setMethod("predict", signature = "xgb.Booster.handle", setMethod("predict", signature = "xgb.Booster.handle",
definition = function(object, ...) { definition = function(object, ...) {
if (class(object) != "xgb.Booster.handle"){ if (class(object) != "xgb.Booster.handle"){
stop("predict: model in prediction must be of class xgb.Booster.handle") stop("predict: model in prediction must be of class xgb.Booster.handle")
} }
bst <- xgb.handleToBooster(object) bst <- xgb.handleToBooster(object)
ret = predict(bst, ...) ret <- predict(bst, ...)
return(ret) return(ret)
}) })

View File

@ -32,7 +32,7 @@ setinfo <- function(object, ...){
#' @param ... other parameters #' @param ... other parameters
#' @rdname setinfo #' @rdname setinfo
#' @method setinfo xgb.DMatrix #' @method setinfo xgb.DMatrix
setMethod("setinfo", signature = "xgb.DMatrix", setMethod("setinfo", signature = "xgb.DMatrix",
definition = function(object, name, info) { definition = function(object, name, info) {
xgb.setinfo(object, name, info) xgb.setinfo(object, name, info)
}) })

View File

@ -23,14 +23,14 @@ slice <- function(object, ...){
#' @param ... other parameters #' @param ... other parameters
#' @rdname slice #' @rdname slice
#' @method slice xgb.DMatrix #' @method slice xgb.DMatrix
setMethod("slice", signature = "xgb.DMatrix", setMethod("slice", signature = "xgb.DMatrix",
definition = function(object, idxset, ...) { definition = function(object, idxset, ...) {
if (class(object) != "xgb.DMatrix") { if (class(object) != "xgb.DMatrix") {
stop("slice: first argument dtrain must be xgb.DMatrix") stop("slice: first argument dtrain must be xgb.DMatrix")
} }
ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset, ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset,
PACKAGE = "xgboost") PACKAGE = "xgboost")
attr_list <- attributes(object) attr_list <- attributes(object)
nr <- xgb.numrow(object) nr <- xgb.numrow(object)
len <- sapply(attr_list,length) len <- sapply(attr_list,length)

View File

@ -17,28 +17,28 @@ xgb.setinfo <- function(dmat, name, info) {
if (name == "label") { if (name == "label") {
if (length(info)!=xgb.numrow(dmat)) if (length(info)!=xgb.numrow(dmat))
stop("The length of labels must equal to the number of rows in the input data") stop("The length of labels must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
if (name == "weight") { if (name == "weight") {
if (length(info)!=xgb.numrow(dmat)) if (length(info)!=xgb.numrow(dmat))
stop("The length of weights must equal to the number of rows in the input data") stop("The length of weights must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
if (name == "base_margin") { if (name == "base_margin") {
# if (length(info)!=xgb.numrow(dmat)) # if (length(info)!=xgb.numrow(dmat))
# stop("The length of base margin must equal to the number of rows in the input data") # stop("The length of base margin must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
if (name == "group") { if (name == "group") {
if (sum(info)!=xgb.numrow(dmat)) if (sum(info)!=xgb.numrow(dmat))
stop("The sum of groups must equal to the number of rows in the input data") stop("The sum of groups must equal to the number of rows in the input data")
.Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info), .Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
@ -68,7 +68,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
if (typeof(modelfile) == "character") { if (typeof(modelfile) == "character") {
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost") .Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
} else if (typeof(modelfile) == "raw") { } else if (typeof(modelfile) == "raw") {
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost") .Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
} else { } else {
stop("xgb.Booster: modelfile must be character or raw vector") stop("xgb.Booster: modelfile must be character or raw vector")
} }
@ -142,8 +142,7 @@ xgb.iter.boost <- function(booster, dtrain, gpair) {
if (class(dtrain) != "xgb.DMatrix") { if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix") stop("xgb.iter.update: second argument must be type xgb.DMatrix")
} }
.Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, PACKAGE = "xgboost")
PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
@ -159,7 +158,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
if (is.null(obj)) { if (is.null(obj)) {
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain,
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else { } else {
pred <- predict(booster, dtrain) pred <- predict(booster, dtrain)
gpair <- obj(pred, dtrain) gpair <- obj(pred, dtrain)
succ <- xgb.iter.boost(booster, dtrain, gpair) succ <- xgb.iter.boost(booster, dtrain, gpair)
@ -192,7 +191,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F
} }
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist,
evnames, PACKAGE = "xgboost") evnames, PACKAGE = "xgboost")
} else { } else {
msg <- paste("[", iter, "]", sep="") msg <- paste("[", iter, "]", sep="")
for (j in 1:length(watchlist)) { for (j in 1:length(watchlist)) {
w <- watchlist[j] w <- watchlist[j]
@ -253,10 +252,10 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
kstep <- length(randidx) %/% nfold kstep <- length(randidx) %/% nfold
folds <- list() folds <- list()
for (i in 1:(nfold-1)) { for (i in 1:(nfold-1)) {
folds[[i]] = randidx[1:kstep] folds[[i]] <- randidx[1:kstep]
randidx = setdiff(randidx, folds[[i]]) randidx <- setdiff(randidx, folds[[i]])
} }
folds[[nfold]] = randidx folds[[nfold]] <- randidx
} }
} }
ret <- list() ret <- list()
@ -270,7 +269,7 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
} }
dtrain <- slice(dall, didx) dtrain <- slice(dall, didx)
bst <- xgb.Booster(param, list(dtrain, dtest)) bst <- xgb.Booster(param, list(dtrain, dtest))
watchlist = list(train=dtrain, test=dtest) watchlist <- list(train=dtrain, test=dtest)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=folds[[k]]) ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=folds[[k]])
} }
return (ret) return (ret)

View File

@ -20,26 +20,26 @@
#' #'
xgb.DMatrix <- function(data, info = list(), missing = NA, ...) { xgb.DMatrix <- function(data, info = list(), missing = NA, ...) {
if (typeof(data) == "character") { if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE),
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else if (is.matrix(data)) { } else if (is.matrix(data)) {
handle <- .Call("XGDMatrixCreateFromMat_R", data, missing, handle <- .Call("XGDMatrixCreateFromMat_R", data, missing,
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else if (class(data) == "dgCMatrix") { } else if (class(data) == "dgCMatrix") {
handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x, handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x,
PACKAGE = "xgboost") PACKAGE = "xgboost")
} else { } else {
stop(paste("xgb.DMatrix: does not support to construct from ", stop(paste("xgb.DMatrix: does not support to construct from ",
typeof(data))) typeof(data)))
} }
dmat <- structure(handle, class = "xgb.DMatrix") dmat <- structure(handle, class = "xgb.DMatrix")
info <- append(info, list(...)) info <- append(info, list(...))
if (length(info) == 0) if (length(info) == 0)
return(dmat) return(dmat)
for (i in 1:length(info)) { for (i in 1:length(info)) {
p <- info[i] p <- info[i]
xgb.setinfo(dmat, names(p), p[[1]]) xgb.setinfo(dmat, names(p), p[[1]])
} }
return(dmat) return(dmat)
} }

View File

@ -18,10 +18,10 @@ xgb.DMatrix.save <- function(DMatrix, fname) {
stop("xgb.save: fname must be character") stop("xgb.save: fname must be character")
} }
if (class(DMatrix) == "xgb.DMatrix") { if (class(DMatrix) == "xgb.DMatrix") {
.Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE), .Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE),
PACKAGE = "xgboost") PACKAGE = "xgboost")
return(TRUE) return(TRUE)
} }
stop("xgb.DMatrix.save: the input must be xgb.DMatrix") stop("xgb.DMatrix.save: the input must be xgb.DMatrix")
return(FALSE) return(FALSE)
} }

View File

@ -91,15 +91,15 @@
#' print(history) #' print(history)
#' @export #' @export
#' #'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA, xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA,
prediction = FALSE, showsd = TRUE, metrics=list(), prediction = FALSE, showsd = TRUE, metrics=list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L, obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L,
early.stop.round = NULL, maximize = NULL, ...) { early.stop.round = NULL, maximize = NULL, ...) {
if (typeof(params) != "list") { if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list") stop("xgb.cv: first argument params must be list")
} }
if(!is.null(folds)) { if(!is.null(folds)) {
if(class(folds)!="list" | length(folds) < 2) { if(class(folds) != "list" | length(folds) < 2) {
stop("folds must be a list with 2 or more elements that are vectors of indices for each CV-fold") stop("folds must be a list with 2 or more elements that are vectors of indices for each CV-fold")
} }
nfold <- length(folds) nfold <- length(folds)
@ -108,22 +108,22 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
stop("nfold must be bigger than 1") stop("nfold must be bigger than 1")
} }
dtrain <- xgb.get.DMatrix(data, label, missing) dtrain <- xgb.get.DMatrix(data, label, missing)
dot.params = list(...) dot.params <- list(...)
nms.params = names(params) nms.params <- names(params)
nms.dot.params = names(dot.params) nms.dot.params <- names(dot.params)
if (length(intersect(nms.params,nms.dot.params))>0) if (length(intersect(nms.params,nms.dot.params)) > 0)
stop("Duplicated defined term in parameters. Please check your list of params.") stop("Duplicated defined term in parameters. Please check your list of params.")
params <- append(params, dot.params) params <- append(params, dot.params)
params <- append(params, list(silent=1)) params <- append(params, list(silent=1))
for (mc in metrics) { for (mc in metrics) {
params <- append(params, list("eval_metric"=mc)) params <- append(params, list("eval_metric"=mc))
} }
# customized objective and evaluation metric interface # customized objective and evaluation metric interface
if (!is.null(params$objective) && !is.null(obj)) if (!is.null(params$objective) && !is.null(obj))
stop("xgb.cv: cannot assign two different objectives") stop("xgb.cv: cannot assign two different objectives")
if (!is.null(params$objective)) if (!is.null(params$objective))
if (class(params$objective)=='function') { if (class(params$objective) == 'function') {
obj = params$objective obj = params$objective
params[['objective']] = NULL params[['objective']] = NULL
} }
@ -151,21 +151,21 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
} }
if (maximize) { if (maximize) {
bestScore = 0 bestScore <- 0
} else { } else {
bestScore = Inf bestScore <- Inf
} }
bestInd = 0 bestInd <- 0
earlyStopflag = FALSE earlyStopflag <- FALSE
if (length(metrics)>1) if (length(metrics)>1)
warning('Only the first metric is used for early stopping process.') warning('Only the first metric is used for early stopping process.')
} }
xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds) xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds)
obj_type = params[['objective']] obj_type <- params[['objective']]
mat_pred = FALSE mat_pred <- FALSE
if (!is.null(obj_type) && obj_type=='multi:softprob') if (!is.null(obj_type) && obj_type == 'multi:softprob')
{ {
num_class = params[['num_class']] num_class = params[['num_class']]
if (is.null(num_class)) if (is.null(num_class))
@ -187,20 +187,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
ret <- xgb.cv.aggcv(msg, showsd) ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret) history <- c(history, ret)
if(verbose) if(verbose)
if (0==(i-1L)%%print.every.n) if (0 == (i-1L)%%print.every.n)
cat(ret, "\n", sep="") cat(ret, "\n", sep="")
# early_Stopping # early_Stopping
if (!is.null(early.stop.round)){ if (!is.null(early.stop.round)){
score = strsplit(ret,'\\s+')[[1]][1+length(metrics)+2] score <- strsplit(ret,'\\s+')[[1]][1+length(metrics)+2]
score = strsplit(score,'\\+|:')[[1]][[2]] score <- strsplit(score,'\\+|:')[[1]][[2]]
score = as.numeric(score) score <- as.numeric(score)
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) { if ((maximize && score > bestScore) || (!maximize && score < bestScore)) {
bestScore = score bestScore <- score
bestInd = i bestInd <- i
} else { } else {
if (i-bestInd>=early.stop.round) { if (i-bestInd >= early.stop.round) {
earlyStopflag = TRUE earlyStopflag <- TRUE
cat('Stopping. Best iteration:',bestInd) cat('Stopping. Best iteration:',bestInd)
break break
} }
@ -211,36 +211,36 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
if (prediction) { if (prediction) {
for (k in 1:nfold) { for (k in 1:nfold) {
fd = xgb_folds[[k]] fd <- xgb_folds[[k]]
if (!is.null(early.stop.round) && earlyStopflag) { if (!is.null(early.stop.round) && earlyStopflag) {
res = xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction) res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction)
} else { } else {
res = xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction) res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction)
} }
if (mat_pred) { if (mat_pred) {
pred_mat = matrix(res[[2]],num_class,length(fd$index)) pred_mat <- matrix(res[[2]],num_class,length(fd$index))
predictValues[fd$index,] = t(pred_mat) predictValues[fd$index,] <- t(pred_mat)
} else { } else {
predictValues[fd$index] = res[[2]] predictValues[fd$index] <- res[[2]]
} }
} }
} }
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".") colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
colnamesMean <- paste(colnames, "mean") colnamesMean <- paste(colnames, "mean")
if(showsd) colnamesStd <- paste(colnames, "std") if(showsd) colnamesStd <- paste(colnames, "std")
colnames <- c() colnames <- c()
if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i]) if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i])
else colnames <- colnamesMean else colnames <- colnamesMean
type <- rep(x = "numeric", times = length(colnames)) type <- rep(x = "numeric", times = length(colnames))
dt <- utils::read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table dt <- utils::read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table
split <- str_split(string = history, pattern = "\t") split <- str_split(string = history, pattern = "\t")
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.numeric %>% as.list %>% {rbindlist(list(dt, .), use.names = F, fill = F)} for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.numeric %>% as.list %>% {rbindlist(list(dt, .), use.names = F, fill = F)}
if (prediction) { if (prediction) {
return(list(dt = dt,pred = predictValues)) return(list(dt = dt,pred = predictValues))
} }

View File

@ -49,13 +49,13 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) { if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) {
stop("fmap: argument must be type character (when provided)") stop("fmap: argument must be type character (when provided)")
} }
longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost") longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost")
dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F) dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F)
setnames(dt, "Lines") setnames(dt, "Lines")
if(is.null(fname)) { if(is.null(fname)) {
result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)] result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)]
return(result) return(result)

View File

@ -66,42 +66,42 @@
#' xgb.importance(train$data@@Dimnames[[2]], model = bst, data = train$data, label = train$label) #' xgb.importance(train$data@@Dimnames[[2]], model = bst, data = train$data, label = train$label)
#' #'
#' @export #' @export
xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ((x + label) == 2)){ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ((x + label) == 2)){
if (!class(feature_names) %in% c("character", "NULL")) { if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
} }
if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) { if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) {
stop("filename_dump: Has to be a path to the model dump file.") stop("filename_dump: Has to be a path to the model dump file.")
} }
if (!class(model) %in% c("xgb.Booster", "NULL")) { if (!class(model) %in% c("xgb.Booster", "NULL")) {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
} }
if((is.null(data) & !is.null(label)) |(!is.null(data) & is.null(label))) { if((is.null(data) & !is.null(label)) |(!is.null(data) & is.null(label))) {
stop("data/label: Provide the two arguments if you want co-occurence computation or none of them if you are not interested but not one of them only.") stop("data/label: Provide the two arguments if you want co-occurence computation or none of them if you are not interested but not one of them only.")
} }
if(class(label) == "numeric"){ if(class(label) == "numeric"){
if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector") if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector")
} }
if(is.null(model)){ if(is.null(model)){
text <- readLines(filename_dump) text <- readLines(filename_dump)
} else { } else {
text <- xgb.dump(model = model, with.stats = T) text <- xgb.dump(model = model, with.stats = T)
} }
if(text[2] == "bias:"){ if(text[2] == "bias:"){
result <- readLines(filename_dump) %>% linearDump(feature_names, .) result <- readLines(filename_dump) %>% linearDump(feature_names, .)
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.") if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
} else { } else {
result <- treeDump(feature_names, text = text, keepDetail = !is.null(data)) result <- treeDump(feature_names, text = text, keepDetail = !is.null(data))
# Co-occurence computation # Co-occurence computation
if(!is.null(data) & !is.null(label) & nrow(result) > 0) { if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
# Take care of missing column # Take care of missing column
a <- data[, result[MissingNo == T,Feature], drop=FALSE] != 0 a <- data[, result[MissingNo == T,Feature], drop=FALSE] != 0
# Bind the two Matrix and reorder columns # Bind the two Matrix and reorder columns
c <- data[, result[MissingNo == F,Feature], drop=FALSE] %>% cBind(a,.) %>% .[,result[,Feature]] c <- data[, result[MissingNo == F,Feature], drop=FALSE] %>% cBind(a,.) %>% .[,result[,Feature]]
@ -109,19 +109,19 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N
# Apply split # Apply split
d <- data[, result[,Feature], drop=FALSE] < as.numeric(result[,Split]) d <- data[, result[,Feature], drop=FALSE] < as.numeric(result[,Split])
apply(c & d, 2, . %>% target %>% sum) -> vec apply(c & d, 2, . %>% target %>% sum) -> vec
result <- result[, "RealCover":= as.numeric(vec), with = F][, "RealCover %" := RealCover / sum(label)][,MissingNo:=NULL] result <- result[, "RealCover":= as.numeric(vec), with = F][, "RealCover %" := RealCover / sum(label)][,MissingNo:=NULL]
} }
} }
result result
} }
treeDump <- function(feature_names, text, keepDetail){ treeDump <- function(feature_names, text, keepDetail){
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature" if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo":= Missing == No ][Feature!="Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequence = .N), by = groupBy, with = T][,`:=`(Gain = Gain/sum(Gain), Cover = Cover/sum(Cover), Frequence = Frequence/sum(Frequence))][order(Gain, decreasing = T)] result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo":= Missing == No ][Feature!="Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequence = .N), by = groupBy, with = T][,`:=`(Gain = Gain/sum(Gain), Cover = Cover/sum(Cover), Frequence = Frequence/sum(Frequence))][order(Gain, decreasing = T)]
result result
} }
linearDump <- function(feature_names, text){ linearDump <- function(feature_names, text){

View File

@ -17,9 +17,9 @@
#' @export #' @export
#' #'
xgb.load <- function(modelfile) { xgb.load <- function(modelfile) {
if (is.null(modelfile)) if (is.null(modelfile))
stop("xgb.load: modelfile cannot be NULL") stop("xgb.load: modelfile cannot be NULL")
handle <- xgb.Booster(modelfile = modelfile) handle <- xgb.Booster(modelfile = modelfile)
# re-use modelfile if it is raw so we donot need to serialize # re-use modelfile if it is raw so we donot need to serialize
if (typeof(modelfile) == "raw") { if (typeof(modelfile) == "raw") {
@ -29,4 +29,4 @@ xgb.load <- function(modelfile) {
} }
bst <- xgb.Booster.check(bst) bst <- xgb.Booster.check(bst)
return(bst) return(bst)
} }

View File

@ -56,8 +56,8 @@
#' #'
#' @export #' @export
xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, text = NULL, n_first_tree = NULL){ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, text = NULL, n_first_tree = NULL){
if (!class(feature_names) %in% c("character", "NULL")) { if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
} }
if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) { if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) {
@ -67,59 +67,59 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model
} else if(is.null(filename_dump) && is.null(model) && is.null(text)){ } else if(is.null(filename_dump) && is.null(model) && is.null(text)){
stop("filename_dump & model & text: no path to dump model, no model, no text dump, have been provided.") stop("filename_dump & model & text: no path to dump model, no model, no text dump, have been provided.")
} }
if (!class(model) %in% c("xgb.Booster", "NULL")) { if (!class(model) %in% c("xgb.Booster", "NULL")) {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
} }
if (!class(text) %in% c("character", "NULL")) { if (!class(text) %in% c("character", "NULL")) {
stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.")
} }
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) { if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
stop("n_first_tree: Has to be a numeric vector of size 1.") stop("n_first_tree: Has to be a numeric vector of size 1.")
} }
if(!is.null(model)){ if(!is.null(model)){
text = xgb.dump(model = model, with.stats = T) text = xgb.dump(model = model, with.stats = T)
} else if(!is.null(filename_dump)){ } else if(!is.null(filename_dump)){
text <- readLines(filename_dump) %>% str_trim(side = "both") text <- readLines(filename_dump) %>% str_trim(side = "both")
} }
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text)+1) position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text)+1)
extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist
n_round <- min(length(position) - 1, n_first_tree) n_round <- min(length(position) - 1, n_first_tree)
addTreeId <- function(x, i) paste(i,x,sep = "-") addTreeId <- function(x, i) paste(i,x,sep = "-")
allTrees <- data.table() allTrees <- data.table()
anynumber_regex<-"[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" anynumber_regex<-"[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"
for(i in 1:n_round){ for(i in 1:n_round){
tree <- text[(position[i]+1):(position[i+1]-1)] tree <- text[(position[i]+1):(position[i+1]-1)]
# avoid tree made of a leaf only (no split) # avoid tree made of a leaf only (no split)
if(length(tree) <2) next if(length(tree) <2) next
treeID <- i-1 treeID <- i-1
notLeaf <- str_match(tree, "leaf") %>% is.na notLeaf <- str_match(tree, "leaf") %>% is.na
leaf <- notLeaf %>% not %>% tree[.] leaf <- notLeaf %>% not %>% tree[.]
branch <- notLeaf %>% tree[.] branch <- notLeaf %>% tree[.]
idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID)
idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID)
featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric
if(!is.null(feature_names)){ if(!is.null(feature_names)){
featureBranch <- feature_names[featureBranch + 1] featureBranch <- feature_names[featureBranch + 1]
} }
featureLeaf <- rep("Leaf", length(leaf)) featureLeaf <- rep("Leaf", length(leaf))
splitBranch <- str_extract(branch, paste0("<",anynumber_regex,"\\]")) %>% str_replace("<", "") %>% str_replace("\\]", "") splitBranch <- str_extract(branch, paste0("<",anynumber_regex,"\\]")) %>% str_replace("<", "") %>% str_replace("\\]", "")
splitLeaf <- rep(NA, length(leaf)) splitLeaf <- rep(NA, length(leaf))
yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(treeID) yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(treeID)
yesLeaf <- rep(NA, length(leaf)) yesLeaf <- rep(NA, length(leaf))
noBranch <- extract(branch, "no=\\d*") %>% addTreeId(treeID) noBranch <- extract(branch, "no=\\d*") %>% addTreeId(treeID)
noLeaf <- rep(NA, length(leaf)) noLeaf <- rep(NA, length(leaf))
missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(treeID) missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(treeID)
@ -129,10 +129,10 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model
coverBranch <- extract(branch, "cover=\\d*\\.*\\d*") coverBranch <- extract(branch, "cover=\\d*\\.*\\d*")
coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*") coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*")
dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=treeID] dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=treeID]
allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F) allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F)
} }
yes <- allTrees[!is.na(Yes), Yes] yes <- allTrees[!is.na(Yes), Yes]
set(allTrees, i = which(allTrees[, Feature] != "Leaf"), set(allTrees, i = which(allTrees[, Feature] != "Leaf"),