[BLOCKING][jvm-packages] fix non-deterministic order within a partition (in the case of an upstream shuffle) on prediction (#4388)
* [jvm-packages][hot-fix] fix column mismatch caused by zip actions at XGBooostModel.transformInternal * apply minibatch in prediction * an iterator-compatible minibatch prediction * regressor impl * continuous working on mini-batch prediction of xgboost4j-spark * Update Booster.java
This commit is contained in:
parent
503cc42f48
commit
2d875ec019
@ -16,30 +16,26 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.Iterator
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.ml.classification._
|
import org.apache.spark.ml.classification._
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
|
|
||||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||||
@ -216,7 +212,8 @@ class XGBoostClassificationModel private[ml](
|
|||||||
override val numClasses: Int,
|
override val numClasses: Int,
|
||||||
private[spark] val _booster: Booster)
|
private[spark] val _booster: Booster)
|
||||||
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
||||||
with XGBoostClassifierParams with MLWritable with Serializable {
|
with XGBoostClassifierParams with InferenceParams
|
||||||
|
with MLWritable with Serializable {
|
||||||
|
|
||||||
import XGBoostClassificationModel._
|
import XGBoostClassificationModel._
|
||||||
|
|
||||||
@ -250,6 +247,8 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||||
|
|
||||||
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -287,46 +286,53 @@ class XGBoostClassificationModel private[ml](
|
|||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
val appName = dataset.sparkSession.sparkContext.appName
|
||||||
|
|
||||||
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
new AbstractIterator[Row] {
|
||||||
if (rowIterator.hasNext) {
|
private var batchCnt = 0
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
if (batchCnt == 0) {
|
||||||
$(featuresCol))).toList.iterator
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
import DataUtils._
|
Rabit.init(rabitEnv.asJava)
|
||||||
val cacheInfo = {
|
}
|
||||||
if ($(useExternalMemory)) {
|
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
|
||||||
} else {
|
|
||||||
null
|
import DataUtils._
|
||||||
|
val cacheInfo = {
|
||||||
|
if ($(useExternalMemory)) {
|
||||||
|
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
||||||
|
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dm = new DMatrix(
|
||||||
|
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
|
||||||
|
cacheInfo)
|
||||||
|
try {
|
||||||
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
|
producePredictionItrs(bBooster, dm)
|
||||||
|
produceResultIterator(batchRow.iterator,
|
||||||
|
rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
||||||
|
} finally {
|
||||||
|
batchCnt += 1
|
||||||
|
dm.delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dm = new DMatrix(
|
|
||||||
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||||
cacheInfo)
|
|
||||||
try {
|
override def next(): Row = {
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
val ret = batchIterImpl.next()
|
||||||
producePredictionItrs(bBooster, dm)
|
if (!batchIterImpl.hasNext) {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
Iterator(rawPredictionItr, probabilityItr, predLeafItr,
|
}
|
||||||
predContribItr)
|
ret
|
||||||
} finally {
|
|
||||||
dm.delete()
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
|
||||||
case (inputIterator, predictionItr) =>
|
|
||||||
if (inputIterator.hasNext) {
|
|
||||||
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
|
||||||
predictionItr.next(), predictionItr.next())
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||||
@ -527,4 +533,3 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,10 +16,10 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.Iterator
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.{Rabit, XGBoost => JXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
@ -37,7 +37,7 @@ import org.apache.spark.sql._
|
|||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
|
||||||
@ -207,7 +207,8 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
override val uid: String,
|
override val uid: String,
|
||||||
private[spark] val _booster: Booster)
|
private[spark] val _booster: Booster)
|
||||||
extends PredictionModel[Vector, XGBoostRegressionModel]
|
extends PredictionModel[Vector, XGBoostRegressionModel]
|
||||||
with XGBoostRegressorParams with MLWritable with Serializable {
|
with XGBoostRegressorParams with InferenceParams
|
||||||
|
with MLWritable with Serializable {
|
||||||
|
|
||||||
import XGBoostRegressionModel._
|
import XGBoostRegressionModel._
|
||||||
|
|
||||||
@ -241,6 +242,8 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||||
|
|
||||||
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -259,45 +262,53 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
val appName = dataset.sparkSession.sparkContext.appName
|
||||||
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
|
||||||
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
if (rowIterator.hasNext) {
|
new AbstractIterator[Row] {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
private var batchCnt = 0
|
||||||
Rabit.init(rabitEnv.asJava)
|
|
||||||
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
$(featuresCol))).toList.iterator
|
if (batchCnt == 0) {
|
||||||
import DataUtils._
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
val cacheInfo = {
|
Rabit.init(rabitEnv.asJava)
|
||||||
if ($(useExternalMemory)) {
|
}
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
|
||||||
} else {
|
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
|
||||||
null
|
|
||||||
|
import DataUtils._
|
||||||
|
val cacheInfo = {
|
||||||
|
if ($(useExternalMemory)) {
|
||||||
|
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
||||||
|
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dm = new DMatrix(
|
||||||
|
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
|
||||||
|
cacheInfo)
|
||||||
|
try {
|
||||||
|
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
||||||
|
producePredictionItrs(bBooster, dm)
|
||||||
|
produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr)
|
||||||
|
} finally {
|
||||||
|
batchCnt += 1
|
||||||
|
dm.delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dm = new DMatrix(
|
|
||||||
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||||
cacheInfo)
|
|
||||||
try {
|
override def next(): Row = {
|
||||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
val ret = batchIterImpl.next()
|
||||||
producePredictionItrs(bBooster, dm)
|
if (!batchIterImpl.hasNext) {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
Iterator(originalPredictionItr, predLeafItr, predContribItr)
|
}
|
||||||
} finally {
|
ret
|
||||||
dm.delete()
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
|
||||||
case (inputIterator, predictionItr) =>
|
|
||||||
if (inputIterator.hasNext) {
|
|
||||||
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
|
||||||
predictionItr.next())
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||||
}
|
}
|
||||||
@ -347,14 +358,14 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
resultSchema
|
resultSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
|
||||||
Array[Iterator[Row]] = {
|
Array[Iterator[Row]] = {
|
||||||
val originalPredictionItr = {
|
val originalPredictionItr = {
|
||||||
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
||||||
}
|
}
|
||||||
val predLeafItr = {
|
val predLeafItr = {
|
||||||
if (isDefined(leafPredictionCol)) {
|
if (isDefined(leafPredictionCol)) {
|
||||||
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).
|
booster.value.predictLeaf(dm, $(treeLimit)).
|
||||||
map(Row(_)).iterator
|
map(Row(_)).iterator
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
@ -362,7 +373,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
}
|
}
|
||||||
val predContribItr = {
|
val predContribItr = {
|
||||||
if (isDefined(contribPredictionCol)) {
|
if (isDefined(contribPredictionCol)) {
|
||||||
broadcastBooster.value.predictContrib(dm, $(treeLimit)).
|
booster.value.predictContrib(dm, $(treeLimit)).
|
||||||
map(Row(_)).iterator
|
map(Row(_)).iterator
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
@ -373,7 +384,6 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
|
||||||
// Output selected columns only.
|
// Output selected columns only.
|
||||||
// This is a bit complicated since it tries to avoid repeated computation.
|
// This is a bit complicated since it tries to avoid repeated computation.
|
||||||
var outputData = transformInternal(dataset)
|
var outputData = transformInternal(dataset)
|
||||||
|
|||||||
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param.{IntParam, Params}
|
||||||
|
|
||||||
|
private[spark] trait InferenceParams extends Params {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* batch size of inference iteration
|
||||||
|
*/
|
||||||
|
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getInferBatchSize: Int = ${inferBatchSize}
|
||||||
|
|
||||||
|
setDefault(inferBatchSize, 32 << 10)
|
||||||
|
}
|
||||||
@ -19,11 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||||
|
|
||||||
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
||||||
@ -80,6 +81,18 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
.toDF("id", "label", "features")
|
.toDF("id", "label", "features")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected def buildDataFrameWithRandSort(
|
||||||
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
|
val df = buildDataFrame(labeledPoints, numPartitions)
|
||||||
|
val rndSortedRDD = df.rdd.mapPartitions { iter =>
|
||||||
|
iter.map(_ -> Random.nextDouble()).toList
|
||||||
|
.sortBy(_._2)
|
||||||
|
.map(_._1).iterator
|
||||||
|
}
|
||||||
|
ss.createDataFrame(rndSortedRDD, df.schema)
|
||||||
|
}
|
||||||
|
|
||||||
protected def buildDataFrameWithGroup(
|
protected def buildDataFrameWithGroup(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
|
|||||||
@ -27,13 +27,28 @@ import org.apache.spark.Partitioner
|
|||||||
|
|
||||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
|
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
val testDF = buildDataFrame(Classification.test)
|
val testDF = buildDataFrame(Classification.test)
|
||||||
val round = 5
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostClassifier should make correct predictions after upstream random sort") {
|
||||||
|
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
val trainingDF = buildDataFrameWithRandSort(Classification.train)
|
||||||
|
val testDF = buildDataFrameWithRandSort(Classification.test)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkResultsWithXGBoost4j(
|
||||||
|
trainingDM: DMatrix,
|
||||||
|
testDM: DMatrix,
|
||||||
|
trainingDF: DataFrame,
|
||||||
|
testDF: DataFrame,
|
||||||
|
round: Int = 5): Unit = {
|
||||||
val paramMap = Map(
|
val paramMap = Map(
|
||||||
"eta" -> "1",
|
"eta" -> "1",
|
||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
@ -47,7 +62,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).
|
val prediction2 = model2.transform(testDF).
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
||||||
|
|
||||||
assert(testDF.count() === prediction2.size)
|
assert(testDF.count() === prediction2.size)
|
||||||
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
||||||
@ -60,7 +75,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
||||||
val prediction4 = model2.transform(testDF).
|
val prediction4 = model2.transform(testDF).
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
||||||
|
|
||||||
assert(testDF.count() === prediction4.size)
|
assert(testDF.count() === prediction4.size)
|
||||||
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
||||||
@ -73,7 +88,9 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
// check the equality of single instance prediction
|
// check the equality of single instance prediction
|
||||||
val firstOfDM = testDM.slice(Array(0))
|
val firstOfDM = testDM.slice(Array(0))
|
||||||
val firstOfDF = testDF.head().getAs[Vector]("features")
|
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
||||||
|
.head()
|
||||||
|
.getAs[Vector]("features")
|
||||||
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
||||||
val prediction6 = model2.predict(firstOfDF)
|
val prediction6 = model2.predict(firstOfDF)
|
||||||
assert(prediction5 === prediction6)
|
assert(prediction5 === prediction6)
|
||||||
|
|||||||
@ -463,4 +463,42 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("infer with different batch sizes") {
|
||||||
|
val regModel = new XGBoostRegressor(Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"silent" -> "1",
|
||||||
|
"objective" -> "reg:squarederror",
|
||||||
|
"num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers))
|
||||||
|
.fit(buildDataFrame(Regression.train))
|
||||||
|
val regDF = buildDataFrame(Regression.test)
|
||||||
|
|
||||||
|
val regRet1 = regModel.transform(regDF).collect()
|
||||||
|
val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
|
||||||
|
val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
|
||||||
|
val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
|
||||||
|
assert(regRet1 sameElements regRet2)
|
||||||
|
assert(regRet1 sameElements regRet3)
|
||||||
|
assert(regRet1 sameElements regRet4)
|
||||||
|
|
||||||
|
val clsModel = new XGBoostClassifier(Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
"num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers))
|
||||||
|
.fit(buildDataFrame(Classification.train))
|
||||||
|
val clsDF = buildDataFrame(Classification.test)
|
||||||
|
|
||||||
|
val clsRet1 = clsModel.transform(clsDF).collect()
|
||||||
|
val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
|
||||||
|
val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
|
||||||
|
val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
|
||||||
|
assert(clsRet1 sameElements clsRet2)
|
||||||
|
assert(clsRet1 sameElements clsRet3)
|
||||||
|
assert(clsRet1 sameElements clsRet4)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,19 +19,34 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
|
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||||
val testDM = new DMatrix(Regression.test.iterator)
|
val testDM = new DMatrix(Regression.test.iterator)
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val round = 5
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostRegressor should make correct predictions after upstream random sort") {
|
||||||
|
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||||
|
val testDM = new DMatrix(Regression.test.iterator)
|
||||||
|
val trainingDF = buildDataFrameWithRandSort(Regression.train)
|
||||||
|
val testDF = buildDataFrameWithRandSort(Regression.test)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkResultsWithXGBoost4j(
|
||||||
|
trainingDM: DMatrix,
|
||||||
|
testDM: DMatrix,
|
||||||
|
trainingDF: DataFrame,
|
||||||
|
testDF: DataFrame,
|
||||||
|
round: Int = 5): Unit = {
|
||||||
val paramMap = Map(
|
val paramMap = Map(
|
||||||
"eta" -> "1",
|
"eta" -> "1",
|
||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
@ -45,7 +60,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).
|
val prediction2 = model2.transform(testDF).
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
|
||||||
|
|
||||||
assert(prediction1.indices.count { i =>
|
assert(prediction1.indices.count { i =>
|
||||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
||||||
@ -54,7 +69,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
// check the equality of single instance prediction
|
// check the equality of single instance prediction
|
||||||
val firstOfDM = testDM.slice(Array(0))
|
val firstOfDM = testDM.slice(Array(0))
|
||||||
val firstOfDF = testDF.head().getAs[Vector]("features")
|
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
||||||
|
.head()
|
||||||
|
.getAs[Vector]("features")
|
||||||
val prediction3 = model1.predict(firstOfDM)(0)(0)
|
val prediction3 = model1.predict(firstOfDM)(0)(0)
|
||||||
val prediction4 = model2.predict(firstOfDF)
|
val prediction4 = model2.predict(firstOfDF)
|
||||||
assert(math.abs(prediction3 - prediction4) <= 0.01f)
|
assert(math.abs(prediction3 - prediction4) <= 0.01f)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user