Bring XGBoost4J Intro up-to-date (#3574)
This commit is contained in:
parent
2e7c3a0ed5
commit
9c647d8130
@ -6,15 +6,15 @@ This tutorial introduces Java API for XGBoost.
|
||||
**************
|
||||
Data Interface
|
||||
**************
|
||||
Like the XGBoost python module, XGBoost4J uses DMatrix to handle data,
|
||||
LIBSVM txt format file, sparse matrix in CSR/CSC format, and dense matrix is
|
||||
Like the XGBoost python module, XGBoost4J uses DMatrix to handle data.
|
||||
LIBSVM txt format file, sparse matrix in CSR/CSC format, and dense matrix are
|
||||
supported.
|
||||
|
||||
* The first step is to import DMatrix:
|
||||
|
||||
.. code-block:: java
|
||||
|
||||
import org.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
|
||||
* Use DMatrix constructor to load data from a libsvm text format file:
|
||||
|
||||
@ -39,7 +39,8 @@ supported.
|
||||
long[] rowHeaders = new long[] {0,2,4,7};
|
||||
float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f};
|
||||
int[] colIndex = new int[] {0,2,0,3,0,1,2};
|
||||
DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
||||
int numColumn = 4;
|
||||
DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn);
|
||||
|
||||
... or in `Compressed Sparse Column (CSC) <https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)>`_ format:
|
||||
|
||||
@ -48,7 +49,8 @@ supported.
|
||||
long[] colHeaders = new long[] {0,3,4,6,7};
|
||||
float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f};
|
||||
int[] rowIndex = new int[] {0,1,2,2,0,2,1};
|
||||
DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
|
||||
int numRow = 3;
|
||||
DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, numRow);
|
||||
|
||||
* You may also load your data from a dense matrix. Let's assume we have a matrix of form
|
||||
|
||||
@ -66,7 +68,7 @@ supported.
|
||||
int nrow = 3;
|
||||
int ncol = 2;
|
||||
float missing = 0.0f;
|
||||
DMatrix dmat = new Matrix(data, nrow, ncol, missing);
|
||||
DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
|
||||
|
||||
* To set weight:
|
||||
|
||||
@ -82,7 +84,7 @@ To set parameters, parameters are specified as a Map:
|
||||
|
||||
.. code-block:: java
|
||||
|
||||
Map<String, Object> params = new HashMap<>() {
|
||||
Map<String, Object> params = new HashMap<String, Object>() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
@ -101,8 +103,8 @@ With parameters and data, you are able to train a booster model.
|
||||
|
||||
.. code-block:: java
|
||||
|
||||
import org.dmlc.xgboost4j.java.Booster;
|
||||
import org.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
|
||||
* Training
|
||||
|
||||
@ -110,11 +112,13 @@ With parameters and data, you are able to train a booster model.
|
||||
|
||||
DMatrix trainMat = new DMatrix("train.svm.txt");
|
||||
DMatrix validMat = new DMatrix("valid.svm.txt");
|
||||
// Specify a watchList to see the performance
|
||||
// Any Iterable<Entry<String, DMatrix>> object could be used as watchList
|
||||
List<Entry<String, DMatrix>> watches = new ArrayList<>();
|
||||
watches.add(new SimpleEntry<>("train", trainMat));
|
||||
watches.add(new SimpleEntry<>("test", testMat));
|
||||
// Specify a watch list to see model accuracy on data sets
|
||||
Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
|
||||
{
|
||||
put("train", trainMat);
|
||||
put("test", testMat);
|
||||
}
|
||||
};
|
||||
int nround = 2;
|
||||
Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
|
||||
|
||||
@ -130,15 +134,16 @@ With parameters and data, you are able to train a booster model.
|
||||
|
||||
.. code-block:: java
|
||||
|
||||
String[] model_dump = booster.getModelDump(null, false)
|
||||
// dump without feature map
|
||||
String[] model_dump = booster.getModelDump(null, false);
|
||||
// dump with feature map
|
||||
String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false)
|
||||
String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false);
|
||||
|
||||
* Load a model
|
||||
|
||||
.. code-block:: java
|
||||
|
||||
Booster booster = Booster.loadModel("model.bin");
|
||||
Booster booster = XGBoost.loadModel("model.bin");
|
||||
|
||||
**********
|
||||
Prediction
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user