[jvm-packages] support missing value when constructing dmatrix with iterator (#10628)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -29,21 +29,27 @@ public class DMatrix {
|
||||
protected long handle = 0;
|
||||
|
||||
/**
|
||||
* sparse matrix type (CSR or CSC)
|
||||
* Create DMatrix from iterator.
|
||||
*
|
||||
* @param iter The data iterator of mini batch to provide the data.
|
||||
* @param cacheInfo Cache path information, used for external memory setting, can be null.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static enum SparseType {
|
||||
CSR,
|
||||
CSC;
|
||||
public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
|
||||
this(iter, cacheInfo, Float.NaN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix from iterator.
|
||||
*
|
||||
* @param iter The data iterator of mini batch to provide the data.
|
||||
* @param iter The data iterator of mini batch to provide the data.
|
||||
* @param cacheInfo Cache path information, used for external memory setting, can be null.
|
||||
* @param missing the missing value
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
|
||||
public DMatrix(Iterator<LabeledPoint> iter,
|
||||
String cacheInfo,
|
||||
float missing) throws XGBoostError {
|
||||
if (iter == null) {
|
||||
throw new NullPointerException("iter: null");
|
||||
}
|
||||
@@ -51,7 +57,8 @@ public class DMatrix {
|
||||
int batchSize = 32 << 10;
|
||||
Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(
|
||||
batchIter, cacheInfo, missing, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@@ -72,10 +79,11 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||
*
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Deprecated
|
||||
@@ -86,12 +94,13 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
*
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st,
|
||||
@@ -121,7 +130,6 @@ public class DMatrix {
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @throws XGBoostError native error
|
||||
*
|
||||
* @deprecated Please specify the missing value explicitly using
|
||||
* {@link DMatrix(float[], int, int, float)}
|
||||
*/
|
||||
@@ -144,9 +152,10 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
*
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
|
||||
@@ -157,13 +166,14 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
*
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
|
||||
matrix.ncol, missing, out));
|
||||
matrix.ncol, missing, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@@ -176,10 +186,11 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
*
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
@@ -194,36 +205,30 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of label column
|
||||
* @throws XGBoostError native error
|
||||
* flatten a mat to array
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set weight of DMatrix from cuda array interface
|
||||
* Set query id of DMatrix from array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of weight column
|
||||
* @param column the XGBoost Column to provide the array interface
|
||||
* of query id column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of base margin column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
public void setQueryId(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
@@ -257,17 +262,9 @@ public class DMatrix {
|
||||
return outValue[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature names
|
||||
* @param values feature names to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureNames(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_name", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature names
|
||||
*
|
||||
* @return an array of feature names to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@@ -276,16 +273,18 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature types
|
||||
* @param values feature types to be set
|
||||
* Set feature names
|
||||
*
|
||||
* @param values feature names to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_type", values);
|
||||
public void setFeatureNames(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_name", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature types
|
||||
*
|
||||
* @return an array of feature types to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@@ -294,46 +293,23 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
* Set feature types
|
||||
*
|
||||
* @param labels labels
|
||||
* @param values feature types to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_type", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix
|
||||
*
|
||||
* @return group size as array
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
*
|
||||
* @param weights weights
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin (initial prediction).
|
||||
*
|
||||
* The margin must have the same number of elements as the number of
|
||||
* rows in this matrix.
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
if (baseMargin.length != rowNum()) {
|
||||
throw new IllegalArgumentException(String.format(
|
||||
"base margin must have exactly %s elements, got %s",
|
||||
rowNum(), baseMargin.length));
|
||||
}
|
||||
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin (initial prediction).
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
setBaseMargin(flatten(baseMargin));
|
||||
public int[] getGroup() throws XGBoostError {
|
||||
return getIntInfo("group_ptr");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -347,13 +323,13 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix
|
||||
* Set query ids (used for ranking)
|
||||
*
|
||||
* @param qid the query ids
|
||||
* @throws XGBoostError native error
|
||||
* @return group size as array
|
||||
*/
|
||||
public int[] getGroup() throws XGBoostError {
|
||||
return getIntInfo("group_ptr");
|
||||
public void setQueryId(int[] qid) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "qid", qid));
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) throws XGBoostError {
|
||||
@@ -378,6 +354,27 @@ public class DMatrix {
|
||||
return getFloatInfo("label");
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the array interface
|
||||
* of label column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
* @param labels labels
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
}
|
||||
|
||||
/**
|
||||
* get weight of the DMatrix
|
||||
*
|
||||
@@ -388,6 +385,27 @@ public class DMatrix {
|
||||
return getFloatInfo("weight");
|
||||
}
|
||||
|
||||
/**
|
||||
* Set weight of DMatrix from array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the array interface
|
||||
* of weight column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
*
|
||||
* @param weights weights
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get base margin of the DMatrix.
|
||||
*/
|
||||
@@ -395,6 +413,40 @@ public class DMatrix {
|
||||
return getFloatInfo("base_margin");
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin of DMatrix from array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the array interface
|
||||
* of base margin column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin (initial prediction).
|
||||
* <p>
|
||||
* The margin must have the same number of elements as the number of
|
||||
* rows in this matrix.
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
if (baseMargin.length != rowNum()) {
|
||||
throw new IllegalArgumentException(String.format(
|
||||
"base margin must have exactly %s elements, got %s",
|
||||
rowNum(), baseMargin.length));
|
||||
}
|
||||
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin (initial prediction).
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
setBaseMargin(flatten(baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
*
|
||||
@@ -448,22 +500,6 @@ public class DMatrix {
|
||||
return handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
*/
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
dispose();
|
||||
@@ -475,4 +511,12 @@ public class DMatrix {
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
public enum SparseType {
|
||||
CSR,
|
||||
CSC
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
String cache_info, float missing, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
|
||||
float[] data, int shapeParam,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
/**
|
||||
@@ -33,14 +33,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
}
|
||||
|
||||
/**
|
||||
* init DMatrix from Iterator of LabeledPoint
|
||||
*
|
||||
* @param dataIter An iterator of LabeledPoint
|
||||
* @param cacheInfo Cache path information, used for external memory setting, null by default.
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
def this(dataIter: Iterator[LabeledPoint], cacheInfo: String = null) {
|
||||
this(new JDMatrix(dataIter.asJava, cacheInfo))
|
||||
* init DMatrix from Iterator of LabeledPoint
|
||||
*
|
||||
* @param dataIter An iterator of LabeledPoint
|
||||
* @param cacheInfo Cache path information, used for external memory setting, null by default.
|
||||
* @param missing Which value will be treated as the missing value
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
def this(dataIter: Iterator[LabeledPoint],
|
||||
cacheInfo: String = null,
|
||||
missing: Float = Float.NaN) {
|
||||
this(new JDMatrix(dataIter.asJava, cacheInfo, missing))
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -60,12 +63,12 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
*
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
* row number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
|
||||
@@ -76,14 +79,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
*
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
* @param missing missing value
|
||||
* @param nthread The number of threads used for constructing DMatrix
|
||||
* row number
|
||||
* @param missing missing value
|
||||
* @param nthread The number of threads used for constructing DMatrix
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
|
||||
@@ -93,10 +96,11 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
*
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread The number of threads used for constructing DMatrix
|
||||
* @param missing missing value
|
||||
* @param nthread The number of threads used for constructing DMatrix
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(columnBatch: ColumnBatch, missing: Float, nthread: Int) {
|
||||
@@ -119,9 +123,9 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@@ -181,6 +185,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setGroup(group)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set query ids (used for ranking)
|
||||
*
|
||||
* @param qid query ids
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setQueryId(qid: Array[Int]): Unit = {
|
||||
jDMatrix.setQueryId(qid)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*/
|
||||
@@ -205,8 +219,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setBaseMargin(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set query id of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setQueryId(column: Column): Unit = {
|
||||
jDMatrix.setQueryId(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set feature names
|
||||
*
|
||||
* @param values feature names
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@@ -217,6 +240,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* set feature types
|
||||
*
|
||||
* @param values feature types
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@@ -265,6 +289,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* get feature names
|
||||
*
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
@@ -275,6 +300,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* get feature types
|
||||
*
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user