modify xgb.getinfo to getinfo

This commit is contained in:
unknown
2014-08-27 15:03:24 -07:00
parent a060a2e9a6
commit 0f0c12707c
14 changed files with 268 additions and 256 deletions

View File

@@ -1,38 +1,47 @@
# train a model using given parameters
xgb.train <- function(params, dtrain, nrounds=10, watchlist=list(), obj=NULL, feval=NULL) {
if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list");
xgb.train <- function(params=list(), dtrain, nrounds = 10, watchlist = list(),
obj = NULL, feval = NULL, ...) {
if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.train: second argument dtrain must be xgb.DMatrix")
}
params = append(params, list(...))
bst <- xgb.Booster(params, append(watchlist, dtrain))
for (i in 1:nrounds) {
if (is.null(obj)) {
succ <- xgb.iter.update(bst, dtrain, i - 1)
} else {
pred <- xgb.predict(bst, dtrain)
gpair <- obj(pred, dtrain)
succ <- xgb.iter.boost(bst, dtrain, gpair)
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.train: second argument dtrain must be xgb.DMatrix");
}
bst <- xgb.Booster(params, append(watchlist,dtrain))
for (i in 1:nrounds) {
if (is.null(obj)) {
succ <- xgb.iter.update(bst, dtrain, i-1)
} else {
pred <- xgb.predict(bst, dtrain)
gpair <- obj(pred, dtrain)
succ <- xgb.iter.boost(bst, dtrain, gpair)
}
if (length(watchlist) != 0) {
if (is.null(feval)) {
msg <- xgb.iter.eval(bst, watchlist, i-1)
cat(msg); cat("\n")
} else {
cat("["); cat(i); cat("]");
for (j in 1:length(watchlist)) {
w <- watchlist[j]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
ret <- feval(xgb.predict(bst, w[[1]]), w[[1]])
cat("\t"); cat(names(w)); cat("-"); cat(ret$metric);
cat(":"); cat(ret$value)
}
cat("\n")
}
if (length(watchlist) != 0) {
if (is.null(feval)) {
msg <- xgb.iter.eval(bst, watchlist, i - 1)
cat(msg)
cat("\n")
} else {
cat("[")
cat(i)
cat("]")
for (j in 1:length(watchlist)) {
w <- watchlist[j]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
ret <- feval(xgb.predict(bst, w[[1]]), w[[1]])
cat("\t")
cat(names(w))
cat("-")
cat(ret$metric)
cat(":")
cat(ret$value)
}
cat("\n")
}
}
return(bst)
}
}
return(bst)
}