[breakinig] [jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix (#8461)

This commit is contained in:
Bobby Wang
2022-12-05 12:23:21 +08:00
committed by GitHub
parent 78d65a1928
commit f1e9bbcee5
13 changed files with 150 additions and 94 deletions

View File

@@ -71,7 +71,7 @@ public abstract class ColumnBatch implements AutoCloseable {
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();

View File

@@ -1,68 +0,0 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* DeviceQuantileDMatrix will only be used to train
*/
public class DeviceQuantileDMatrix extends DMatrix {
/**
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public DeviceQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
iter, missing, maxBin, nthread, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
}
}

View File

@@ -0,0 +1,75 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}
private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}

View File

@@ -149,9 +149,13 @@ class XGBoostJNI {
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);
@Deprecated
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);

View File

@@ -18,13 +18,13 @@ package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, DeviceQuantileDMatrix => JDeviceQuantileDMatrix}
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
class DeviceQuantileDMatrix private[scala](
private[scala] override val jDMatrix: JDeviceQuantileDMatrix) extends DMatrix(jDMatrix) {
class QuantileDMatrix private[scala](
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
/**
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
@@ -33,7 +33,7 @@ class DeviceQuantileDMatrix private[scala](
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
this(new JDeviceQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
}
/**
@@ -43,7 +43,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setLabel(labels: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
/**
* set weight of each instance
@@ -52,7 +52,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setWeight(weights: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
/**
* if specified, xgboost will start from this init margin
@@ -62,7 +62,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
/**
* if specified, xgboost will start from this init margin
@@ -72,7 +72,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
/**
* Set group sizes of DMatrix (used for ranking)
@@ -81,27 +81,27 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setGroup(group: Array[Int]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.")
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
/**
* Set label of DMatrix from cuda array interface
*/
@throws(classOf[XGBoostError])
override def setLabel(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
/**
* set weight of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setWeight(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
/**
* set base margin of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
}

View File

@@ -962,6 +962,9 @@ namespace jni {
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout);
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
char const *config, jlongArray jout);
} // namespace jni
} // namespace xgboost
@@ -977,6 +980,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
(JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, jlongArray jout) {
char const *conf = jenv->GetStringUTFChars(jconf, 0);
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
conf, jout);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface

View File

@@ -345,6 +345,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/