From 886955148ddcc3cc711ff1307d4625f3406e3931 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Mon, 7 Sep 2015 21:55:17 -0400 Subject: [PATCH] TST: Added test for models with custom objective --- .../tests/testthat/test_custom_objective.R | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 R-package/tests/testthat/test_custom_objective.R diff --git a/R-package/tests/testthat/test_custom_objective.R b/R-package/tests/testthat/test_custom_objective.R new file mode 100644 index 000000000..a941e39c8 --- /dev/null +++ b/R-package/tests/testthat/test_custom_objective.R @@ -0,0 +1,43 @@ +context('Test models with custom objective') + +require(xgboost) + +test_that("custom objective works", { + data(agaricus.train, package='xgboost') + data(agaricus.test, package='xgboost') + dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) + dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) + + watchlist <- list(eval = dtest, train = dtrain) + num_round <- 2 + + logregobj <- function(preds, dtrain) { + labels <- getinfo(dtrain, "label") + preds <- 1/(1 + exp(-preds)) + grad <- preds - labels + hess <- preds * (1 - preds) + return(list(grad = grad, hess = hess)) + } + evalerror <- function(preds, dtrain) { + labels <- getinfo(dtrain, "label") + err <- as.numeric(sum(labels != (preds > 0)))/length(labels) + return(list(metric = "error", value = err)) + } + + param <- list(max.depth=2, eta=1, nthread = 2, silent=1, + objective=logregobj, eval_metric=evalerror) + + bst <- xgb.train(param, dtrain, num_round, watchlist) + attr(dtrain, 'label') <- getinfo(dtrain, 'label') + + logregobjattr <- function(preds, dtrain) { + labels <- attr(dtrain, 'label') + preds <- 1/(1 + exp(-preds)) + grad <- preds - labels + hess <- preds * (1 - preds) + return(list(grad = grad, hess = hess)) + } + param <- list(max.depth=2, eta=1, nthread = 2, silent=1, + objective=logregobjattr, eval_metric=evalerror) + bst <- xgb.train(param, dtrain, num_round, watchlist) +}) \ No newline at end of file