[jvm-packages]add feature size for LabelPoint and DataBatch (#5303)

* fix type error

* Validate number of features.

* resolve comments

* add feature size for LabelPoint and DataBatch

* pass the feature size to native

* move feature size validating tests into a separate suite

* resolve comments

Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang 2020-04-08 07:49:52 +08:00 committed by GitHub
parent 8bc595ea1e
commit ad826e913f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 193 additions and 75 deletions

View File

@ -29,7 +29,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
size_t gpu_page_size;
bool enable_experimental_json_serialization {false};
bool validate_parameters {false};
bool validate_features {true};
void CheckDeprecated() {
if (this->n_gpus != 0) {
@ -75,9 +74,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
DMLC_DECLARE_FIELD(validate_parameters)
.set_default(false)
.describe("Enable checking whether parameters are used or not.");
DMLC_DECLARE_FIELD(validate_features)
.set_default(false)
.describe("Enable validating input DMatrix.");
DMLC_DECLARE_FIELD(n_gpus)
.set_default(0)
.set_range(0, 1)

View File

@ -49,7 +49,7 @@ object XGBoost {
Rabit.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, index.toArray, value.map(_.toFloat).toArray)
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)

View File

@ -56,7 +56,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint(0.0f, index.toArray, value.map(_.toFloat).toArray)
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)

View File

@ -38,15 +38,11 @@ object DataUtils extends Serializable {
/**
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
*
* If the point is sparse, the dimensionality of the resulting sparse
* vector would be [[Int.MaxValue]]. This is the only safe value, since
* XGBoost does not store the dimensionality explicitly.
*/
def features: Vector = if (labeledPoint.indices == null) {
Vectors.dense(labeledPoint.values.map(_.toDouble))
} else {
Vectors.sparse(Int.MaxValue, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
}
}
@ -68,9 +64,9 @@ object DataUtils extends Serializable {
*/
def asXGB: XGBLabeledPoint = v match {
case v: DenseVector =>
XGBLabeledPoint(0.0f, null, v.values.map(_.toFloat))
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
case v: SparseVector =>
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
}
}
@ -162,18 +158,18 @@ object DataUtils extends Serializable {
df => df.select(selectedColumns: _*).rdd.map {
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
}
}

View File

@ -0,0 +1,71 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite
import scala.util.Random
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
test("transform throwing exception if feature size of dataset is different with model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val thrown = intercept[Exception] {
model.transform(assembler.transform(df)).show()
}
assert(thrown.getMessage.contains(
"Number of columns does not match number of features in booster"))
}
test("train throwing exception if feature size of dataset is different on distributed train") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
import DataUtils._
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
.map(lp => (lp.label, lp)).partitionBy(
new Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
}
).map(_._2).zipWithIndex().map {
case (lp, id) =>
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
intercept[XGBoostError] {
xgb.fit(repartitioned)
}
}
}

View File

@ -19,13 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.util.Arrays
import scala.io.Source
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.util.Random
import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.functions._
import org.scalatest.FunSuite
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
@ -138,12 +137,21 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
for (x <- 1 to 250) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
model.transform(assembler.transform(df)).show()
df = assembler.transform(df)
for (x <- 1 to 250) {
df = df.drop(s"feature_${x}")
}
model.transform(df).show()
}
}

View File

@ -31,11 +31,12 @@ trait TrainTestData {
Source.fromInputStream(is).getLines()
}
protected def getLabeledPoints(resource: String, zeroBased: Boolean): Seq[XGBLabeledPoint] = {
protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures.head.toFloat
val values = new Array[Float](126)
val values = new Array[Float](featureSize)
for (feature <- labelAndFeatures.tail) {
val idAndValue = feature.split(":")
if (!zeroBased) {
@ -45,7 +46,7 @@ trait TrainTestData {
}
}
XGBLabeledPoint(label, null, values)
XGBLabeledPoint(label, featureSize, null, values)
}.toList
}
@ -56,14 +57,14 @@ trait TrainTestData {
val label = original.head.toFloat
val group = original.last.toInt
val values = original.slice(1, length - 1).map(_.toFloat)
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
XGBLabeledPoint(label, values.size, null, values, 1f, group, Float.NaN)
}.toList
}
}
object Classification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", zeroBased = false)
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
}
object MultiClassification extends TrainTestData {
@ -80,19 +81,24 @@ object MultiClassification extends TrainTestData {
values(i) = featuresAndLabel(i).toFloat
}
XGBLabeledPoint(label, null, values.take(values.length - 1))
XGBLabeledPoint(label, values.length - 1, null, values.take(values.length - 1))
}.toList
}
}
object Regression extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.train", zeroBased = true)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.test", zeroBased = true)
val MACHINE_COL_NUM = 36
val train: Seq[XGBLabeledPoint] = getLabeledPoints(
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
}
object Ranking extends TrainTestData {
val RANK_COL_NUM = 3
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
"/rank.test.txt", RANK_COL_NUM, zeroBased = false)
private def getGroups(resource: String): Seq[Int] = {
getResourceLines(resource).map(_.toInt).toList
@ -100,10 +106,17 @@ object Ranking extends TrainTestData {
}
object Synthetic extends {
val TRAIN_COL_NUM = 3
val TRAIN_WRONG_COL_NUM = 2
val train: Seq[XGBLabeledPoint] = Seq(
XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f)),
XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f))
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f))
)
val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq(
XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f))
)
}

View File

@ -17,12 +17,9 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest {
@ -308,4 +305,5 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
}

View File

@ -16,19 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.TaskContext
import org.apache.spark.{TaskContext}
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
@ -350,12 +344,21 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
toDF("feature", "label").repartition(5)
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
for (x <- 1 to 250) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
df = assembler.transform(df)
for (x <- 1 to 250) {
df = df.drop(s"feature_${x}")
}
val df1 = model.transform(df).withColumnRenamed(
"prediction", "prediction1").withColumnRenamed(
"rawPrediction", "rawPrediction1").withColumnRenamed(
"probability", "probability1")
@ -363,4 +366,5 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
df1.collect()
df2.collect()
}
}

View File

@ -69,8 +69,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
test("test regression prediction parity w/o ring reduce") {
val training = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator, null)
val testDF = buildDataFrame(Classification.test)
val testDF = buildDataFrame(Regression.test)
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
val model1 = new XGBoostRegressor(xgbSettings).fit(training)

View File

@ -49,7 +49,6 @@ public class Booster implements Serializable, KryoSerializable {
*/
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
init(cacheMats);
setParam("validate_features", "0");
setParams(params);
}

View File

@ -27,14 +27,17 @@ class DataBatch {
final int[] featureIndex;
/** value of each non-missing entry in the sparse matrix */
final float[] featureValue ;
/** feature columns */
final int featureCols;
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
float[] featureValue, int featureCols) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
this.featureIndex = featureIndex;
this.featureValue = featureValue;
this.featureCols = featureCols;
}
static class BatchIterator implements Iterator<DataBatch> {
@ -56,9 +59,15 @@ class DataBatch {
try {
int numRows = 0;
int numElem = 0;
int numCol = -1;
List<LabeledPoint> batch = new ArrayList<>(batchSize);
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
if (numCol == -1) {
numCol = labeledPoint.size();
} else if (numCol != labeledPoint.size()) {
throw new RuntimeException("Feature size is not the same");
}
batch.add(labeledPoint);
numElem += labeledPoint.values().length;
numRows++;
@ -91,7 +100,7 @@ class DataBatch {
}
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue, numCol);
} catch (RuntimeException runtimeError) {
logger.error(runtimeError);
return null;

View File

@ -20,6 +20,7 @@ package ml.dmlc.xgboost4j
* Labeled training data point.
*
* @param label Label of this point.
* @param size Feature dimensionality
* @param indices Feature indices of this point or `null` if the data is dense.
* @param values Feature values of this point.
* @param weight Weight of this point.
@ -28,6 +29,7 @@ package ml.dmlc.xgboost4j
*/
case class LabeledPoint(
label: Float,
size: Int,
indices: Array[Int],
values: Array[Float],
weight: Float = 1f,
@ -36,8 +38,11 @@ case class LabeledPoint(
require(indices == null || indices.length == values.length,
"indices and values must have the same number of elements")
def this(label: Float, indices: Array[Int], values: Array[Float]) = {
require(indices == null || size >= indices.length,
"feature dimensionality must be greater equal than size of indices")
def this(label: Float, size: Int, indices: Array[Int], values: Array[Float]) = {
// [[weight]] default duplicated to disambiguate the constructor call.
this(label, indices, values, 1.0f)
this(label, size, indices, values, 1.0f)
}
}

View File

@ -91,9 +91,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
jint jcols = jenv->GetIntField(
batch, jenv->GetFieldID(batchClass, "featureCols", "I"));
XGBoostBatchCSR cbatch;
cbatch.size = jenv->GetArrayLength(joffset) - 1;
cbatch.columns = std::numeric_limits<size_t>::max();
cbatch.columns = jcols;
cbatch.offset = reinterpret_cast<jlong *>(
jenv->GetLongArrayElements(joffset, 0));
if (jlabel != nullptr) {

View File

@ -45,7 +45,7 @@ public class DMatrixTest {
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
LabeledPoint p = new LabeledPoint(
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
@ -57,6 +57,33 @@ public class DMatrixTest {
}
}
@Test
public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError {
//create DMatrix from DataIterator
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
int nrep = 3000;
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
int featureSize = 4;
for (int i = 0; i < nrep; ++i) {
// set some rows with wrong feature size
if (i % 10 == 1) {
featureSize = 5;
}
LabeledPoint p = new LabeledPoint(
0.1f + i, featureSize, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
boolean success = true;
try {
DMatrix dmat = new DMatrix(blist.iterator(), null);
} catch (XGBoostError e) {
success = false;
}
TestCase.assertTrue(success == false);
}
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file

View File

@ -125,6 +125,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
} else {
info_.num_col_ = adapter->NumColumns();
}
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);

View File

@ -1063,19 +1063,9 @@ class LearnerImpl : public LearnerIO {
return tparam_.dsplit == DataSplitMode::kRow ||
tparam_.dsplit == DataSplitMode::kAuto;
};
bool const valid_features =
!row_based_split() ||
(learner_model_param_.num_feature == p_fmat->Info().num_col_);
std::string const msg {
"Number of columns does not match number of features in booster."
};
if (generic_parameters_.validate_features) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << msg;
} else if (!valid_features) {
// Remove this and make the equality check fatal once spark can fix all failing tests.
LOG(WARNING) << msg << " "
<< "Columns: " << p_fmat->Info().num_col_ << " "
<< "Features: " << learner_model_param_.num_feature;
if (row_based_split()) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in booster.";
}
}