xgboost/R-package/tests/testthat/test_interaction_constraints.R
Andrew Thia 9254c58e4d [TREE] add interaction constraints (#3466)
* add interaction constraints

* enable both interaction and monotonic constraints at the same time

* fix lint

* add R test, fix lint, update demo

* Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping

* Add Python test for interaction constraints

* make R interaction constraints parameter based on feature index instead of column names, fix R coding style

* Fix lint

* Add BlueTea88 to CONTRIBUTORS.md

* Short circuit when no constraint is specified; address review comments

* Add tutorial for feature interaction constraints

* allow interaction constraints to be passed as string, remove redundant column_names argument

* Fix typo

* Address review comments

* Add comments to Python test
2018-09-04 09:35:39 -07:00

39 lines
1.2 KiB
R

require(xgboost)
context("interaction constraints")
set.seed(1024)
x1 <- rnorm(1000, 1)
x2 <- rnorm(1000, 1)
x3 <- sample(c(1,2,3), size=1000, replace=TRUE)
y <- x1 + x2 + x3 + x1*x2*x3 + rnorm(1000, 0.001) + 3*sin(x1)
train <- matrix(c(x1,x2,x3), ncol = 3)
test_that("interaction constraints for regression", {
# Fit a model that only allows interaction between x1 and x2
bst <- xgboost(data = train, label = y, max_depth = 3,
eta = 0.1, nthread = 2, nrounds = 100, verbose = 0,
interaction_constraints = list(c(0,1)))
# Set all observations to have the same x3 values then increment
# by the same amount
preds <- lapply(c(1,2,3), function(x){
tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3)
return(predict(bst, tmat))
})
# Check incrementing x3 has the same effect on all observations
# since x3 is constrained to be independent of x1 and x2
# and all observations start off from the same x3 value
diff1 <- preds[[2]] - preds[[1]]
test1 <- all(abs(diff1 - diff1[1]) < 1e-4)
diff2 <- preds[[3]] - preds[[2]]
test2 <- all(abs(diff2 - diff2[1]) < 1e-4)
expect_true({
test1 & test2
}, "Interaction Contraint Satisfied")
})