From 3abbd7b4c7d9b54cccb24d407e0d0d6999042761 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sat, 24 Oct 2015 16:39:58 -0400 Subject: [PATCH 1/3] Added test_lint to test code quality --- R-package/R/getinfo.xgb.DMatrix.R | 5 ++--- R-package/R/predict.xgb.Booster.handle.R | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/R-package/R/getinfo.xgb.DMatrix.R b/R-package/R/getinfo.xgb.DMatrix.R index 26523699a..dc734bce1 100644 --- a/R-package/R/getinfo.xgb.DMatrix.R +++ b/R-package/R/getinfo.xgb.DMatrix.R @@ -35,7 +35,7 @@ getinfo <- function(object, ...){ #' @param ... other parameters #' @rdname getinfo #' @method getinfo xgb.DMatrix -setMethod("getinfo", signature = "xgb.DMatrix", +setMethod("getinfo", signature = "xgb.DMatrix", definition = function(object, name) { if (typeof(name) != "character") { stop("xgb.getinfo: name must be character") @@ -43,7 +43,7 @@ setMethod("getinfo", signature = "xgb.DMatrix", if (class(object) != "xgb.DMatrix") { stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix") } - if (name != "label" && name != "weight" && + if (name != "label" && name != "weight" && name != "base_margin" && name != "nrow") { stop(paste("xgb.getinfo: unknown info name", name)) } @@ -54,4 +54,3 @@ setMethod("getinfo", signature = "xgb.DMatrix", } return(ret) }) - diff --git a/R-package/R/predict.xgb.Booster.handle.R b/R-package/R/predict.xgb.Booster.handle.R index 685318f12..5788283da 100644 --- a/R-package/R/predict.xgb.Booster.handle.R +++ b/R-package/R/predict.xgb.Booster.handle.R @@ -5,14 +5,14 @@ #' @param object Object of class "xgb.Boost.handle" #' @param ... Parameters pass to \code{predict.xgb.Booster} #' -setMethod("predict", signature = "xgb.Booster.handle", +setMethod("predict", signature = "xgb.Booster.handle", definition = function(object, ...) { if (class(object) != "xgb.Booster.handle"){ stop("predict: model in prediction must be of class xgb.Booster.handle") } - + bst <- xgb.handleToBooster(object) - + ret = predict(bst, ...) return(ret) }) From 537b34dc6fdd183ec68a6fd658a905fc185b6ad5 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sat, 24 Oct 2015 16:43:44 -0400 Subject: [PATCH 2/3] Code: Some Lint fixes --- R-package/R/predict.xgb.Booster.R | 7 +++---- R-package/R/predict.xgb.Booster.handle.R | 3 +-- R-package/R/setinfo.xgb.DMatrix.R | 2 +- R-package/R/slice.xgb.DMatrix.R | 6 +++--- R-package/R/utils.R | 25 ++++++++++++------------ R-package/R/xgb.cv.R | 18 ++++++++--------- 6 files changed, 29 insertions(+), 32 deletions(-) diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 902260258..9cc1867da 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -30,8 +30,8 @@ setClass("xgb.Booster", #' pred <- predict(bst, test$data) #' @export #' -setMethod("predict", signature = "xgb.Booster", - definition = function(object, newdata, missing = NA, +setMethod("predict", signature = "xgb.Booster", + definition = function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { if (class(object) != "xgb.Booster"){ stop("predict: model in prediction must be of class xgb.Booster") @@ -55,7 +55,7 @@ setMethod("predict", signature = "xgb.Booster", if (predleaf) { option <- option + 2 } - ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option), + ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option), as.integer(ntreelimit), PACKAGE = "xgboost") if (predleaf){ len <- getinfo(newdata, "nrow") @@ -68,4 +68,3 @@ setMethod("predict", signature = "xgb.Booster", } return(ret) }) - diff --git a/R-package/R/predict.xgb.Booster.handle.R b/R-package/R/predict.xgb.Booster.handle.R index 5788283da..3e4013b75 100644 --- a/R-package/R/predict.xgb.Booster.handle.R +++ b/R-package/R/predict.xgb.Booster.handle.R @@ -13,7 +13,6 @@ setMethod("predict", signature = "xgb.Booster.handle", bst <- xgb.handleToBooster(object) - ret = predict(bst, ...) + ret <- predict(bst, ...) return(ret) }) - diff --git a/R-package/R/setinfo.xgb.DMatrix.R b/R-package/R/setinfo.xgb.DMatrix.R index 61019d8e2..4bee161b7 100644 --- a/R-package/R/setinfo.xgb.DMatrix.R +++ b/R-package/R/setinfo.xgb.DMatrix.R @@ -32,7 +32,7 @@ setinfo <- function(object, ...){ #' @param ... other parameters #' @rdname setinfo #' @method setinfo xgb.DMatrix -setMethod("setinfo", signature = "xgb.DMatrix", +setMethod("setinfo", signature = "xgb.DMatrix", definition = function(object, name, info) { xgb.setinfo(object, name, info) }) diff --git a/R-package/R/slice.xgb.DMatrix.R b/R-package/R/slice.xgb.DMatrix.R index b70a8ee92..d8ef8cb9c 100644 --- a/R-package/R/slice.xgb.DMatrix.R +++ b/R-package/R/slice.xgb.DMatrix.R @@ -23,14 +23,14 @@ slice <- function(object, ...){ #' @param ... other parameters #' @rdname slice #' @method slice xgb.DMatrix -setMethod("slice", signature = "xgb.DMatrix", +setMethod("slice", signature = "xgb.DMatrix", definition = function(object, idxset, ...) { if (class(object) != "xgb.DMatrix") { stop("slice: first argument dtrain must be xgb.DMatrix") } - ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset, + ret <- .Call("XGDMatrixSliceDMatrix_R", object, idxset, PACKAGE = "xgboost") - + attr_list <- attributes(object) nr <- xgb.numrow(object) len <- sapply(attr_list,length) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index eecc5e260..459eb068e 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -17,28 +17,28 @@ xgb.setinfo <- function(dmat, name, info) { if (name == "label") { if (length(info)!=xgb.numrow(dmat)) stop("The length of labels must equal to the number of rows in the input data") - .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), + .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), PACKAGE = "xgboost") return(TRUE) } if (name == "weight") { if (length(info)!=xgb.numrow(dmat)) stop("The length of weights must equal to the number of rows in the input data") - .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), + .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), PACKAGE = "xgboost") return(TRUE) } if (name == "base_margin") { # if (length(info)!=xgb.numrow(dmat)) # stop("The length of base margin must equal to the number of rows in the input data") - .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), + .Call("XGDMatrixSetInfo_R", dmat, name, as.numeric(info), PACKAGE = "xgboost") return(TRUE) } if (name == "group") { if (sum(info)!=xgb.numrow(dmat)) stop("The sum of groups must equal to the number of rows in the input data") - .Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info), + .Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info), PACKAGE = "xgboost") return(TRUE) } @@ -68,7 +68,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) { if (typeof(modelfile) == "character") { .Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost") } else if (typeof(modelfile) == "raw") { - .Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost") + .Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost") } else { stop("xgb.Booster: modelfile must be character or raw vector") } @@ -142,8 +142,7 @@ xgb.iter.boost <- function(booster, dtrain, gpair) { if (class(dtrain) != "xgb.DMatrix") { stop("xgb.iter.update: second argument must be type xgb.DMatrix") } - .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, - PACKAGE = "xgboost") + .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, PACKAGE = "xgboost") return(TRUE) } @@ -159,7 +158,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { if (is.null(obj)) { .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, PACKAGE = "xgboost") - } else { + } else { pred <- predict(booster, dtrain) gpair <- obj(pred, dtrain) succ <- xgb.iter.boost(booster, dtrain, gpair) @@ -192,7 +191,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F } msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames, PACKAGE = "xgboost") - } else { + } else { msg <- paste("[", iter, "]", sep="") for (j in 1:length(watchlist)) { w <- watchlist[j] @@ -253,10 +252,10 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) { kstep <- length(randidx) %/% nfold folds <- list() for (i in 1:(nfold-1)) { - folds[[i]] = randidx[1:kstep] - randidx = setdiff(randidx, folds[[i]]) + folds[[i]] <- randidx[1:kstep] + randidx <- setdiff(randidx, folds[[i]]) } - folds[[nfold]] = randidx + folds[[nfold]] <- randidx } } ret <- list() @@ -270,7 +269,7 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) { } dtrain <- slice(dall, didx) bst <- xgb.Booster(param, list(dtrain, dtest)) - watchlist = list(train=dtrain, test=dtest) + watchlist <- list(train=dtrain, test=dtest) ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=folds[[k]]) } return (ret) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 9811bba38..173ebd279 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -91,15 +91,15 @@ #' print(history) #' @export #' -xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA, - prediction = FALSE, showsd = TRUE, metrics=list(), +xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NA, + prediction = FALSE, showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T, print.every.n=1L, early.stop.round = NULL, maximize = NULL, ...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } if(!is.null(folds)) { - if(class(folds)!="list" | length(folds) < 2) { + if(class(folds) != "list" | length(folds) < 2) { stop("folds must be a list with 2 or more elements that are vectors of indices for each CV-fold") } nfold <- length(folds) @@ -108,22 +108,22 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = stop("nfold must be bigger than 1") } dtrain <- xgb.get.DMatrix(data, label, missing) - dot.params = list(...) - nms.params = names(params) - nms.dot.params = names(dot.params) - if (length(intersect(nms.params,nms.dot.params))>0) + dot.params <- list(...) + nms.params <- names(params) + nms.dot.params <- names(dot.params) + if (length(intersect(nms.params,nms.dot.params)) > 0) stop("Duplicated defined term in parameters. Please check your list of params.") params <- append(params, dot.params) params <- append(params, list(silent=1)) for (mc in metrics) { params <- append(params, list("eval_metric"=mc)) } - + # customized objective and evaluation metric interface if (!is.null(params$objective) && !is.null(obj)) stop("xgb.cv: cannot assign two different objectives") if (!is.null(params$objective)) - if (class(params$objective)=='function') { + if (class(params$objective) == 'function') { obj = params$objective params[['objective']] = NULL } From 139feaf97aaae68866132cf2b18c98b1b3e1fc0d Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sat, 24 Oct 2015 16:50:03 -0400 Subject: [PATCH 3/3] Code: Lint fixes on trailing spaces --- R-package/R/xgb.DMatrix.R | 14 ++++---- R-package/R/xgb.DMatrix.save.R | 4 +-- R-package/R/xgb.cv.R | 58 ++++++++++++++++----------------- R-package/R/xgb.dump.R | 6 ++-- R-package/R/xgb.importance.R | 34 +++++++++---------- R-package/R/xgb.load.R | 6 ++-- R-package/R/xgb.model.dt.tree.R | 50 ++++++++++++++-------------- 7 files changed, 86 insertions(+), 86 deletions(-) diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 970fab394..20a3276c0 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -20,26 +20,26 @@ #' xgb.DMatrix <- function(data, info = list(), missing = NA, ...) { if (typeof(data) == "character") { - handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), + handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), PACKAGE = "xgboost") } else if (is.matrix(data)) { - handle <- .Call("XGDMatrixCreateFromMat_R", data, missing, + handle <- .Call("XGDMatrixCreateFromMat_R", data, missing, PACKAGE = "xgboost") } else if (class(data) == "dgCMatrix") { - handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x, + handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x, PACKAGE = "xgboost") } else { - stop(paste("xgb.DMatrix: does not support to construct from ", + stop(paste("xgb.DMatrix: does not support to construct from ", typeof(data))) } dmat <- structure(handle, class = "xgb.DMatrix") - + info <- append(info, list(...)) - if (length(info) == 0) + if (length(info) == 0) return(dmat) for (i in 1:length(info)) { p <- info[i] xgb.setinfo(dmat, names(p), p[[1]]) } return(dmat) -} +} diff --git a/R-package/R/xgb.DMatrix.save.R b/R-package/R/xgb.DMatrix.save.R index d58dc09de..7a9ac611d 100644 --- a/R-package/R/xgb.DMatrix.save.R +++ b/R-package/R/xgb.DMatrix.save.R @@ -18,10 +18,10 @@ xgb.DMatrix.save <- function(DMatrix, fname) { stop("xgb.save: fname must be character") } if (class(DMatrix) == "xgb.DMatrix") { - .Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE), + .Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE), PACKAGE = "xgboost") return(TRUE) } stop("xgb.DMatrix.save: the input must be xgb.DMatrix") return(FALSE) -} +} diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 173ebd279..3f1be704f 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -151,21 +151,21 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = } if (maximize) { - bestScore = 0 + bestScore <- 0 } else { - bestScore = Inf + bestScore <- Inf } - bestInd = 0 - earlyStopflag = FALSE + bestInd <- 0 + earlyStopflag <- FALSE if (length(metrics)>1) warning('Only the first metric is used for early stopping process.') } - + xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds) - obj_type = params[['objective']] - mat_pred = FALSE - if (!is.null(obj_type) && obj_type=='multi:softprob') + obj_type <- params[['objective']] + mat_pred <- FALSE + if (!is.null(obj_type) && obj_type == 'multi:softprob') { num_class = params[['num_class']] if (is.null(num_class)) @@ -187,20 +187,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = ret <- xgb.cv.aggcv(msg, showsd) history <- c(history, ret) if(verbose) - if (0==(i-1L)%%print.every.n) + if (0 == (i-1L)%%print.every.n) cat(ret, "\n", sep="") # early_Stopping if (!is.null(early.stop.round)){ - score = strsplit(ret,'\\s+')[[1]][1+length(metrics)+2] - score = strsplit(score,'\\+|:')[[1]][[2]] - score = as.numeric(score) - if ((maximize && score>bestScore) || (!maximize && score bestScore) || (!maximize && score < bestScore)) { + bestScore <- score + bestInd <- i } else { - if (i-bestInd>=early.stop.round) { - earlyStopflag = TRUE + if (i-bestInd >= early.stop.round) { + earlyStopflag <- TRUE cat('Stopping. Best iteration:',bestInd) break } @@ -211,36 +211,36 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = if (prediction) { for (k in 1:nfold) { - fd = xgb_folds[[k]] + fd <- xgb_folds[[k]] if (!is.null(early.stop.round) && earlyStopflag) { - res = xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction) + res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction) } else { - res = xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction) + res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction) } if (mat_pred) { - pred_mat = matrix(res[[2]],num_class,length(fd$index)) - predictValues[fd$index,] = t(pred_mat) + pred_mat <- matrix(res[[2]],num_class,length(fd$index)) + predictValues[fd$index,] <- t(pred_mat) } else { - predictValues[fd$index] = res[[2]] + predictValues[fd$index] <- res[[2]] } } } - - + + colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".") colnamesMean <- paste(colnames, "mean") if(showsd) colnamesStd <- paste(colnames, "std") - + colnames <- c() if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i]) else colnames <- colnamesMean - + type <- rep(x = "numeric", times = length(colnames)) dt <- utils::read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table split <- str_split(string = history, pattern = "\t") - + for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.numeric %>% as.list %>% {rbindlist(list(dt, .), use.names = F, fill = F)} - + if (prediction) { return(list(dt = dt,pred = predictValues)) } diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index fae1c7d2b..856ec0888 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -49,13 +49,13 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) { stop("fmap: argument must be type character (when provided)") } - + longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost") - + dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F) setnames(dt, "Lines") - + if(is.null(fname)) { result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)] return(result) diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index f7696d53e..0b0703587 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -66,42 +66,42 @@ #' xgb.importance(train$data@@Dimnames[[2]], model = bst, data = train$data, label = train$label) #' #' @export -xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ((x + label) == 2)){ - if (!class(feature_names) %in% c("character", "NULL")) { +xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ((x + label) == 2)){ + if (!class(feature_names) %in% c("character", "NULL")) { stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") } - + if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) { stop("filename_dump: Has to be a path to the model dump file.") } - + if (!class(model) %in% c("xgb.Booster", "NULL")) { stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") } - + if((is.null(data) & !is.null(label)) |(!is.null(data) & is.null(label))) { stop("data/label: Provide the two arguments if you want co-occurence computation or none of them if you are not interested but not one of them only.") } - + if(class(label) == "numeric"){ if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector") } - + if(is.null(model)){ - text <- readLines(filename_dump) + text <- readLines(filename_dump) } else { text <- xgb.dump(model = model, with.stats = T) - } - + } + if(text[2] == "bias:"){ result <- readLines(filename_dump) %>% linearDump(feature_names, .) if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.") } else { result <- treeDump(feature_names, text = text, keepDetail = !is.null(data)) - + # Co-occurence computation if(!is.null(data) & !is.null(label) & nrow(result) > 0) { - # Take care of missing column + # Take care of missing column a <- data[, result[MissingNo == T,Feature], drop=FALSE] != 0 # Bind the two Matrix and reorder columns c <- data[, result[MissingNo == F,Feature], drop=FALSE] %>% cBind(a,.) %>% .[,result[,Feature]] @@ -109,19 +109,19 @@ xgb.importance <- function(feature_names = NULL, filename_dump = NULL, model = N # Apply split d <- data[, result[,Feature], drop=FALSE] < as.numeric(result[,Split]) apply(c & d, 2, . %>% target %>% sum) -> vec - + result <- result[, "RealCover":= as.numeric(vec), with = F][, "RealCover %" := RealCover / sum(label)][,MissingNo:=NULL] - } + } } result } treeDump <- function(feature_names, text, keepDetail){ if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature" - + result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo":= Missing == No ][Feature!="Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequence = .N), by = groupBy, with = T][,`:=`(Gain = Gain/sum(Gain), Cover = Cover/sum(Cover), Frequence = Frequence/sum(Frequence))][order(Gain, decreasing = T)] - - result + + result } linearDump <- function(feature_names, text){ diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index b69a719cf..2a2598dd8 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -17,9 +17,9 @@ #' @export #' xgb.load <- function(modelfile) { - if (is.null(modelfile)) + if (is.null(modelfile)) stop("xgb.load: modelfile cannot be NULL") - + handle <- xgb.Booster(modelfile = modelfile) # re-use modelfile if it is raw so we donot need to serialize if (typeof(modelfile) == "raw") { @@ -29,4 +29,4 @@ xgb.load <- function(modelfile) { } bst <- xgb.Booster.check(bst) return(bst) -} +} diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index d083566a5..cef988962 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -56,8 +56,8 @@ #' #' @export xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, text = NULL, n_first_tree = NULL){ - - if (!class(feature_names) %in% c("character", "NULL")) { + + if (!class(feature_names) %in% c("character", "NULL")) { stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") } if (!(class(filename_dump) %in% c("character", "NULL") && length(filename_dump) <= 1)) { @@ -67,59 +67,59 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model } else if(is.null(filename_dump) && is.null(model) && is.null(text)){ stop("filename_dump & model & text: no path to dump model, no model, no text dump, have been provided.") } - + if (!class(model) %in% c("xgb.Booster", "NULL")) { stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") } - - if (!class(text) %in% c("character", "NULL")) { + + if (!class(text) %in% c("character", "NULL")) { stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") } - + if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) { stop("n_first_tree: Has to be a numeric vector of size 1.") } - + if(!is.null(model)){ text = xgb.dump(model = model, with.stats = T) } else if(!is.null(filename_dump)){ - text <- readLines(filename_dump) %>% str_trim(side = "both") + text <- readLines(filename_dump) %>% str_trim(side = "both") } - + position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text)+1) - + extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist - + n_round <- min(length(position) - 1, n_first_tree) - + addTreeId <- function(x, i) paste(i,x,sep = "-") - + allTrees <- data.table() - - anynumber_regex<-"[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" + + anynumber_regex<-"[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" for(i in 1:n_round){ - + tree <- text[(position[i]+1):(position[i+1]-1)] - + # avoid tree made of a leaf only (no split) if(length(tree) <2) next - + treeID <- i-1 - + notLeaf <- str_match(tree, "leaf") %>% is.na leaf <- notLeaf %>% not %>% tree[.] branch <- notLeaf %>% tree[.] idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) - featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric + featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric if(!is.null(feature_names)){ featureBranch <- feature_names[featureBranch + 1] } featureLeaf <- rep("Leaf", length(leaf)) - splitBranch <- str_extract(branch, paste0("<",anynumber_regex,"\\]")) %>% str_replace("<", "") %>% str_replace("\\]", "") - splitLeaf <- rep(NA, length(leaf)) + splitBranch <- str_extract(branch, paste0("<",anynumber_regex,"\\]")) %>% str_replace("<", "") %>% str_replace("\\]", "") + splitLeaf <- rep(NA, length(leaf)) yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(treeID) - yesLeaf <- rep(NA, length(leaf)) + yesLeaf <- rep(NA, length(leaf)) noBranch <- extract(branch, "no=\\d*") %>% addTreeId(treeID) noLeaf <- rep(NA, length(leaf)) missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(treeID) @@ -129,10 +129,10 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, model coverBranch <- extract(branch, "cover=\\d*\\.*\\d*") coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*") dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=treeID] - + allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F) } - + yes <- allTrees[!is.na(Yes), Yes] set(allTrees, i = which(allTrees[, Feature] != "Leaf"),