[jvm-packages] Add Rapids plugin support (#7491)
* Add GPU pre-processing pipeline.
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.ServiceLoader
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
@@ -24,7 +25,6 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel._originalPredictionCol
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
@@ -35,7 +35,7 @@ import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
@@ -43,7 +43,7 @@ import org.apache.spark.storage.StorageLevel
|
||||
/**
|
||||
* PreXGBoost serves preparing data before training and transform
|
||||
*/
|
||||
object PreXGBoost {
|
||||
object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@@ -51,6 +51,48 @@ object PreXGBoost {
|
||||
private lazy val defaultWeightColumn = lit(1.0)
|
||||
private lazy val defaultGroupColumn = lit(-1)
|
||||
|
||||
// Find the correct PreXGBoostProvider by ServiceLoader
|
||||
private val optionProvider: Option[PreXGBoostProvider] = {
|
||||
val classLoader = Option(Thread.currentThread().getContextClassLoader)
|
||||
.getOrElse(getClass.getClassLoader)
|
||||
|
||||
val serviceLoader = ServiceLoader.load(classOf[PreXGBoostProvider], classLoader)
|
||||
|
||||
// For now, we only trust GpuPreXGBoost.
|
||||
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
|
||||
"ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost")).toList match {
|
||||
case Nil => None
|
||||
case head::Nil =>
|
||||
Some(head)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform schema
|
||||
*
|
||||
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
|
||||
* XGBoostRegressor/XGBoostRegressionModel
|
||||
* @param schema the input schema
|
||||
* @return the transformed schema
|
||||
*/
|
||||
override def transformSchema(
|
||||
xgboostEstimator: XGBoostEstimatorCommon,
|
||||
schema: StructType): StructType = {
|
||||
|
||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(None)) {
|
||||
return optionProvider.get.transformSchema(xgboostEstimator, schema)
|
||||
}
|
||||
|
||||
xgboostEstimator match {
|
||||
case est: XGBoostClassifier => est.transformSchemaInternal(schema)
|
||||
case model: XGBoostClassificationModel => model.transformSchemaInternal(schema)
|
||||
case reg: XGBoostRegressor => reg.transformSchemaInternal(schema)
|
||||
case model: XGBoostRegressionModel => model.transformSchemaInternal(schema)
|
||||
case _ => throw new RuntimeException("Unsupporting " + xgboostEstimator)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
||||
*
|
||||
@@ -61,11 +103,15 @@ object PreXGBoost {
|
||||
* RDD[Watches] will be used as the training input
|
||||
* Option[RDD[_]\] is the optional cached RDD
|
||||
*/
|
||||
def buildDatasetToRDD(
|
||||
override def buildDatasetToRDD(
|
||||
estimator: Estimator[_],
|
||||
dataset: Dataset[_],
|
||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
||||
|
||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
||||
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
||||
}
|
||||
|
||||
val (packedParams, evalSet) = estimator match {
|
||||
case est: XGBoostEstimatorCommon =>
|
||||
// get weight column, if weight is not defined, default to lit(1.0)
|
||||
@@ -131,7 +177,11 @@ object PreXGBoost {
|
||||
* @param dataset the input Dataset to transform
|
||||
* @return the transformed DataFrame
|
||||
*/
|
||||
def transformDataFrame(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
||||
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
||||
return optionProvider.get.transformDataset(model, dataset)
|
||||
}
|
||||
|
||||
/** get the necessary parameters */
|
||||
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
|
||||
@@ -467,7 +517,7 @@ object PreXGBoost {
|
||||
}
|
||||
}
|
||||
|
||||
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||
private[scala] def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
if (useExternalMemory) {
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
|
||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
/**
|
||||
* PreXGBoost implementation provider
|
||||
*/
|
||||
private[scala] trait PreXGBoostProvider {
|
||||
|
||||
/**
|
||||
* Whether the provider is enabled or not
|
||||
* @param dataset the input dataset
|
||||
* @return Boolean
|
||||
*/
|
||||
def providerEnabled(dataset: Option[Dataset[_]]): Boolean = false
|
||||
|
||||
/**
|
||||
* Transform schema
|
||||
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
|
||||
* XGBoostRegressor/XGBoostRegressionModel
|
||||
* @param schema the input schema
|
||||
* @return the transformed schema
|
||||
*/
|
||||
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
|
||||
|
||||
/**
|
||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
||||
*
|
||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||
* @param dataset the training data
|
||||
* @param params all user defined and defaulted params
|
||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
||||
* RDD[Watches] will be used as the training input
|
||||
* Option[ RDD[_] ] is the optional cached RDD
|
||||
*/
|
||||
def buildDatasetToRDD(
|
||||
estimator: Estimator[_],
|
||||
dataset: Dataset[_],
|
||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]])
|
||||
|
||||
/**
|
||||
* Transform Dataset
|
||||
*
|
||||
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||
* @param dataset the input Dataset to transform
|
||||
* @return the transformed DataFrame
|
||||
*/
|
||||
def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame
|
||||
|
||||
}
|
||||
@@ -53,12 +53,12 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
}
|
||||
|
||||
private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
maximizeEvalMetrics: Boolean)
|
||||
|
||||
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[spark] case class XGBoostExecutionParams(
|
||||
private[scala] case class XGBoostExecutionParams(
|
||||
numWorkers: Int,
|
||||
numRounds: Int,
|
||||
useExternalMemory: Boolean,
|
||||
@@ -257,7 +257,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private def getGPUAddrFromResources: Int = {
|
||||
def getGPUAddrFromResources: Int = {
|
||||
val tc = TaskContext.get()
|
||||
if (tc == null) {
|
||||
throw new RuntimeException("Something wrong for task context")
|
||||
@@ -473,7 +473,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
}
|
||||
|
||||
class Watches private(
|
||||
class Watches private[scala] (
|
||||
val datasets: Array[DMatrix],
|
||||
val names: Array[String],
|
||||
val cacheDirName: Option[String]) {
|
||||
|
||||
@@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.classification._
|
||||
import org.apache.spark.ml.linalg._
|
||||
@@ -27,9 +28,10 @@ import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import scala.collection.{Iterator, mutable}
|
||||
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
private[spark] val xgboostParams: Map[String, Any])
|
||||
@@ -142,6 +144,13 @@ class XGBoostClassifier (
|
||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||
set(singlePrecisionHistogram, value)
|
||||
|
||||
/**
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
@@ -154,6 +163,15 @@ class XGBoostClassifier (
|
||||
}
|
||||
}
|
||||
|
||||
// Callback from PreXGBoost
|
||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||
super.transformSchema(schema)
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
PreXGBoost.transformSchema(this, schema)
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
@@ -196,7 +214,7 @@ object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
|
||||
class XGBoostClassificationModel private[ml](
|
||||
override val uid: String,
|
||||
override val numClasses: Int,
|
||||
private[spark] val _booster: Booster)
|
||||
private[scala] val _booster: Booster)
|
||||
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with InferenceParams
|
||||
with MLWritable with Serializable {
|
||||
@@ -242,6 +260,13 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||
|
||||
/**
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
@@ -271,7 +296,7 @@ class XGBoostClassificationModel private[ml](
|
||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||
}
|
||||
|
||||
private[spark] def produceResultIterator(
|
||||
private[scala] def produceResultIterator(
|
||||
originalRowItr: Iterator[Row],
|
||||
rawPredictionItr: Iterator[Row],
|
||||
probabilityItr: Iterator[Row],
|
||||
@@ -306,7 +331,7 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||
private[scala] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||
Array[Iterator[Row]] = {
|
||||
val rawPredictionItr = {
|
||||
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
|
||||
@@ -333,6 +358,14 @@ class XGBoostClassificationModel private[ml](
|
||||
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
||||
}
|
||||
|
||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||
super.transformSchema(schema)
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
PreXGBoost.transformSchema(this, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
if (isDefined(thresholds)) {
|
||||
@@ -343,7 +376,7 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = PreXGBoost.transformDataFrame(this, dataset)
|
||||
var outputData = PreXGBoost.transformDataset(this, dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||
@@ -404,8 +437,8 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
||||
|
||||
private[spark] val _rawPredictionCol = "_rawPrediction"
|
||||
private[spark] val _probabilityCol = "_probability"
|
||||
private[scala] val _rawPredictionCol = "_rawPrediction"
|
||||
private[scala] val _probabilityCol = "_probability"
|
||||
|
||||
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ import org.apache.spark.sql.functions._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
@@ -145,6 +146,13 @@ class XGBoostRegressor (
|
||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||
set(singlePrecisionHistogram, value)
|
||||
|
||||
/**
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
@@ -155,6 +163,14 @@ class XGBoostRegressor (
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||
super.transformSchema(schema)
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
PreXGBoost.transformSchema(this, schema)
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
@@ -191,7 +207,7 @@ object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
|
||||
|
||||
class XGBoostRegressionModel private[ml] (
|
||||
override val uid: String,
|
||||
private[spark] val _booster: Booster)
|
||||
private[scala] val _booster: Booster)
|
||||
extends PredictionModel[Vector, XGBoostRegressionModel]
|
||||
with XGBoostRegressorParams with InferenceParams
|
||||
with MLWritable with Serializable {
|
||||
@@ -237,6 +253,13 @@ class XGBoostRegressionModel private[ml] (
|
||||
|
||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||
|
||||
/**
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
@@ -251,7 +274,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
_booster.predict(data = dm)(0)(0)
|
||||
}
|
||||
|
||||
private[spark] def produceResultIterator(
|
||||
private[scala] def produceResultIterator(
|
||||
originalRowItr: Iterator[Row],
|
||||
predictionItr: Iterator[Row],
|
||||
predLeafItr: Iterator[Row],
|
||||
@@ -283,7 +306,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
|
||||
private[scala] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
|
||||
Array[Iterator[Row]] = {
|
||||
val originalPredictionItr = {
|
||||
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
||||
@@ -307,11 +330,19 @@ class XGBoostRegressionModel private[ml] (
|
||||
Array(originalPredictionItr, predLeafItr, predContribItr)
|
||||
}
|
||||
|
||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||
super.transformSchema(schema)
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
PreXGBoost.transformSchema(this, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = PreXGBoost.transformDataFrame(this, dataset)
|
||||
var outputData = PreXGBoost.transformDataset(this, dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
||||
@@ -342,7 +373,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
|
||||
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
|
||||
private[spark] val _originalPredictionCol = "_originalPrediction"
|
||||
private[scala] val _originalPredictionCol = "_originalPrediction"
|
||||
|
||||
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -289,7 +289,7 @@ private[spark] trait BoosterParams extends Params {
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
|
||||
}
|
||||
|
||||
private[spark] object BoosterParams {
|
||||
private[scala] object BoosterParams {
|
||||
|
||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.ml.param.{BooleanParam, Param, Params}
|
||||
|
||||
trait GpuParams extends Params {
|
||||
/**
|
||||
* Param for the names of feature columns.
|
||||
* @group param
|
||||
*/
|
||||
final val featuresCols: StringSeqParam = new StringSeqParam(this, "featuresCols",
|
||||
"a sequence of feature column names.")
|
||||
|
||||
setDefault(featuresCols, Seq.empty[String])
|
||||
|
||||
/** @group getParam */
|
||||
final def getFeaturesCols: Seq[String] = $(featuresCols)
|
||||
|
||||
}
|
||||
|
||||
class StringSeqParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[Seq[String]](parent, name, doc) {
|
||||
|
||||
override def jsonEncode(value: Seq[String]): String = {
|
||||
import org.json4s.JsonDSL._
|
||||
compact(render(value))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): Seq[String] = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[Seq[String]]
|
||||
}
|
||||
}
|
||||
@@ -18,16 +18,16 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
|
||||
|
||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
|
||||
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
|
||||
with HasLabelCol {
|
||||
with HasLabelCol with GpuParams {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
||||
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
|
||||
private[scala] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
|
||||
|
||||
Reference in New Issue
Block a user