add test module in R

This commit is contained in:
hetong007 2015-08-27 15:56:35 -07:00
parent b0be833c75
commit 4554da0537
3 changed files with 40 additions and 1 deletions

View File

@ -23,7 +23,8 @@ Suggests:
ggplot2 (>= 1.0.0),
DiagrammeR (>= 0.6),
Ckmeans.1d.dp (>= 3.3.1),
vcd (>= 1.3)
vcd (>= 1.3),
testthat
Depends:
R (>= 2.10)
Imports:

View File

@ -0,0 +1,4 @@
library(testthat)
library(xgboost)
test_check("xgboost")

View File

@ -0,0 +1,34 @@
require(xgboost)
context("basic functions")
test_that("data loading", {
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
})
test_that("train and prediction",{
train = agaricus.train
test = agaricus.test
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
pred = predict(bst, test$data)
})
test_that("early stopping", {
res = xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE)
expect_true(nrow(res)<20)
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
early.stop.round = 3, maximize = FALSE)
pred = predict(bst, test$data)
})
test_that("save_period", {
bst = xgboost(data = train$data, label = train$label, max.depth = 2,
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
save_period = 10, save_name = "xgb.model")
pred = predict(bst, test$data)
})