Implement feature score for linear model. (#7048)
* Add feature score support for linear model. * Port R interface to the new implementation. * Add linear model support in Python. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -32,6 +32,9 @@ import org.junit.Test;
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
|
||||
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
|
||||
|
||||
public static class EvalError implements IEvaluation {
|
||||
@Override
|
||||
public String getMetric() {
|
||||
@@ -87,8 +90,8 @@ public class BoosterImplTest {
|
||||
@Test
|
||||
public void testBoosterBasic() throws XGBoostError, IOException {
|
||||
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
|
||||
@@ -103,8 +106,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
@@ -121,8 +124,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
|
||||
@@ -310,8 +313,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -363,8 +366,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testQuantileHistoDepthWise() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -383,8 +386,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testQuantileHistoLossGuide() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -404,8 +407,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testQuantileHistoLossGuideMaxBin() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -425,8 +428,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testDumpModelJson() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] dump = booster.getModelDump("", false, "json");
|
||||
@@ -441,8 +444,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testGetFeatureScore() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
@@ -453,8 +456,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceGain() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
@@ -465,8 +468,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceTotalGain() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
@@ -477,8 +480,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceCover() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
@@ -489,8 +492,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceTotalCover() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
@@ -501,7 +504,7 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testQuantileHistoDepthwiseMaxDepth() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -519,8 +522,8 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testQuantileHistoDepthwiseMaxDepthMaxBin() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
@@ -545,7 +548,7 @@ public class BoosterImplTest {
|
||||
@Test
|
||||
public void testCV() throws XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
|
||||
//set params
|
||||
Map<String, Object> param = new HashMap<String, Object>() {
|
||||
@@ -573,8 +576,8 @@ public class BoosterImplTest {
|
||||
*/
|
||||
@Test
|
||||
public void testTrainFromExistingModel() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
@@ -624,8 +627,8 @@ public class BoosterImplTest {
|
||||
*/
|
||||
@Test
|
||||
public void testSetAndGetAttrs() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
booster.setAttr("testKey1", "testValue1");
|
||||
@@ -654,10 +657,10 @@ public class BoosterImplTest {
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumFeature() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
TestCase.assertEquals(booster.getNumFeature(), 127);
|
||||
TestCase.assertEquals(booster.getNumFeature(), 126);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user