[jvm-packages] Parameter tuning tool for XGBoost (#1664)
This commit is contained in:
parent
ac41845d4b
commit
016ab89484
@ -217,7 +217,7 @@
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>2.2.6</version>
|
||||
<version>3.0.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@ -0,0 +1,206 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoost}
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
|
||||
import org.apache.spark.ml.tuning._
|
||||
import org.apache.spark.sql.{Dataset, DataFrame, SparkSession}
|
||||
|
||||
case class SalesRecord(storeId: Int, daysOfWeek: Int, date: String, sales: Int, customers: Int,
|
||||
open: Int, promo: Int, stateHoliday: String, schoolHoliday: String)
|
||||
|
||||
case class Store(storeId: Int, storeType: String, assortment: String, competitionDistance: Int,
|
||||
competitionOpenSinceMonth: Int, competitionOpenSinceYear: Int, promo2: Int,
|
||||
promo2SinceWeek: Int, promo2SinceYear: Int, promoInterval: String)
|
||||
|
||||
object Main {
|
||||
|
||||
private def parseStoreFile(storeFilePath: String): List[Store] = {
|
||||
var isHeader = true
|
||||
val storeInstances = new ListBuffer[Store]
|
||||
for (line <- Source.fromFile(storeFilePath).getLines()) {
|
||||
if (isHeader) {
|
||||
isHeader = false
|
||||
} else {
|
||||
try {
|
||||
val strArray = line.split(",")
|
||||
if (strArray.length == 10) {
|
||||
val Array(storeIdStr, storeTypeStr, assortmentStr, competitionDistanceStr,
|
||||
competitionOpenSinceMonthStr, competitionOpenSinceYearStr, promo2Str,
|
||||
promo2SinceWeekStr, promo2SinceYearStr, promoIntervalStr) = line.split(",")
|
||||
storeInstances += Store(storeIdStr.toInt, storeTypeStr, assortmentStr,
|
||||
if (competitionDistanceStr == "") -1 else competitionDistanceStr.toInt,
|
||||
if (competitionOpenSinceMonthStr == "" ) -1 else competitionOpenSinceMonthStr.toInt,
|
||||
if (competitionOpenSinceYearStr == "" ) -1 else competitionOpenSinceYearStr.toInt,
|
||||
promo2Str.toInt,
|
||||
if (promo2Str == "0") -1 else promo2SinceWeekStr.toInt,
|
||||
if (promo2Str == "0") -1 else promo2SinceYearStr.toInt,
|
||||
promoIntervalStr.replace("\"", ""))
|
||||
} else {
|
||||
val Array(storeIdStr, storeTypeStr, assortmentStr, competitionDistanceStr,
|
||||
competitionOpenSinceMonthStr, competitionOpenSinceYearStr, promo2Str,
|
||||
promo2SinceWeekStr, promo2SinceYearStr, firstMonth, secondMonth, thirdMonth,
|
||||
forthMonth) = line.split(",")
|
||||
storeInstances += Store(storeIdStr.toInt, storeTypeStr, assortmentStr,
|
||||
if (competitionDistanceStr == "") -1 else competitionDistanceStr.toInt,
|
||||
if (competitionOpenSinceMonthStr == "" ) -1 else competitionOpenSinceMonthStr.toInt,
|
||||
if (competitionOpenSinceYearStr == "" ) -1 else competitionOpenSinceYearStr.toInt,
|
||||
promo2Str.toInt,
|
||||
if (promo2Str == "0") -1 else promo2SinceWeekStr.toInt,
|
||||
if (promo2Str == "0") -1 else promo2SinceYearStr.toInt,
|
||||
firstMonth.replace("\"", "") + "," + secondMonth + "," + thirdMonth + "," +
|
||||
forthMonth.replace("\"", ""))
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
e.printStackTrace()
|
||||
sys.exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
storeInstances.toList
|
||||
}
|
||||
|
||||
private def parseTrainingFile(trainingPath: String): List[SalesRecord] = {
|
||||
var isHeader = true
|
||||
val records = new ListBuffer[SalesRecord]
|
||||
for (line <- Source.fromFile(trainingPath).getLines()) {
|
||||
if (isHeader) {
|
||||
isHeader = false
|
||||
} else {
|
||||
val Array(storeIdStr, daysOfWeekStr, dateStr, salesStr, customerStr, openStr, promoStr,
|
||||
stateHolidayStr, schoolHolidayStr) = line.split(",")
|
||||
val salesRecord = SalesRecord(storeIdStr.toInt, daysOfWeekStr.toInt, dateStr,
|
||||
salesStr.toInt, customerStr.toInt, openStr.toInt, promoStr.toInt, stateHolidayStr,
|
||||
schoolHolidayStr)
|
||||
records += salesRecord
|
||||
}
|
||||
}
|
||||
records.toList
|
||||
}
|
||||
|
||||
private def featureEngineering(ds: DataFrame): DataFrame = {
|
||||
import org.apache.spark.sql.functions._
|
||||
import ds.sparkSession.implicits._
|
||||
val stateHolidayIndexer = new StringIndexer()
|
||||
.setInputCol("stateHoliday")
|
||||
.setOutputCol("stateHolidayIndex")
|
||||
val schoolHolidayIndexer = new StringIndexer()
|
||||
.setInputCol("schoolHoliday")
|
||||
.setOutputCol("schoolHolidayIndex")
|
||||
val storeTypeIndexer = new StringIndexer()
|
||||
.setInputCol("storeType")
|
||||
.setOutputCol("storeTypeIndex")
|
||||
val assortmentIndexer = new StringIndexer()
|
||||
.setInputCol("assortment")
|
||||
.setOutputCol("assortmentIndex")
|
||||
val promoInterval = new StringIndexer()
|
||||
.setInputCol("promoInterval")
|
||||
.setOutputCol("promoIntervalIndex")
|
||||
val filteredDS = ds.filter($"sales" > 0).filter($"open" > 0)
|
||||
// parse date
|
||||
val dsWithDayCol =
|
||||
filteredDS.withColumn("day", udf((dateStr: String) =>
|
||||
dateStr.split("-")(2).toInt).apply(col("date")))
|
||||
val dsWithMonthCol =
|
||||
dsWithDayCol.withColumn("month", udf((dateStr: String) =>
|
||||
dateStr.split("-")(1).toInt).apply(col("date")))
|
||||
val dsWithYearCol =
|
||||
dsWithMonthCol.withColumn("year", udf((dateStr: String) =>
|
||||
dateStr.split("-")(0).toInt).apply(col("date")))
|
||||
val dsWithLogSales = dsWithYearCol.withColumn("logSales",
|
||||
udf((sales: Int) => math.log(sales)).apply(col("sales")))
|
||||
|
||||
// fill with mean values
|
||||
val meanCompetitionDistance = dsWithLogSales.select(avg("competitionDistance")).first()(0).
|
||||
asInstanceOf[Double]
|
||||
println("====" + meanCompetitionDistance)
|
||||
val finalDS = dsWithLogSales.withColumn("transformedCompetitionDistance",
|
||||
udf((distance: Int) => if (distance > 0) distance.toDouble else meanCompetitionDistance).
|
||||
apply(col("competitionDistance")))
|
||||
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("storeId", "daysOfWeek", "promo", "competitionDistance", "promo2", "day",
|
||||
"month", "year", "transformedCompetitionDistance", "stateHolidayIndex",
|
||||
"schoolHolidayIndex", "storeTypeIndex", "assortmentIndex", "promoIntervalIndex"))
|
||||
.setOutputCol("features")
|
||||
|
||||
val pipeline = new Pipeline().setStages(
|
||||
Array(stateHolidayIndexer, schoolHolidayIndexer, storeTypeIndexer, assortmentIndexer,
|
||||
promoInterval, vectorAssembler))
|
||||
|
||||
pipeline.fit(finalDS).transform(finalDS).
|
||||
drop("stateHoliday", "schoolHoliday", "storeType", "assortment", "promoInterval", "sales",
|
||||
"promo2SinceWeek", "customers", "promoInterval", "competitionOpenSinceYear",
|
||||
"competitionOpenSinceMonth", "promo2SinceYear", "competitionDistance", "date")
|
||||
}
|
||||
|
||||
private def crossValidation(
|
||||
xgboostParam: Map[String, Any],
|
||||
trainingData: Dataset[_]): TrainValidationSplitModel = {
|
||||
val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features").
|
||||
setLabelCol("logSales")
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(xgbEstimator.round, Array(20, 50))
|
||||
.addGrid(xgbEstimator.eta, Array(0.1, 0.4))
|
||||
.build()
|
||||
val tv = new TrainValidationSplit()
|
||||
.setEstimator(xgbEstimator)
|
||||
.setEvaluator(new RegressionEvaluator().setLabelCol("logSales"))
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
.setTrainRatio(0.8) // Use 3+ in practice
|
||||
tv.fit(trainingData)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val sparkSession = SparkSession.builder().appName("rosseman").getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
|
||||
// parse training file to data frame
|
||||
val trainingPath = args(0)
|
||||
val allSalesRecords = parseTrainingFile(trainingPath)
|
||||
// create dataset
|
||||
val salesRecordsDF = allSalesRecords.toDF
|
||||
|
||||
// parse store file to data frame
|
||||
val storeFilePath = args(1)
|
||||
val allStores = parseStoreFile(storeFilePath)
|
||||
val storesDS = allStores.toDF()
|
||||
|
||||
val fullDataset = salesRecordsDF.join(storesDS, "storeId")
|
||||
val featureEngineeredDF = featureEngineering(fullDataset)
|
||||
// prediction
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 0.1
|
||||
params += "max_depth" -> 6
|
||||
params += "silent" -> 1
|
||||
params += "ntreelimit" -> 1000
|
||||
params += "objective" -> "reg:linear"
|
||||
params += "subsample" -> 0.8
|
||||
params += "round" -> 100
|
||||
|
||||
val bestModel = crossValidation(params.toMap, featureEngineeredDF)
|
||||
}
|
||||
}
|
||||
@ -20,39 +20,38 @@ import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{SQLContext, Row}
|
||||
import org.apache.spark.sql.{SparkSession, SQLContext, Row}
|
||||
import org.apache.spark.{SparkContext, SparkConf}
|
||||
|
||||
object SparkWithDataFrame {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 5) {
|
||||
if (args.length != 4) {
|
||||
println(
|
||||
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||
"usage: program num_of_rounds num_workers training_path test_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
// create SparkSession
|
||||
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
val sqlContext = new SQLContext(new SparkContext(sparkConf))
|
||||
// val sqlContext = new SQLContext(new SparkContext(sparkConf))
|
||||
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
|
||||
// create training and testing dataframes
|
||||
val numRound = args(0).toInt
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val outputModelPath = args(4)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
import DataUtils._
|
||||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTrainPath).
|
||||
// build dataset
|
||||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath).
|
||||
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
||||
val trainDF = sqlContext.createDataFrame(trainRDDOfRows, StructType(
|
||||
val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
|
||||
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
val testRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTestPath).
|
||||
val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
|
||||
zipWithIndex().map{ case (labeledPoint, id) =>
|
||||
Row(id, labeledPoint.features, labeledPoint.label)}
|
||||
val testDF = sqlContext.createDataFrame(testRDDOfRows, StructType(
|
||||
val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
|
||||
Array(StructField("id", LongType),
|
||||
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
// training parameters
|
||||
// start training
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
|
||||
@ -49,7 +49,7 @@ object SparkWithRDD {
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
useExternalMemory = true)
|
||||
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||
// save model to HDFS path
|
||||
|
||||
@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
|
||||
import org.apache.spark.ml.linalg.SparseVector
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
@ -67,14 +67,8 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val partitionedData = {
|
||||
private def repartitionData(trainingData: RDD[MLLabeledPoint], numWorkers: Int):
|
||||
RDD[MLLabeledPoint] = {
|
||||
if (numWorkers != trainingData.partitions.length) {
|
||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||
trainingData.repartition(numWorkers)
|
||||
@ -82,18 +76,27 @@ object XGBoost extends Serializable {
|
||||
trainingData
|
||||
}
|
||||
}
|
||||
val appName = partitionedData.context.appName
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingSet: RDD[MLLabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val partitionedTrainingSet = repartitionData(trainingSet, numWorkers)
|
||||
val appName = partitionedTrainingSet.context.appName
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
partitionedData.mapPartitions {
|
||||
partitionedTrainingSet.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
var booster: Booster = null
|
||||
if (trainingSamples.hasNext) {
|
||||
val cacheFileName: String = {
|
||||
if (useExternalMemory && trainingSamples.hasNext) {
|
||||
if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-" +
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
@ -146,14 +149,24 @@ object XGBoost extends Serializable {
|
||||
featureCol: String = "features",
|
||||
labelCol: String = "label"): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
val estimator = new XGBoostEstimator(params, round, nWorkers, obj, eval,
|
||||
useExternalMemory, missing)
|
||||
estimator.setFeaturesCol(featureCol).setLabelCol(labelCol).fit(trainingData)
|
||||
val estimator = new XGBoostEstimator(params)
|
||||
// assigning general parameters
|
||||
estimator.
|
||||
set(estimator.useExternalMemory, useExternalMemory).
|
||||
set(estimator.round, round).
|
||||
set(estimator.nWorkers, nWorkers).
|
||||
set(estimator.customObj, obj).
|
||||
set(estimator.customEval, eval).
|
||||
set(estimator.missing, missing).
|
||||
setFeaturesCol(featureCol).
|
||||
setLabelCol(labelCol).
|
||||
fit(trainingData)
|
||||
}
|
||||
|
||||
private[spark] def isClassificationTask(objective: Option[Any]): Boolean = {
|
||||
objective.isDefined && {
|
||||
val objStr = objective.get.toString
|
||||
private[spark] def isClassificationTask(paramsMap: Map[String, Any]): Boolean = {
|
||||
val objective = paramsMap.getOrElse("objective", paramsMap.getOrElse("obj_type", null))
|
||||
objective != null && {
|
||||
val objStr = objective.toString
|
||||
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
objStr != "rank:pairwise")
|
||||
}
|
||||
@ -162,7 +175,7 @@ object XGBoost extends Serializable {
|
||||
/**
|
||||
*
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* @param configMap Map containing the configuration entries
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
@ -174,19 +187,40 @@ object XGBoost extends Serializable {
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@deprecated(since = "0.7", message = "this method is deprecated since 0.7, users are encouraged" +
|
||||
" to switch to trainWithRDD")
|
||||
def train(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
def train(
|
||||
trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
trainWithRDD(trainingData, configMap, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
}
|
||||
|
||||
private def overrideParamMapAccordingtoTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
|
||||
var overridedParams = params
|
||||
if (overridedParams.contains("nthread")) {
|
||||
val nThread = overridedParams("nthread").toString.toInt
|
||||
require(nThread <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
} else {
|
||||
overridedParams = params + ("nthread" -> coresPerTask)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int): RabitTracker = {
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* @param configMap Map containing the configuration entries
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
@ -199,28 +233,18 @@ object XGBoost extends Serializable {
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithRDD(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
def trainWithRDD(
|
||||
trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(configMap.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
}
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
implicit val sc = trainingData.sparkContext
|
||||
var overridedConfMap = configMap
|
||||
if (overridedConfMap.contains("nthread")) {
|
||||
val nThread = overridedConfMap("nthread").toString.toInt
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
|
||||
require(nThread <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
} else {
|
||||
overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt)
|
||||
}
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val tracker = startTracker(nWorkers)
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
@ -230,16 +254,19 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
sparkJobThread.start()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
if (returnVal == 0) {
|
||||
convertBoosterToXGBoostModel(boosters.first(),
|
||||
isClassificationTask(
|
||||
if (obj == null) {
|
||||
configMap.get("objective")
|
||||
} else {
|
||||
configMap.get("obj_type")
|
||||
}))
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
isClsTask)
|
||||
}
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
|
||||
configMap: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
||||
XGBoostModel = {
|
||||
if (trackerReturnVal == 0) {
|
||||
convertBoosterToXGBoostModel(distributedBoosters.first(), isClassificationTask)
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
|
||||
@ -18,19 +18,22 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector, DenseVector => MLDenseVector}
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
class XGBoostClassificationModel private[spark](
|
||||
override val uid: String, _booster: Booster)
|
||||
extends XGBoostModel(_booster) {
|
||||
override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), _booster)
|
||||
def this(booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
@ -57,16 +60,28 @@ class XGBoostClassificationModel private[spark](
|
||||
|
||||
// scalastyle:on
|
||||
|
||||
// generate dataframe containing raw prediction column which is typed as Vector
|
||||
private def predictRaw(
|
||||
testSet: Dataset[_],
|
||||
temporalColName: Option[String] = None,
|
||||
forceTransformedScore: Option[Boolean] = None): DataFrame = {
|
||||
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
|
||||
testSet.sparkSession.createDataFrame(predictRDD, schema = {
|
||||
StructType(testSet.schema.add(StructField(
|
||||
temporalColName.getOrElse($(rawPredictionCol)),
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
val colName = temporalColName.getOrElse($(rawPredictionCol))
|
||||
val tempColName = colName + "_arraytype"
|
||||
val dsWithArrayTypedRawPredCol = testSet.sparkSession.createDataFrame(predictRDD, schema = {
|
||||
testSet.schema.add(tempColName, ArrayType(FloatType, containsNull = false))
|
||||
})
|
||||
val transformerForProbabilitiesArray =
|
||||
(rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
if (numClasses == 2) {
|
||||
Array(1 - rawPredArray(0), rawPredArray(0)).map(_.toDouble)
|
||||
} else {
|
||||
rawPredArray.map(_.toDouble).array
|
||||
}
|
||||
dsWithArrayTypedRawPredCol.withColumn(colName,
|
||||
udf((rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
new MLDenseVector(transformerForProbabilitiesArray(rawPredArray))).apply(col(tempColName))).
|
||||
drop(tempColName)
|
||||
}
|
||||
|
||||
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
|
||||
@ -77,28 +92,28 @@ class XGBoostClassificationModel private[spark](
|
||||
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
|
||||
}
|
||||
|
||||
private def argMax(vector: mutable.WrappedArray[Float]): Double = {
|
||||
private def argMax(vector: Array[Double]): Double = {
|
||||
vector.zipWithIndex.maxBy(_._1)._2
|
||||
}
|
||||
|
||||
private def raw2prediction(rawPrediction: mutable.WrappedArray[Float]): Double = {
|
||||
private def raw2prediction(rawPrediction: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(rawPrediction)
|
||||
argMax(rawPrediction.values)
|
||||
} else {
|
||||
probability2prediction(rawPrediction)
|
||||
}
|
||||
}
|
||||
|
||||
private def probability2prediction(probability: mutable.WrappedArray[Float]): Double = {
|
||||
private def probability2prediction(probability: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(probability)
|
||||
argMax(probability.values)
|
||||
} else {
|
||||
val thresholds: Array[Double] = getThresholds
|
||||
val scaledProbability: mutable.WrappedArray[Double] =
|
||||
probability.zip(thresholds).map { case (p, t) =>
|
||||
val scaledProbability =
|
||||
probability.values.zip(thresholds).map { case (p, t) =>
|
||||
if (t == 0.0) Double.PositiveInfinity else p / t
|
||||
}
|
||||
argMax(scaledProbability.map(_.toFloat))
|
||||
argMax(scaledProbability)
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,7 +159,9 @@ class XGBoostClassificationModel private[spark](
|
||||
def numClasses: Int = numOfClasses
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||
defaultCopy(extra)
|
||||
val clsModel = defaultCopy(extra).asInstanceOf[XGBoostClassificationModel]
|
||||
clsModel._booster = booster
|
||||
clsModel
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
|
||||
@ -16,38 +16,97 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams}
|
||||
import org.apache.spark.ml.Predictor
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector, VectorUDT}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{StructType, DoubleType}
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
|
||||
/**
|
||||
* the estimator wrapping XGBoost to produce a training model
|
||||
*
|
||||
* @param xgboostParams the parameters configuring XGBoost
|
||||
* @param round the number of iterations to train
|
||||
* @param nWorkers the total number of workers of xgboost
|
||||
* @param obj the customized objective function, default to be null and using the default in model
|
||||
* @param eval the customized eval function, default to be null and using the default in model
|
||||
* @param useExternalMemory whether to use external memory when training
|
||||
* @param missing the value taken as missing
|
||||
*/
|
||||
class XGBoostEstimator private[spark](
|
||||
override val uid: String, xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
||||
obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, missing: Float)
|
||||
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] {
|
||||
override val uid: String, private[spark] var xgboostParams: Map[String, Any])
|
||||
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
||||
with LearningTaskParams with GeneralParams with BoosterParams {
|
||||
|
||||
def this(xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) =
|
||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean,
|
||||
missing: Float)
|
||||
def this(xgboostParams: Map[String, Any]) =
|
||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
|
||||
// called in fromXGBParamMapToParams only when eval_metric is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
|
||||
if (objFunc == null) {
|
||||
"rmse"
|
||||
} else {
|
||||
// compute default metric based on specified objective
|
||||
val isClassificationTask = XGBoost.isClassificationTask(xgboostParams)
|
||||
if (!isClassificationTask) {
|
||||
// default metric for regression or ranking
|
||||
if (objFunc.toString.startsWith("rank")) {
|
||||
"map"
|
||||
} else {
|
||||
"rmse"
|
||||
}
|
||||
} else {
|
||||
// default metric for classification
|
||||
if (objFunc.toString.startsWith("multi")) {
|
||||
// multi
|
||||
"merror"
|
||||
} else {
|
||||
// binary
|
||||
"error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def fromXGBParamMapToParams(): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
params.find(_.name == paramName) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
set(paramName, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
set(paramName, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
set(paramName, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
set(paramName, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
set(paramName, paramValue)
|
||||
}
|
||||
}
|
||||
if (xgboostParams.get("eval_metric").isEmpty) {
|
||||
set("eval_metric", setupDefaultEvalMetric())
|
||||
}
|
||||
}
|
||||
|
||||
fromXGBParamMapToParams()
|
||||
|
||||
// only called when XGBParamMap is empty, i.e. in the constructor this(String)
|
||||
// TODO: refactor to be functional
|
||||
private def fromParamsToXGBParamMap(): Map[String, Any] = {
|
||||
require(xgboostParams.isEmpty, "fromParamsToXGBParamMap can only be called when" +
|
||||
" XGBParamMap is empty, i.e. in the constructor this(String)")
|
||||
val xgbParamMap = new mutable.HashMap[String, Any]()
|
||||
for (param <- params) {
|
||||
xgbParamMap += param.name -> $(param)
|
||||
}
|
||||
xgboostParams = xgbParamMap.toMap
|
||||
xgbParamMap.toMap
|
||||
}
|
||||
|
||||
/**
|
||||
* produce a XGBoostModel by fitting the given dataset
|
||||
@ -59,16 +118,14 @@ class XGBoostEstimator private[spark](
|
||||
LabeledPoint(label, feature)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
|
||||
eval, useExternalMemory, missing).setParent(this)
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, $(round), $(nWorkers),
|
||||
$(customObj), $(customEval), $(useExternalMemory), $(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel)
|
||||
if (XGBoost.isClassificationTask(
|
||||
if (obj == null) xgboostParams.get("objective") else xgboostParams.get("obj_type"))) {
|
||||
if (XGBoost.isClassificationTask(xgboostParams)) {
|
||||
val numClass = {
|
||||
if (xgboostParams.contains("num_class")) {
|
||||
xgboostParams("num_class").asInstanceOf[Int]
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
2
|
||||
}
|
||||
}
|
||||
@ -78,6 +135,11 @@ class XGBoostEstimator private[spark](
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||
defaultCopy(extra)
|
||||
val est = defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||
// we need to synchronize the params here instead of in the constructor
|
||||
// because we cannot guarantee that params (default implementation) is initialized fully
|
||||
// before the other params
|
||||
est.fromParamsToXGBParamMap()
|
||||
est
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,10 +27,10 @@ import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVec
|
||||
import org.apache.spark.ml.param.{Param, Params}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.types.{FloatType, ArrayType, DataType}
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
abstract class XGBoostModel(_booster: Booster)
|
||||
abstract class XGBoostModel(protected var _booster: Booster)
|
||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
|
||||
|
||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||
@ -74,13 +74,28 @@ abstract class XGBoostModel(_booster: Booster)
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param useExternalCache if use external cache
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
@deprecated(message = "this API is deprecated from 0.7," +
|
||||
" use eval(booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,iter: Int) or" +
|
||||
" eval(booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String," +
|
||||
" evalFunc: EvalTrait) instead", since = "0.7")
|
||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||
if (evalFunc == null) {
|
||||
eval(_booster, evalDataset, evalName, iter)
|
||||
} else {
|
||||
eval(_booster, evalDataset, evalName, evalFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: refactor to remove duplicate code in two variations of eval()
|
||||
def eval(
|
||||
booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,
|
||||
iter: Int): String = {
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
@ -88,7 +103,7 @@ abstract class XGBoostModel(_booster: Booster)
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
if (broadcastUseExternalCache.value) {
|
||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
@ -97,16 +112,44 @@ abstract class XGBoostModel(_booster: Booster)
|
||||
}
|
||||
import DataUtils._
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
if (iter == -1) {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
} else {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
def eval(
|
||||
booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,
|
||||
evalFunc: EvalTrait): String = {
|
||||
require(evalFunc != null, "you have to specify the value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
if (labeledPointsPartition.hasNext) {
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (broadcastUseExternalCache.value) {
|
||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
import DataUtils._
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
@ -215,8 +258,7 @@ abstract class XGBoostModel(_booster: Booster)
|
||||
val testDataset = new DMatrix(vectorIterator, cachePrefix)
|
||||
val rawPredictResults = {
|
||||
if (!predLeaf) {
|
||||
broadcastBooster.value.predict(testDataset, outputMargin).
|
||||
map(Row(_)).iterator
|
||||
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
|
||||
} else {
|
||||
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
||||
}
|
||||
@ -284,7 +326,5 @@ abstract class XGBoostModel(_booster: Booster)
|
||||
outputStream.close()
|
||||
}
|
||||
|
||||
// override protected def featuresDataType: DataType = new VectorUDT
|
||||
|
||||
def booster: Booster = _booster
|
||||
}
|
||||
|
||||
@ -16,26 +16,35 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
|
||||
class XGBoostRegressionModel private[spark](override val uid: String, _booster: Booster)
|
||||
extends XGBoostModel(_booster) {
|
||||
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
val predictRDD = produceRowRDD(testSet)
|
||||
testSet.sparkSession.createDataFrame(predictRDD, schema =
|
||||
StructType(testSet.schema.add(StructField($(predictionCol),
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
)
|
||||
val tempPredColName = $(predictionCol) + "_temp"
|
||||
val transformerForArrayTypedPredCol =
|
||||
udf((regressionResults: mutable.WrappedArray[Float]) => regressionResults(0))
|
||||
testSet.sparkSession.createDataFrame(predictRDD,
|
||||
schema = testSet.schema.add(tempPredColName, ArrayType(FloatType, containsNull = false))
|
||||
).withColumn(
|
||||
$(predictionCol),
|
||||
transformerForArrayTypedPredCol.apply(col(tempPredColName))).drop(tempPredColName)
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
@ -43,6 +52,8 @@ class XGBoostRegressionModel private[spark](override val uid: String, _booster:
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||
defaultCopy(extra)
|
||||
val regModel = defaultCopy(extra).asInstanceOf[XGBoostRegressionModel]
|
||||
regModel._booster = booster
|
||||
regModel
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,150 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
private[spark] trait BoosterParams extends Params {
|
||||
this: XGBoostEstimator =>
|
||||
|
||||
val boosterType = new Param[String](this, "booster",
|
||||
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
||||
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
||||
|
||||
// Tree Booster parameters
|
||||
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
||||
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
|
||||
" partition on a leaf node of the tree. the larger, the more conservative the algorithm will" +
|
||||
" be.", (value: Double) => value >= 0)
|
||||
|
||||
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
|
||||
" will make model more complex / likely to be overfitting.", (value: Int) => value >= 1)
|
||||
|
||||
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
|
||||
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
||||
" the sum of instance weight less than min_child_weight, then the building process will" +
|
||||
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
|
||||
" number of instances needed to be in each node. The larger, the more conservative" +
|
||||
" the algorithm will be.", (value: Double) => value >= 0)
|
||||
|
||||
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
|
||||
" tree's weight" +
|
||||
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
||||
" to a positive value, it can help making the update step more conservative. Usually this" +
|
||||
" parameter is not needed, but it might help in logistic regression when class is extremely" +
|
||||
" imbalanced. Set it to value of 1-10 might help control the update",
|
||||
(value: Double) => value >= 0)
|
||||
|
||||
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
|
||||
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
|
||||
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
|
||||
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
|
||||
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
val treeMethod = new Param[String](this, "tree_method",
|
||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
val sketchEps = new DoubleParam(this, "sketch_eps",
|
||||
"This is only used for approximate greedy algorithm. This roughly translated into" +
|
||||
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
|
||||
" with theoretical guarantee with sketch accuracy.",
|
||||
(value: Double) => value < 1 && value > 0)
|
||||
|
||||
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
|
||||
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||
" sum(negative cases) / sum(positive cases)")
|
||||
|
||||
// Dart boosters
|
||||
|
||||
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
|
||||
" {'uniform', 'weighted'}",
|
||||
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
||||
|
||||
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
||||
" algorithm, options: {'tree', 'forest'}",
|
||||
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
||||
|
||||
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
|
||||
value >= 0 && value <= 1)
|
||||
|
||||
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
|
||||
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
// linear booster
|
||||
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
|
||||
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||
|
||||
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||
scalePosWeight -> 0, sampleType -> "uniform", normalizeType -> "tree",
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
||||
|
||||
/**
|
||||
* Explains all params of this instance. See `explainParam()`.
|
||||
*/
|
||||
override def explainParams(): String = {
|
||||
// TODO: filter some parameters according to the booster type
|
||||
val boosterTypeStr = $(boosterType)
|
||||
val validParamList = {
|
||||
if (boosterTypeStr == "gblinear") {
|
||||
// gblinear
|
||||
params.filter(param => param.name == "lambda" ||
|
||||
param.name == "alpha" || param.name == "lambda_bias")
|
||||
} else if (boosterTypeStr != "dart") {
|
||||
// gbtree
|
||||
params.filter(param => param.name != "sample_type" &&
|
||||
param.name != "normalize_type" && param.name != "rate_drop" && param.name != "skip_drop")
|
||||
} else {
|
||||
// dart
|
||||
params.filter(_.name != "lambda_bias")
|
||||
}
|
||||
}
|
||||
explainParam(boosterType) + "\n" ++ validParamList.map(explainParam).mkString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object BoosterParams {
|
||||
|
||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
||||
|
||||
val supportedTreeMethods = HashSet("auto", "exact", "approx")
|
||||
|
||||
val supportedSampleType = HashSet("uniform", "weighted")
|
||||
|
||||
val supportedNormalizeType = HashSet("tree", "forest")
|
||||
}
|
||||
@ -0,0 +1,47 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
private[spark] trait GeneralParams extends Params {
|
||||
|
||||
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
val nWorkers = new IntParam(this, "nthread", "number of workers used to run xgboost",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
|
||||
"memory as cache")
|
||||
|
||||
val silent = new IntParam(this, "silent",
|
||||
"0 means printing running messages, 1 means silent mode.",
|
||||
(value: Int) => value >= 0 && value <= 1)
|
||||
|
||||
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
|
||||
"provided by the user")
|
||||
|
||||
val customEval = new Param[EvalTrait](this, "custom_obj", "customized evaluation function " +
|
||||
"provided by the user")
|
||||
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN)
|
||||
}
|
||||
@ -0,0 +1,48 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param.{DoubleParam, Param, Params}
|
||||
|
||||
private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
val objective = new Param[String](this, "objective", "objective function used for training," +
|
||||
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||
|
||||
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
|
||||
" instances, global bias")
|
||||
|
||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||
" error for classification, mean average precision for ranking), options: " +
|
||||
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
val supportedObjective = HashSet("reg:linear", "reg:logistic", "binary:logistic",
|
||||
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
|
||||
"reg:gamma")
|
||||
|
||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||
"auc", "ndcg", "map", "gamma-deviance")
|
||||
}
|
||||
@ -16,16 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
@ -66,13 +62,15 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
"id", "features", "label")
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row =>
|
||||
(row.getAs[Int]("id"), row.getAs[mutable.WrappedArray[Float]]("probabilities"))
|
||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))
|
||||
).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||
// Spark
|
||||
for (i <- predResultFromSeq.indices) {
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
|
||||
for (j <- predResultFromSeq(i).indices) {
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
|
||||
}
|
||||
}
|
||||
cleanExternalCache("XGBoostDFSuite")
|
||||
@ -160,4 +158,29 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
cleanExternalCache("XGBoostDFSuite")
|
||||
}
|
||||
|
||||
test("xgboost and spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
|
||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eta").get.toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("objective").get.toString === "binary:logistic")
|
||||
}
|
||||
|
||||
test("eval_metric is configured correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||
val sparkParamMap = ParamMap.empty
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eval_metric") === Some("error"))
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.xgboostParams.get("eval_metric") === Some("logloss"))
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@ -45,12 +45,11 @@ object XGBoost {
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
|
||||
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
params.map{
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user