refinement of R package
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# constructing DMatrix
|
||||
xgb.DMatrix <- function(data, missing=0.0, ...) {
|
||||
xgb.DMatrix <- function(data, info=list(), missing=0.0, ...) {
|
||||
if (typeof(data) == "character") {
|
||||
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE), PACKAGE="xgboost")
|
||||
} else if(is.matrix(data)) {
|
||||
@@ -11,7 +11,7 @@ xgb.DMatrix <- function(data, missing=0.0, ...) {
|
||||
}
|
||||
dmat <- structure(handle, class="xgb.DMatrix")
|
||||
|
||||
info = list(...)
|
||||
info = append(info,list(...))
|
||||
if (length(info)==0)
|
||||
return(dmat)
|
||||
for (i in 1:length(info)) {
|
||||
|
||||
12
R-package/R/xgb.DMatrix.save.R
Normal file
12
R-package/R/xgb.DMatrix.save.R
Normal file
@@ -0,0 +1,12 @@
|
||||
# save model or DMatrix to file
|
||||
xgb.DMatrix.save <- function(handle, fname) {
|
||||
if (typeof(fname) != "character") {
|
||||
stop("xgb.save: fname must be character")
|
||||
}
|
||||
if (class(handle) == "xgb.DMatrix") {
|
||||
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE), PACKAGE="xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
|
||||
return(FALSE)
|
||||
}
|
||||
@@ -7,10 +7,6 @@ xgb.save <- function(handle, fname) {
|
||||
.Call("XGBoosterSaveModel_R", handle, fname, PACKAGE="xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
if (class(handle) == "xgb.DMatrix") {
|
||||
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE), PACKAGE="xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
@@ -1,49 +1,41 @@
|
||||
# Main function for xgboost-package
|
||||
|
||||
xgboost = function(x=NULL,y=NULL,DMatrix=NULL, file=NULL, validation=NULL,
|
||||
nrounds=10, obj=NULL, feval=NULL, margin=NULL, verbose = T, ...)
|
||||
xgboost = function(data=NULL, label = NULL, params=list(), nrounds=10,
|
||||
verbose = 1, ...)
|
||||
{
|
||||
if (!is.null(DMatrix))
|
||||
dtrain = DMatrix
|
||||
inClass = class(data)
|
||||
if (inClass=='dgCMatrix' || inClass=='matrix')
|
||||
{
|
||||
if (is.null(label))
|
||||
stop('xgboost: need label when data is a matrix')
|
||||
dtrain = xgb.DMatrix(data, label=y)
|
||||
}
|
||||
else
|
||||
{
|
||||
if (is.null(x) && is.null(y))
|
||||
{
|
||||
if (is.null(file))
|
||||
stop('xgboost need input data, either R objects, local files or DMatrix object.')
|
||||
dtrain = xgb.DMatrix(file)
|
||||
}
|
||||
if (!is.null(label))
|
||||
warning('xgboost: label will be ignored.')
|
||||
if (inClass=='character')
|
||||
dtrain = xgb.DMatrix(data)
|
||||
else if (inClass=='xgb.DMatrix')
|
||||
dtrain = data
|
||||
else
|
||||
dtrain = xgb.DMatrix(x, label=y)
|
||||
if (!is.null(margin))
|
||||
{
|
||||
succ <- xgb.setinfo(dtrain, "base_margin", margin)
|
||||
if (!succ)
|
||||
warning('Attemp to use margin failed.')
|
||||
}
|
||||
stop('xgboost: Invalid input of data')
|
||||
}
|
||||
|
||||
params = list(...)
|
||||
if (verbose>1)
|
||||
silent = 0
|
||||
else
|
||||
silent = 1
|
||||
|
||||
watchlist=list()
|
||||
if (verbose)
|
||||
{
|
||||
if (!is.null(validation))
|
||||
{
|
||||
if (class(validation)!='xgb.DMatrix')
|
||||
dtest = xgb.DMatrix(validation)
|
||||
else
|
||||
dtest = validation
|
||||
watchlist = list(eval=dtest,train=dtrain)
|
||||
}
|
||||
|
||||
else
|
||||
watchlist = list(train=dtrain)
|
||||
}
|
||||
params = append(params, list(silent=silent))
|
||||
params = append(params, list(...))
|
||||
|
||||
bst <- xgb.train(params, dtrain, nrounds, watchlist, obj, feval)
|
||||
if (verbose>0)
|
||||
watchlist = list(train=dtrain)
|
||||
else
|
||||
watchlist = list()
|
||||
|
||||
bst <- xgb.train(params, dtrain, nrounds, watchlist)
|
||||
|
||||
return(bst)
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user