Fix bug in Cross Validation when showsd = FALSE
This commit is contained in:
parent
9f5929497a
commit
a17e29b130
@ -114,10 +114,11 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
|
||||
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
|
||||
colnamesMean <- paste(colnames, "mean")
|
||||
colnamesStd <- paste(colnames, "std")
|
||||
if(showsd) colnamesStd <- paste(colnames, "std")
|
||||
|
||||
colnames <- c()
|
||||
for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i])
|
||||
if(showsd) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i])
|
||||
else colnames <- colnamesMean
|
||||
|
||||
type <- rep(x = "numeric", times = length(colnames))
|
||||
dt <- read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table
|
||||
|
||||
@ -19,7 +19,7 @@ cat('running cross validation, disable standard deviation display\n')
|
||||
# [iteration] metric_name:mean_value+std_value
|
||||
# std_value is standard deviation of the metric
|
||||
xgb.cv(param, dtrain, nround, nfold=5,
|
||||
metrics={'error'}, , showsd = FALSE)
|
||||
metrics={'error'}, showsd = FALSE)
|
||||
|
||||
###
|
||||
# you can also do cross validation with cutomized loss function
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user