[jvm-packages]add feature size for LabelPoint and DataBatch (#5303)
* fix type error * Validate number of features. * resolve comments * add feature size for LabelPoint and DataBatch * pass the feature size to native * move feature size validating tests into a separate suite * resolve comments Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
parent
8bc595ea1e
commit
ad826e913f
@ -29,7 +29,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
size_t gpu_page_size;
|
||||
bool enable_experimental_json_serialization {false};
|
||||
bool validate_parameters {false};
|
||||
bool validate_features {true};
|
||||
|
||||
void CheckDeprecated() {
|
||||
if (this->n_gpus != 0) {
|
||||
@ -75,9 +74,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
DMLC_DECLARE_FIELD(validate_parameters)
|
||||
.set_default(false)
|
||||
.describe("Enable checking whether parameters are used or not.");
|
||||
DMLC_DECLARE_FIELD(validate_features)
|
||||
.set_default(false)
|
||||
.describe("Enable validating input DMatrix.");
|
||||
DMLC_DECLARE_FIELD(n_gpus)
|
||||
.set_default(0)
|
||||
.set_range(0, 1)
|
||||
|
||||
@ -49,7 +49,7 @@ object XGBoost {
|
||||
Rabit.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, index.toArray, value.map(_.toFloat).toArray)
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
|
||||
val trainMat = new DMatrix(dataIter, null)
|
||||
|
||||
@ -56,7 +56,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
LabeledPoint(0.0f, index.toArray, value.map(_.toFloat).toArray)
|
||||
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it) yield mapper(x)
|
||||
val dmat = new DMatrix(dataIter, null)
|
||||
|
||||
@ -38,15 +38,11 @@ object DataUtils extends Serializable {
|
||||
|
||||
/**
|
||||
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
|
||||
*
|
||||
* If the point is sparse, the dimensionality of the resulting sparse
|
||||
* vector would be [[Int.MaxValue]]. This is the only safe value, since
|
||||
* XGBoost does not store the dimensionality explicitly.
|
||||
*/
|
||||
def features: Vector = if (labeledPoint.indices == null) {
|
||||
Vectors.dense(labeledPoint.values.map(_.toDouble))
|
||||
} else {
|
||||
Vectors.sparse(Int.MaxValue, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
||||
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,9 +64,9 @@ object DataUtils extends Serializable {
|
||||
*/
|
||||
def asXGB: XGBLabeledPoint = v match {
|
||||
case v: DenseVector =>
|
||||
XGBLabeledPoint(0.0f, null, v.values.map(_.toFloat))
|
||||
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
|
||||
case v: SparseVector =>
|
||||
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
|
||||
XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,18 +158,18 @@ object DataUtils extends Serializable {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||
baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
|
||||
test("transform throwing exception if feature size of dataset is different with model's") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
||||
// if feature size of data is not equal to 251
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val thrown = intercept[Exception] {
|
||||
model.transform(assembler.transform(df)).show()
|
||||
}
|
||||
assert(thrown.getMessage.contains(
|
||||
"Number of columns does not match number of features in booster"))
|
||||
}
|
||||
|
||||
test("train throwing exception if feature size of dataset is different on distributed train") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||
import DataUtils._
|
||||
val sparkSession = SparkSession.builder().getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
|
||||
.map(lp => (lp.label, lp)).partitionBy(
|
||||
new Partitioner {
|
||||
override def numPartitions: Int = 2
|
||||
|
||||
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
|
||||
}
|
||||
).map(_._2).zipWithIndex().map {
|
||||
case (lp, id) =>
|
||||
(id, lp.label, lp.features)
|
||||
}.toDF("id", "label", "features")
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
intercept[XGBoostError] {
|
||||
xgb.fit(repartitioned)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -19,13 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
import java.util.Arrays
|
||||
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import scala.util.Random
|
||||
|
||||
import scala.util.Random
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
@ -138,12 +137,21 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
||||
// if feature size of data is not equal to 251
|
||||
for (x <- 1 to 250) {
|
||||
df = df.withColumn(s"feature_${x}", lit(1))
|
||||
}
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
model.transform(assembler.transform(df)).show()
|
||||
df = assembler.transform(df)
|
||||
for (x <- 1 to 250) {
|
||||
df = df.drop(s"feature_${x}")
|
||||
}
|
||||
model.transform(df).show()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -31,11 +31,12 @@ trait TrainTestData {
|
||||
Source.fromInputStream(is).getLines()
|
||||
}
|
||||
|
||||
protected def getLabeledPoints(resource: String, zeroBased: Boolean): Seq[XGBLabeledPoint] = {
|
||||
protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
|
||||
Seq[XGBLabeledPoint] = {
|
||||
getResourceLines(resource).map { line =>
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures.head.toFloat
|
||||
val values = new Array[Float](126)
|
||||
val values = new Array[Float](featureSize)
|
||||
for (feature <- labelAndFeatures.tail) {
|
||||
val idAndValue = feature.split(":")
|
||||
if (!zeroBased) {
|
||||
@ -45,7 +46,7 @@ trait TrainTestData {
|
||||
}
|
||||
}
|
||||
|
||||
XGBLabeledPoint(label, null, values)
|
||||
XGBLabeledPoint(label, featureSize, null, values)
|
||||
}.toList
|
||||
}
|
||||
|
||||
@ -56,14 +57,14 @@ trait TrainTestData {
|
||||
val label = original.head.toFloat
|
||||
val group = original.last.toInt
|
||||
val values = original.slice(1, length - 1).map(_.toFloat)
|
||||
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
|
||||
XGBLabeledPoint(label, values.size, null, values, 1f, group, Float.NaN)
|
||||
}.toList
|
||||
}
|
||||
}
|
||||
|
||||
object Classification extends TrainTestData {
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", zeroBased = false)
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", zeroBased = false)
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
|
||||
}
|
||||
|
||||
object MultiClassification extends TrainTestData {
|
||||
@ -80,19 +81,24 @@ object MultiClassification extends TrainTestData {
|
||||
values(i) = featuresAndLabel(i).toFloat
|
||||
}
|
||||
|
||||
XGBLabeledPoint(label, null, values.take(values.length - 1))
|
||||
XGBLabeledPoint(label, values.length - 1, null, values.take(values.length - 1))
|
||||
}.toList
|
||||
}
|
||||
}
|
||||
|
||||
object Regression extends TrainTestData {
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.train", zeroBased = true)
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.test", zeroBased = true)
|
||||
val MACHINE_COL_NUM = 36
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints(
|
||||
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
||||
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
|
||||
}
|
||||
|
||||
object Ranking extends TrainTestData {
|
||||
val RANK_COL_NUM = 3
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
||||
"/rank.test.txt", RANK_COL_NUM, zeroBased = false)
|
||||
|
||||
private def getGroups(resource: String): Seq[Int] = {
|
||||
getResourceLines(resource).map(_.toInt).toList
|
||||
@ -100,10 +106,17 @@ object Ranking extends TrainTestData {
|
||||
}
|
||||
|
||||
object Synthetic extends {
|
||||
val TRAIN_COL_NUM = 3
|
||||
val TRAIN_WRONG_COL_NUM = 2
|
||||
val train: Seq[XGBLabeledPoint] = Seq(
|
||||
XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f)),
|
||||
XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
||||
XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
||||
XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f))
|
||||
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
|
||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
||||
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f))
|
||||
)
|
||||
|
||||
val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq(
|
||||
XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
|
||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f))
|
||||
)
|
||||
}
|
||||
|
||||
@ -17,12 +17,9 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
@ -308,4 +305,5 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
xgb.fit(repartitioned)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -16,19 +16,13 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.{TaskContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
@ -350,12 +344,21 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||
var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||
toDF("feature", "label").repartition(5)
|
||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
||||
// if feature size of data is not equal to 251
|
||||
for (x <- 1 to 250) {
|
||||
df = df.withColumn(s"feature_${x}", lit(1))
|
||||
}
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
|
||||
df = assembler.transform(df)
|
||||
for (x <- 1 to 250) {
|
||||
df = df.drop(s"feature_${x}")
|
||||
}
|
||||
val df1 = model.transform(df).withColumnRenamed(
|
||||
"prediction", "prediction1").withColumnRenamed(
|
||||
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
||||
"probability", "probability1")
|
||||
@ -363,4 +366,5 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
df1.collect()
|
||||
df2.collect()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -69,8 +69,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test regression prediction parity w/o ring reduce") {
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDM = new DMatrix(Regression.test.iterator, null)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val model1 = new XGBoostRegressor(xgbSettings).fit(training)
|
||||
|
||||
@ -49,7 +49,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
*/
|
||||
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
|
||||
init(cacheMats);
|
||||
setParam("validate_features", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
@ -27,14 +27,17 @@ class DataBatch {
|
||||
final int[] featureIndex;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
final float[] featureValue ;
|
||||
/** feature columns */
|
||||
final int featureCols;
|
||||
|
||||
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
float[] featureValue, int featureCols) {
|
||||
this.rowOffset = rowOffset;
|
||||
this.weight = weight;
|
||||
this.label = label;
|
||||
this.featureIndex = featureIndex;
|
||||
this.featureValue = featureValue;
|
||||
this.featureCols = featureCols;
|
||||
}
|
||||
|
||||
static class BatchIterator implements Iterator<DataBatch> {
|
||||
@ -56,9 +59,15 @@ class DataBatch {
|
||||
try {
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
int numCol = -1;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
while (base.hasNext() && batch.size() < batchSize) {
|
||||
LabeledPoint labeledPoint = base.next();
|
||||
if (numCol == -1) {
|
||||
numCol = labeledPoint.size();
|
||||
} else if (numCol != labeledPoint.size()) {
|
||||
throw new RuntimeException("Feature size is not the same");
|
||||
}
|
||||
batch.add(labeledPoint);
|
||||
numElem += labeledPoint.values().length;
|
||||
numRows++;
|
||||
@ -91,7 +100,7 @@ class DataBatch {
|
||||
}
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue, numCol);
|
||||
} catch (RuntimeException runtimeError) {
|
||||
logger.error(runtimeError);
|
||||
return null;
|
||||
|
||||
@ -20,6 +20,7 @@ package ml.dmlc.xgboost4j
|
||||
* Labeled training data point.
|
||||
*
|
||||
* @param label Label of this point.
|
||||
* @param size Feature dimensionality
|
||||
* @param indices Feature indices of this point or `null` if the data is dense.
|
||||
* @param values Feature values of this point.
|
||||
* @param weight Weight of this point.
|
||||
@ -28,6 +29,7 @@ package ml.dmlc.xgboost4j
|
||||
*/
|
||||
case class LabeledPoint(
|
||||
label: Float,
|
||||
size: Int,
|
||||
indices: Array[Int],
|
||||
values: Array[Float],
|
||||
weight: Float = 1f,
|
||||
@ -36,8 +38,11 @@ case class LabeledPoint(
|
||||
require(indices == null || indices.length == values.length,
|
||||
"indices and values must have the same number of elements")
|
||||
|
||||
def this(label: Float, indices: Array[Int], values: Array[Float]) = {
|
||||
require(indices == null || size >= indices.length,
|
||||
"feature dimensionality must be greater equal than size of indices")
|
||||
|
||||
def this(label: Float, size: Int, indices: Array[Int], values: Array[Float]) = {
|
||||
// [[weight]] default duplicated to disambiguate the constructor call.
|
||||
this(label, indices, values, 1.0f)
|
||||
this(label, size, indices, values, 1.0f)
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,9 +91,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
|
||||
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
|
||||
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
|
||||
jint jcols = jenv->GetIntField(
|
||||
batch, jenv->GetFieldID(batchClass, "featureCols", "I"));
|
||||
XGBoostBatchCSR cbatch;
|
||||
cbatch.size = jenv->GetArrayLength(joffset) - 1;
|
||||
cbatch.columns = std::numeric_limits<size_t>::max();
|
||||
cbatch.columns = jcols;
|
||||
cbatch.offset = reinterpret_cast<jlong *>(
|
||||
jenv->GetLongArrayElements(joffset, 0));
|
||||
if (jlabel != nullptr) {
|
||||
|
||||
@ -45,7 +45,7 @@ public class DMatrixTest {
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
@ -57,6 +57,33 @@ public class DMatrixTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError {
|
||||
//create DMatrix from DataIterator
|
||||
|
||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||
int nrep = 3000;
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
int featureSize = 4;
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
// set some rows with wrong feature size
|
||||
if (i % 10 == 1) {
|
||||
featureSize = 5;
|
||||
}
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
0.1f + i, featureSize, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
boolean success = true;
|
||||
try {
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
} catch (XGBoostError e) {
|
||||
success = false;
|
||||
}
|
||||
TestCase.assertTrue(success == false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
|
||||
@ -125,6 +125,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
} else {
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
}
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
|
||||
|
||||
@ -1063,19 +1063,9 @@ class LearnerImpl : public LearnerIO {
|
||||
return tparam_.dsplit == DataSplitMode::kRow ||
|
||||
tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
bool const valid_features =
|
||||
!row_based_split() ||
|
||||
(learner_model_param_.num_feature == p_fmat->Info().num_col_);
|
||||
std::string const msg {
|
||||
"Number of columns does not match number of features in booster."
|
||||
};
|
||||
if (generic_parameters_.validate_features) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << msg;
|
||||
} else if (!valid_features) {
|
||||
// Remove this and make the equality check fatal once spark can fix all failing tests.
|
||||
LOG(WARNING) << msg << " "
|
||||
<< "Columns: " << p_fmat->Info().num_col_ << " "
|
||||
<< "Features: " << learner_model_param_.num_feature;
|
||||
if (row_based_split()) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in booster.";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user