[jvm-packages] Create demo and test for xgboost4j early stopping. (#7252)
This commit is contained in:
parent
0ee11dac77
commit
fbd58bf190
@ -162,17 +162,17 @@ Example of setting a missing value (e.g. -999) to the "missing" parameter in XGB
|
||||
doing this with missing values encoded as NaN, you will want to set ``setHandleInvalid = "keep"`` on VectorAssembler
|
||||
in order to keep the NaN values in the dataset. You would then set the "missing" parameter to whatever you want to be
|
||||
treated as missing. However this may cause a large amount of memory use if your dataset is very sparse. For example:
|
||||
|
||||
|
||||
.. code-block:: scala
|
||||
|
||||
val assembler = new VectorAssembler().setInputCols(feature_names.toArray).setOutputCol("features").setHandleInvalid("keep")
|
||||
|
||||
// conversion to dense vector using Array()
|
||||
|
||||
|
||||
val featurePipeline = new Pipeline().setStages(Array(assembler))
|
||||
val featureModel = featurePipeline.fit(df_training)
|
||||
val featureDf = featureModel.transform(df_training)
|
||||
|
||||
|
||||
val xgbParam = Map("eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
@ -181,10 +181,10 @@ Example of setting a missing value (e.g. -999) to the "missing" parameter in XGB
|
||||
"num_workers" -> 2,
|
||||
"allow_non_zero_for_missing" -> "true",
|
||||
"missing" -> -999)
|
||||
|
||||
|
||||
val xgb = new XGBoostClassifier(xgbParam)
|
||||
val xgbclassifier = xgb.fit(featureDf)
|
||||
|
||||
|
||||
|
||||
2. Before calling VectorAssembler you can transform the values you want to represent missing into an irregular value
|
||||
that is not 0, NaN, or Null and set the "missing" parameter to 0. The irregular value should ideally be chosen to be
|
||||
|
||||
@ -10,6 +10,7 @@ XGBoost4J Code Examples
|
||||
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||
* [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java)
|
||||
|
||||
## Scala API
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -115,7 +115,7 @@ public class BasicWalkThrough {
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR);
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
//specify watchList
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
|
||||
public class EarlyStopping {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
DataLoader.CSRSparseData trainCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.CSRSparseData testCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
|
||||
DMatrix trainXy = new DMatrix(trainCSR.rowHeaders, trainCSR.colIndex, trainCSR.data,
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
trainXy.setLabel(trainCSR.labels);
|
||||
DMatrix testXy = new DMatrix(testCSR.rowHeaders, testCSR.colIndex, testCSR.data,
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
testXy.setLabel(testCSR.labels);
|
||||
|
||||
int nRounds = 128;
|
||||
int nEarlyStoppingRounds = 4;
|
||||
|
||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||
watches.put("training", trainXy);
|
||||
watches.put("test", testXy);
|
||||
|
||||
float[][] metrics = new float[watches.size()][nRounds];
|
||||
Booster booster = XGBoost.train(trainXy, paramMap, nRounds,
|
||||
watches, metrics, null, null, nEarlyStoppingRounds);
|
||||
|
||||
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
|
||||
float bestScore = Float.valueOf(booster.getAttr("best_score"));
|
||||
|
||||
System.out.printf("Best iter: %d, Best score: %f\n", bestIter, bestScore);
|
||||
}
|
||||
}
|
||||
@ -16,8 +16,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
@ -347,6 +345,34 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingAttributes() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||
watches.put("training", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
int round = 30;
|
||||
int earlyStoppingRound = 4;
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
|
||||
Booster booster = XGBoost.train(trainMat, paramMap, round,
|
||||
watches, metrics, null, null, earlyStoppingRound);
|
||||
|
||||
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
|
||||
float bestScore = Float.valueOf(booster.getAttr("best_score"));
|
||||
TestCase.assertEquals(bestIter, round - 1);
|
||||
TestCase.assertEquals(bestScore, metrics[watches.size() - 1][round - 1]);
|
||||
}
|
||||
|
||||
private void testWithQuantileHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user