[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)
- Introduce an abstract XGBoost Estimator - Update to the latest XGBoost parameters - Add all XGBoost parameters supported in XGBoost4j-spark. - Add setter and getter for these parameters. - Remove the deprecated parameters - Address the missing value handling - Remove any ETL operations in XGBoost - Rework the GPU plugin - Expand sanity tests for CPU and GPU consistency
This commit is contained in:
@@ -519,4 +519,39 @@ public class DMatrix {
|
||||
CSR,
|
||||
CSC
|
||||
}
|
||||
|
||||
/**
|
||||
* A class to hold the quantile information
|
||||
*/
|
||||
public class QuantileCut {
|
||||
// cut ptr
|
||||
long[] indptr;
|
||||
// cut values
|
||||
float[] values;
|
||||
|
||||
QuantileCut(long[] indptr, float[] values) {
|
||||
this.indptr = indptr;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
public long[] getIndptr() {
|
||||
return indptr;
|
||||
}
|
||||
|
||||
public float[] getValues() {
|
||||
return values;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the Quantile Cut.
|
||||
* @return QuantileCut
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public QuantileCut getQuantileCut() throws XGBoostError {
|
||||
long[][] indptr = new long[1][];
|
||||
float[][] values = new float[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetQuantileCut(this.handle, indptr, values));
|
||||
return new QuantileCut(indptr[0], values[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* QuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class QuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||
*
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public QuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
String conf = getConfig(missing, maxBin, nthread);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
||||
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
|
||||
private String getConfig(float missing, int maxBin, int nthread) {
|
||||
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
||||
missing, maxBin, nthread);
|
||||
}
|
||||
}
|
||||
@@ -172,7 +172,7 @@ class XGBoostJNI {
|
||||
long handle, String field, String json);
|
||||
|
||||
public final static native int XGQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
||||
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
@@ -180,4 +180,7 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
public final static native int XGDMatrixGetQuantileCut(long handle, long[][] outIndptr, float[][] outValues);
|
||||
|
||||
}
|
||||
|
||||
@@ -365,4 +365,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
override def read(kryo: Kryo, input: Input): Unit = {
|
||||
booster = kryo.readObject(input, classOf[JBooster])
|
||||
}
|
||||
|
||||
// a flag to indicate if the device is set for the GPU transform
|
||||
var deviceIsSet = false
|
||||
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
|
||||
|
||||
class QuantileDMatrix private[scala](
|
||||
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
|
||||
|
||||
/**
|
||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||
*
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
|
||||
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
* @param labels labels
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setLabel(labels: Array[Float]): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
*
|
||||
* @param weights weights
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setWeight(weights: Array[Float]): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(baseMargin: Array[Float]): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
/**
|
||||
* Set group sizes of DMatrix (used for ranking)
|
||||
*
|
||||
* @param group group size as array
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setGroup(group: Array[Int]): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setLabel(column: Column): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
|
||||
|
||||
/**
|
||||
* set weight of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setWeight(column: Column): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
|
||||
|
||||
/**
|
||||
* set base margin of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(column: Column): Unit =
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
//
|
||||
// Created by bobwang on 2021/9/8.
|
||||
//
|
||||
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include <jni.h>
|
||||
@@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
jobject jdata_iter, jlongArray jref,
|
||||
char const *config, jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include "../../../../src/common/device_helpers.cuh"
|
||||
#include "../../../../src/common/cuda_pinned_allocator.h"
|
||||
#include "../../../../src/common/device_vector.cuh" // for device_vector
|
||||
#include "../../../../src/data/array_interface.h"
|
||||
#include "jvm_utils.h"
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
@@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
|
||||
int Next(DataIterHandle self) {
|
||||
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using Deleter = std::function<void(T *)>;
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
@@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
jobject jdata_iter, jlongArray jref,
|
||||
char const *config, jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
||||
DMatrixHandle result;
|
||||
DMatrixHandle ref{nullptr};
|
||||
|
||||
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
|
||||
if (jref_iter) {
|
||||
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
|
||||
if (jref != NULL) {
|
||||
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
|
||||
[&](jlong *ptr) {
|
||||
jenv->ReleaseLongArrayElements(jref, ptr, 0);
|
||||
jenv->DeleteLocalRef(jref);
|
||||
}};
|
||||
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
|
||||
}
|
||||
|
||||
auto ret = XGQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
|
||||
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/string_view.h> // for StringView
|
||||
|
||||
#include <algorithm> // for copy_n
|
||||
#include <cstddef>
|
||||
@@ -30,8 +31,9 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/c_api/c_api_utils.h"
|
||||
#include "../../../../src/c_api/c_api_error.h"
|
||||
#include "../../../../src/c_api/c_api_utils.h"
|
||||
#include "../../../../src/data/array_interface.h" // for ArrayInterface
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
{ \
|
||||
@@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
|
||||
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
|
||||
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
|
||||
jlongArray jout) {
|
||||
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
|
||||
[&](char const *ptr) {
|
||||
jenv->ReleaseStringUTFChars(jconf, ptr);
|
||||
}};
|
||||
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
|
||||
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
|
||||
conf.get(), jout);
|
||||
}
|
||||
|
||||
@@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetQuantileCut
|
||||
* Signature: (J[[J[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
|
||||
using namespace xgboost; // NOLINT
|
||||
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
|
||||
|
||||
char const *str_indptr;
|
||||
char const *str_data;
|
||||
Json config{Object{}};
|
||||
auto str_config = Json::Dump(config);
|
||||
|
||||
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
|
||||
|
||||
ArrayInterface<1> indptr{StringView{str_indptr}};
|
||||
ArrayInterface<1> data{StringView{str_data}};
|
||||
CHECK_GE(indptr.Shape(0), 2);
|
||||
|
||||
// Cut ptr
|
||||
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
|
||||
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
|
||||
CHECK_LT(indptr(indptr.Shape(0) - 1),
|
||||
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
|
||||
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
|
||||
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
|
||||
static_cast<jlong const *>(indptr.data));
|
||||
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
|
||||
|
||||
// Cut values
|
||||
auto n_cuts = indptr(indptr.Shape(0) - 1);
|
||||
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
|
||||
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
|
||||
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
|
||||
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
|
||||
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetQuantileCut
|
||||
* Signature: (J[[J[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
|
||||
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -258,8 +258,7 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
|
||||
private DMatrix createFromDenseMatrix() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
int nrow = 10;
|
||||
int ncol = 5;
|
||||
@@ -280,12 +279,17 @@ public class DMatrixTest {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||
dmat0.setLabel(label0);
|
||||
DMatrix dm = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||
dm.setLabel(label0);
|
||||
return dm;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
|
||||
DMatrix dm = createFromDenseMatrix();
|
||||
//check
|
||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
TestCase.assertTrue(dm.rowNum() == 10);
|
||||
TestCase.assertTrue(dm.getLabel().length == 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -493,4 +497,28 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getGetQuantileCut() throws XGBoostError {
|
||||
DMatrix Xy = createFromDenseMatrix();
|
||||
Map<String, Object> params = new HashMap<String, Object>();
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", Xy);
|
||||
XGBoost.train(Xy, params, 1, watches, null, null); // Create the cuts
|
||||
DMatrix.QuantileCut cuts = Xy.getQuantileCut();
|
||||
TestCase.assertEquals(cuts.indptr.length, 6);
|
||||
for (int i = 1; i < cuts.indptr.length; ++i) {
|
||||
// Number of bins for each feature + min value.
|
||||
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] >= 5);
|
||||
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] <= Xy.rowNum() + 1);
|
||||
}
|
||||
TestCase.assertEquals(cuts.values.length, cuts.indptr[cuts.indptr.length - 1]);
|
||||
for (int i = 1; i < cuts.indptr.length; ++i) {
|
||||
long begin = cuts.indptr[i - 1];
|
||||
long end = cuts.indptr[i];
|
||||
for (long j = begin + 1; j < end; ++j) {
|
||||
TestCase.assertTrue(cuts.values[(int) j] > cuts.values[(int) j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user