[jvm-packages] Create demo and test for xgboost4j early stopping. (#7252)
This commit is contained in:
parent
0ee11dac77
commit
fbd58bf190
@ -162,17 +162,17 @@ Example of setting a missing value (e.g. -999) to the "missing" parameter in XGB
|
|||||||
doing this with missing values encoded as NaN, you will want to set ``setHandleInvalid = "keep"`` on VectorAssembler
|
doing this with missing values encoded as NaN, you will want to set ``setHandleInvalid = "keep"`` on VectorAssembler
|
||||||
in order to keep the NaN values in the dataset. You would then set the "missing" parameter to whatever you want to be
|
in order to keep the NaN values in the dataset. You would then set the "missing" parameter to whatever you want to be
|
||||||
treated as missing. However this may cause a large amount of memory use if your dataset is very sparse. For example:
|
treated as missing. However this may cause a large amount of memory use if your dataset is very sparse. For example:
|
||||||
|
|
||||||
.. code-block:: scala
|
.. code-block:: scala
|
||||||
|
|
||||||
val assembler = new VectorAssembler().setInputCols(feature_names.toArray).setOutputCol("features").setHandleInvalid("keep")
|
val assembler = new VectorAssembler().setInputCols(feature_names.toArray).setOutputCol("features").setHandleInvalid("keep")
|
||||||
|
|
||||||
// conversion to dense vector using Array()
|
// conversion to dense vector using Array()
|
||||||
|
|
||||||
val featurePipeline = new Pipeline().setStages(Array(assembler))
|
val featurePipeline = new Pipeline().setStages(Array(assembler))
|
||||||
val featureModel = featurePipeline.fit(df_training)
|
val featureModel = featurePipeline.fit(df_training)
|
||||||
val featureDf = featureModel.transform(df_training)
|
val featureDf = featureModel.transform(df_training)
|
||||||
|
|
||||||
val xgbParam = Map("eta" -> 0.1f,
|
val xgbParam = Map("eta" -> 0.1f,
|
||||||
"max_depth" -> 2,
|
"max_depth" -> 2,
|
||||||
"objective" -> "multi:softprob",
|
"objective" -> "multi:softprob",
|
||||||
@ -181,10 +181,10 @@ Example of setting a missing value (e.g. -999) to the "missing" parameter in XGB
|
|||||||
"num_workers" -> 2,
|
"num_workers" -> 2,
|
||||||
"allow_non_zero_for_missing" -> "true",
|
"allow_non_zero_for_missing" -> "true",
|
||||||
"missing" -> -999)
|
"missing" -> -999)
|
||||||
|
|
||||||
val xgb = new XGBoostClassifier(xgbParam)
|
val xgb = new XGBoostClassifier(xgbParam)
|
||||||
val xgbclassifier = xgb.fit(featureDf)
|
val xgbclassifier = xgb.fit(featureDf)
|
||||||
|
|
||||||
|
|
||||||
2. Before calling VectorAssembler you can transform the values you want to represent missing into an irregular value
|
2. Before calling VectorAssembler you can transform the values you want to represent missing into an irregular value
|
||||||
that is not 0, NaN, or Null and set the "missing" parameter to 0. The irregular value should ideally be chosen to be
|
that is not 0, NaN, or Null and set the "missing" parameter to 0. The irregular value should ideally be chosen to be
|
||||||
|
|||||||
@ -10,6 +10,7 @@ XGBoost4J Code Examples
|
|||||||
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||||
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||||
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||||
|
* [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java)
|
||||||
|
|
||||||
## Scala API
|
## Scala API
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2021 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -115,7 +115,7 @@ public class BasicWalkThrough {
|
|||||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||||
|
|
||||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||||
DMatrix.SparseType.CSR);
|
DMatrix.SparseType.CSR, 127);
|
||||||
trainMat2.setLabel(spData.labels);
|
trainMat2.setLabel(spData.labels);
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
|
|||||||
@ -0,0 +1,67 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.Booster;
|
||||||
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
|
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||||
|
|
||||||
|
public class EarlyStopping {
|
||||||
|
public static void main(String[] args) throws IOException, XGBoostError {
|
||||||
|
DataLoader.CSRSparseData trainCSR =
|
||||||
|
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||||
|
DataLoader.CSRSparseData testCSR =
|
||||||
|
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "false");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DMatrix trainXy = new DMatrix(trainCSR.rowHeaders, trainCSR.colIndex, trainCSR.data,
|
||||||
|
DMatrix.SparseType.CSR, 127);
|
||||||
|
trainXy.setLabel(trainCSR.labels);
|
||||||
|
DMatrix testXy = new DMatrix(testCSR.rowHeaders, testCSR.colIndex, testCSR.data,
|
||||||
|
DMatrix.SparseType.CSR, 127);
|
||||||
|
testXy.setLabel(testCSR.labels);
|
||||||
|
|
||||||
|
int nRounds = 128;
|
||||||
|
int nEarlyStoppingRounds = 4;
|
||||||
|
|
||||||
|
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||||
|
watches.put("training", trainXy);
|
||||||
|
watches.put("test", testXy);
|
||||||
|
|
||||||
|
float[][] metrics = new float[watches.size()][nRounds];
|
||||||
|
Booster booster = XGBoost.train(trainXy, paramMap, nRounds,
|
||||||
|
watches, metrics, null, null, nEarlyStoppingRounds);
|
||||||
|
|
||||||
|
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
|
||||||
|
float bestScore = Float.valueOf(booster.getAttr("best_score"));
|
||||||
|
|
||||||
|
System.out.printf("Best iter: %d, Best score: %f\n", bestIter, bestScore);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,8 +16,6 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
@ -347,6 +345,34 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEarlyStoppingAttributes() throws XGBoostError, IOException {
|
||||||
|
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||||
|
DMatrix testMat = new DMatrix(this.test_uri);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "false");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
int round = 30;
|
||||||
|
int earlyStoppingRound = 4;
|
||||||
|
float[][] metrics = new float[watches.size()][round];
|
||||||
|
|
||||||
|
Booster booster = XGBoost.train(trainMat, paramMap, round,
|
||||||
|
watches, metrics, null, null, earlyStoppingRound);
|
||||||
|
|
||||||
|
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
|
||||||
|
float bestScore = Float.valueOf(booster.getAttr("best_score"));
|
||||||
|
TestCase.assertEquals(bestIter, round - 1);
|
||||||
|
TestCase.assertEquals(bestScore, metrics[watches.size() - 1][round - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
private void testWithQuantileHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
private void testWithQuantileHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||||
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||||
float[][] metrics = new float[watches.size()][round];
|
float[][] metrics = new float[watches.size()][round];
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user