From 92b996513ef25a764d85e4c837edd73908e8b942 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sat, 5 Sep 2015 22:50:27 -0400 Subject: [PATCH 1/2] TST: Added R unit test for glm --- R-package/tests/testthat/test_glm.R | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 R-package/tests/testthat/test_glm.R diff --git a/R-package/tests/testthat/test_glm.R b/R-package/tests/testthat/test_glm.R new file mode 100644 index 000000000..485aad82d --- /dev/null +++ b/R-package/tests/testthat/test_glm.R @@ -0,0 +1,19 @@ +context('Test generalized linear models') + +require(xgboost) + +test_that("glm works", { + data(agaricus.train, package='xgboost') + data(agaricus.test, package='xgboost') + dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) + dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) + expect_equal(class(dtrain), "xgb.DMatrix") + expect_equal(class(dtest), "xgb.DMatrix") + param <- list(objective = "binary:logistic", booster = "gblinear", + nthread = 2, alpha = 0.0001, lambda = 1) + watchlist <- list(eval = dtest, train = dtrain) + num_round <- 2 + expect_that(bst <- xgb.train(param, dtrain, num_round, watchlist), not(throws_error())) + expect_that(ypred <- predict(bst, dtest), not(throws_error())) + expect_equal(length(getinfo(dtest, 'label')), 1611) +}) From 339a53d9d4f8ac2a4d71b9034231f57477d24245 Mon Sep 17 00:00:00 2001 From: "Yuan Tang (Terry)" Date: Sun, 6 Sep 2015 20:00:25 -0400 Subject: [PATCH 2/2] fixed unit test in R --- R-package/tests/testthat/test_glm.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R-package/tests/testthat/test_glm.R b/R-package/tests/testthat/test_glm.R index 485aad82d..dc7b6efab 100644 --- a/R-package/tests/testthat/test_glm.R +++ b/R-package/tests/testthat/test_glm.R @@ -13,7 +13,7 @@ test_that("glm works", { nthread = 2, alpha = 0.0001, lambda = 1) watchlist <- list(eval = dtest, train = dtrain) num_round <- 2 - expect_that(bst <- xgb.train(param, dtrain, num_round, watchlist), not(throws_error())) - expect_that(ypred <- predict(bst, dtest), not(throws_error())) + bst <- xgb.train(param, dtrain, num_round, watchlist) + ypred <- predict(bst, dtest) expect_equal(length(getinfo(dtest, 'label')), 1611) })