[jvm-packages]add feature size for LabelPoint and DataBatch (#5303)
* fix type error * Validate number of features. * resolve comments * add feature size for LabelPoint and DataBatch * pass the feature size to native * move feature size validating tests into a separate suite * resolve comments Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -38,15 +38,11 @@ object DataUtils extends Serializable {
|
||||
|
||||
/**
|
||||
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
|
||||
*
|
||||
* If the point is sparse, the dimensionality of the resulting sparse
|
||||
* vector would be [[Int.MaxValue]]. This is the only safe value, since
|
||||
* XGBoost does not store the dimensionality explicitly.
|
||||
*/
|
||||
def features: Vector = if (labeledPoint.indices == null) {
|
||||
Vectors.dense(labeledPoint.values.map(_.toDouble))
|
||||
} else {
|
||||
Vectors.sparse(Int.MaxValue, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
||||
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,9 +64,9 @@ object DataUtils extends Serializable {
|
||||
*/
|
||||
def asXGB: XGBLabeledPoint = v match {
|
||||
case v: DenseVector =>
|
||||
XGBLabeledPoint(0.0f, null, v.values.map(_.toFloat))
|
||||
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
|
||||
case v: SparseVector =>
|
||||
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
|
||||
XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,18 +158,18 @@ object DataUtils extends Serializable {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||
baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user