refinement of R package

This commit is contained in:
unknown
2014-08-27 12:57:37 -07:00
parent 0fe5470a4f
commit d747172d37
6 changed files with 89 additions and 60 deletions

View File

@@ -1,5 +1,5 @@
# constructing DMatrix
xgb.DMatrix <- function(data, missing=0.0, ...) {
xgb.DMatrix <- function(data, info=list(), missing=0.0, ...) {
if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), PACKAGE="xgboost")
} else if(is.matrix(data)) {
@@ -11,7 +11,7 @@ xgb.DMatrix <- function(data, missing=0.0, ...) {
}
dmat <- structure(handle, class="xgb.DMatrix")
info = list(...)
info = append(info,list(...))
if (length(info)==0)
return(dmat)
for (i in 1:length(info)) {

View File

@@ -0,0 +1,12 @@
# save model or DMatrix to file
xgb.DMatrix.save <- function(handle, fname) {
if (typeof(fname) != "character") {
stop("xgb.save: fname must be character")
}
if (class(handle) == "xgb.DMatrix") {
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE), PACKAGE="xgboost")
return(TRUE)
}
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
return(FALSE)
}

View File

@@ -7,10 +7,6 @@ xgb.save <- function(handle, fname) {
.Call("XGBoosterSaveModel_R", handle, fname, PACKAGE="xgboost")
return(TRUE)
}
if (class(handle) == "xgb.DMatrix") {
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE), PACKAGE="xgboost")
return(TRUE)
}
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
return(FALSE)
}

View File

@@ -1,49 +1,41 @@
# Main function for xgboost-package
xgboost = function(x=NULL,y=NULL,DMatrix=NULL, file=NULL, validation=NULL,
nrounds=10, obj=NULL, feval=NULL, margin=NULL, verbose = T, ...)
xgboost = function(data=NULL, label = NULL, params=list(), nrounds=10,
verbose = 1, ...)
{
if (!is.null(DMatrix))
dtrain = DMatrix
inClass = class(data)
if (inClass=='dgCMatrix' || inClass=='matrix')
{
if (is.null(label))
stop('xgboost: need label when data is a matrix')
dtrain = xgb.DMatrix(data, label=y)
}
else
{
if (is.null(x) && is.null(y))
{
if (is.null(file))
stop('xgboost need input data, either R objects, local files or DMatrix object.')
dtrain = xgb.DMatrix(file)
}
if (!is.null(label))
warning('xgboost: label will be ignored.')
if (inClass=='character')
dtrain = xgb.DMatrix(data)
else if (inClass=='xgb.DMatrix')
dtrain = data
else
dtrain = xgb.DMatrix(x, label=y)
if (!is.null(margin))
{
succ <- xgb.setinfo(dtrain, "base_margin", margin)
if (!succ)
warning('Attemp to use margin failed.')
}
stop('xgboost: Invalid input of data')
}
params = list(...)
if (verbose>1)
silent = 0
else
silent = 1
watchlist=list()
if (verbose)
{
if (!is.null(validation))
{
if (class(validation)!='xgb.DMatrix')
dtest = xgb.DMatrix(validation)
else
dtest = validation
watchlist = list(eval=dtest,train=dtrain)
}
else
watchlist = list(train=dtrain)
}
params = append(params, list(silent=silent))
params = append(params, list(...))
bst <- xgb.train(params, dtrain, nrounds, watchlist, obj, feval)
if (verbose>0)
watchlist = list(train=dtrain)
else
watchlist = list()
bst <- xgb.train(params, dtrain, nrounds, watchlist)
return(bst)
}