[jvm-packages] update API docs (#1713)
* add back train method but mark as deprecated * fix scalastyle error * update java doc * update
This commit is contained in:
parent
d321375df5
commit
6082184cd1
@ -120,7 +120,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* train XGBoost model with the DataFrame-represented data
|
||||||
* @param trainingData the trainingset represented as DataFrame
|
* @param trainingData the trainingset represented as DataFrame
|
||||||
* @param params Map containing the parameters to configure XGBoost
|
* @param params Map containing the parameters to configure XGBoost
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
@ -173,7 +173,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* train XGBoost model with the RDD-represented data
|
||||||
* @param trainingData the trainingset represented as RDD
|
* @param trainingData the trainingset represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
@ -218,7 +218,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* various of train()
|
||||||
* @param trainingData the trainingset represented as RDD
|
* @param trainingData the trainingset represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
|
|||||||
@ -26,6 +26,9 @@ import org.apache.spark.sql.functions._
|
|||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class of the XGBoost model used for classification task
|
||||||
|
*/
|
||||||
class XGBoostClassificationModel private[spark](
|
class XGBoostClassificationModel private[spark](
|
||||||
override val uid: String, booster: Booster)
|
override val uid: String, booster: Booster)
|
||||||
extends XGBoostModel(booster) {
|
extends XGBoostModel(booster) {
|
||||||
@ -37,12 +40,19 @@ class XGBoostClassificationModel private[spark](
|
|||||||
|
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
|
|
||||||
|
/**
|
||||||
|
* whether to output raw margin
|
||||||
|
*/
|
||||||
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
|
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
|
||||||
|
|
||||||
setDefault(outputMargin, false)
|
setDefault(outputMargin, false)
|
||||||
|
|
||||||
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
|
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the name of the column storing the raw prediction value, either probabilities (as default) or
|
||||||
|
* raw margin value
|
||||||
|
*/
|
||||||
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
|
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
|
||||||
|
|
||||||
setDefault(rawPredictionCol, "probabilities")
|
setDefault(rawPredictionCol, "probabilities")
|
||||||
@ -51,6 +61,9 @@ class XGBoostClassificationModel private[spark](
|
|||||||
|
|
||||||
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
|
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Thresholds in multi-class classification
|
||||||
|
*/
|
||||||
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
|
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
|
||||||
|
|
||||||
def getThresholds: Array[Double] = $(thresholds)
|
def getThresholds: Array[Double] = $(thresholds)
|
||||||
|
|||||||
@ -29,9 +29,7 @@ import org.apache.spark.sql.types.DoubleType
|
|||||||
import org.apache.spark.sql.{Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the estimator wrapping XGBoost to produce a training model
|
* XGBoost Estimator to produce a XGBoost model
|
||||||
*
|
|
||||||
* @param xgboostParams the parameters configuring XGBoost
|
|
||||||
*/
|
*/
|
||||||
class XGBoostEstimator private[spark](
|
class XGBoostEstimator private[spark](
|
||||||
override val uid: String, private[spark] var xgboostParams: Map[String, Any])
|
override val uid: String, private[spark] var xgboostParams: Map[String, Any])
|
||||||
|
|||||||
@ -30,6 +30,9 @@ import org.apache.spark.sql._
|
|||||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||||
|
*/
|
||||||
abstract class XGBoostModel(protected var _booster: Booster)
|
abstract class XGBoostModel(protected var _booster: Booster)
|
||||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
|
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
|
||||||
|
|
||||||
@ -76,24 +79,18 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||||
* @return the average metric over all partitions
|
* @return the average metric over all partitions
|
||||||
*/
|
*/
|
||||||
@deprecated(message = "this API is deprecated from 0.7," +
|
|
||||||
" use eval(booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,iter: Int) or" +
|
|
||||||
" eval(booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String," +
|
|
||||||
" evalFunc: EvalTrait) instead", since = "0.7")
|
|
||||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||||
if (evalFunc == null) {
|
if (evalFunc == null) {
|
||||||
eval(_booster, evalDataset, evalName, iter)
|
eval(evalDataset, evalName, iter)
|
||||||
} else {
|
} else {
|
||||||
eval(_booster, evalDataset, evalName, evalFunc)
|
eval(evalDataset, evalName, evalFunc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: refactor to remove duplicate code in two variations of eval()
|
// TODO: refactor to remove duplicate code in two variations of eval()
|
||||||
def eval(
|
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = {
|
||||||
booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,
|
|
||||||
iter: Int): String = {
|
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
@ -125,11 +122,10 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
s"$evalPrefix = $evalMetricMean"
|
s"$evalPrefix = $evalMetricMean"
|
||||||
}
|
}
|
||||||
|
|
||||||
def eval(
|
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait):
|
||||||
booster: Booster, evalDataset: RDD[MLLabeledPoint], evalName: String,
|
String = {
|
||||||
evalFunc: EvalTrait): String = {
|
|
||||||
require(evalFunc != null, "you have to specify the value of either eval or iter")
|
require(evalFunc != null, "you have to specify the value of either eval or iter")
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(booster)
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
val allEvalMetrics = evalDataset.mapPartitions {
|
val allEvalMetrics = evalDataset.mapPartitions {
|
||||||
|
|||||||
@ -26,6 +26,9 @@ import org.apache.spark.sql._
|
|||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class of XGBoost model used for regression task
|
||||||
|
*/
|
||||||
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
|
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
|
||||||
extends XGBoostModel(booster) {
|
extends XGBoostModel(booster) {
|
||||||
|
|
||||||
|
|||||||
@ -21,26 +21,49 @@ import scala.collection.immutable.HashSet
|
|||||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
||||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||||
|
|
||||||
private[spark] trait BoosterParams extends Params {
|
trait BoosterParams extends Params {
|
||||||
this: XGBoostEstimator =>
|
this: XGBoostEstimator =>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||||
|
*/
|
||||||
val boosterType = new Param[String](this, "booster",
|
val boosterType = new Param[String](this, "booster",
|
||||||
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
||||||
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
||||||
|
|
||||||
// Tree Booster parameters
|
/**
|
||||||
|
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
|
||||||
|
* can directly get the weights of new features and eta actually shrinks the feature weights
|
||||||
|
* to make the boosting process more conservative. [default=0.3] range: [0,1]
|
||||||
|
*/
|
||||||
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||||
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
||||||
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
||||||
(value: Double) => value >= 0 && value <= 1)
|
(value: Double) => value >= 0 && value <= 1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
|
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
|
||||||
|
* Double.MaxValue]
|
||||||
|
*/
|
||||||
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
|
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
|
||||||
" partition on a leaf node of the tree. the larger, the more conservative the algorithm will" +
|
" partition on a leaf node of the tree. the larger, the more conservative the algorithm" +
|
||||||
" be.", (value: Double) => value >= 0)
|
" will be.", (value: Double) => value >= 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* maximum depth of a tree, increase this value will make model more complex / likely to be
|
||||||
|
* overfitting. [default=6] range: [1, Int.MaxValue]
|
||||||
|
*/
|
||||||
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
|
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
|
||||||
" will make model more complex / likely to be overfitting.", (value: Int) => value >= 1)
|
" will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
||||||
|
* in a leaf node with the sum of instance weight less than min_child_weight, then the building
|
||||||
|
* process will give up further partitioning. In linear regression mode, this simply corresponds
|
||||||
|
* to minimum number of instances needed to be in each node. The larger, the more conservative
|
||||||
|
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
|
||||||
|
*/
|
||||||
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
|
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
|
||||||
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
||||||
" the sum of instance weight less than min_child_weight, then the building process will" +
|
" the sum of instance weight less than min_child_weight, then the building process will" +
|
||||||
@ -48,6 +71,13 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
" number of instances needed to be in each node. The larger, the more conservative" +
|
" number of instances needed to be in each node. The larger, the more conservative" +
|
||||||
" the algorithm will be.", (value: Double) => value >= 0)
|
" the algorithm will be.", (value: Double) => value >= 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
|
||||||
|
* means there is no constraint. If it is set to a positive value, it can help making the update
|
||||||
|
* step more conservative. Usually this parameter is not needed, but it might help in logistic
|
||||||
|
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
|
||||||
|
* update. [default=0] range: [0, Double.MaxValue]
|
||||||
|
*/
|
||||||
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
|
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
|
||||||
" tree's weight" +
|
" tree's weight" +
|
||||||
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
||||||
@ -56,54 +86,109 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
" imbalanced. Set it to value of 1-10 might help control the update",
|
" imbalanced. Set it to value of 1-10 might help control the update",
|
||||||
(value: Double) => value >= 0)
|
(value: Double) => value >= 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
|
||||||
|
* collected half of the data instances to grow trees and this will prevent overfitting.
|
||||||
|
* [default=1] range:(0,1]
|
||||||
|
*/
|
||||||
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
|
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
|
||||||
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
|
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
|
||||||
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
|
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
|
||||||
|
*/
|
||||||
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
|
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
|
||||||
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
|
||||||
|
*/
|
||||||
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
|
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
|
||||||
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* L2 regularization term on weights, increase this value will make model more conservative.
|
||||||
|
* [default=1]
|
||||||
|
*/
|
||||||
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
|
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
|
||||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* L1 regularization term on weights, increase this value will make model more conservative.
|
||||||
|
* [default=0]
|
||||||
|
*/
|
||||||
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
|
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
|
||||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
|
||||||
|
* [default='auto']
|
||||||
|
*/
|
||||||
val treeMethod = new Param[String](this, "tree_method",
|
val treeMethod = new Param[String](this, "tree_method",
|
||||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
|
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
|
||||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is only used for approximate greedy algorithm.
|
||||||
|
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
|
||||||
|
* number of bins, this comes with theoretical guarantee with sketch accuracy.
|
||||||
|
* [default=0.03] range: (0, 1)
|
||||||
|
*/
|
||||||
val sketchEps = new DoubleParam(this, "sketch_eps",
|
val sketchEps = new DoubleParam(this, "sketch_eps",
|
||||||
"This is only used for approximate greedy algorithm. This roughly translated into" +
|
"This is only used for approximate greedy algorithm. This roughly translated into" +
|
||||||
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
|
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
|
||||||
" with theoretical guarantee with sketch accuracy.",
|
" with theoretical guarantee with sketch accuracy.",
|
||||||
(value: Double) => value < 1 && value > 0)
|
(value: Double) => value < 1 && value > 0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
|
||||||
|
* value to consider: sum(negative cases) / sum(positive cases). [default=0]
|
||||||
|
*/
|
||||||
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
|
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
|
||||||
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||||
" sum(negative cases) / sum(positive cases)")
|
" sum(negative cases) / sum(positive cases)")
|
||||||
|
|
||||||
// Dart boosters
|
// Dart boosters
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter for Dart booster.
|
||||||
|
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
|
||||||
|
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
|
||||||
|
*/
|
||||||
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
|
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
|
||||||
" {'uniform', 'weighted'}",
|
" {'uniform', 'weighted'}",
|
||||||
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter of Dart booster.
|
||||||
|
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
|
||||||
|
*/
|
||||||
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
||||||
" algorithm, options: {'tree', 'forest'}",
|
" algorithm, options: {'tree', 'forest'}",
|
||||||
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter of Dart booster.
|
||||||
|
* dropout rate. [default=0.0] range: [0.0, 1.0]
|
||||||
|
*/
|
||||||
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
|
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
|
||||||
value >= 0 && value <= 1)
|
value >= 0 && value <= 1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter of Dart booster.
|
||||||
|
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
|
||||||
|
* as gbtree. [default=0.0] range: [0.0, 1.0]
|
||||||
|
*/
|
||||||
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
|
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
|
||||||
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
||||||
(value: Double) => value >= 0 && value <= 1)
|
(value: Double) => value >= 0 && value <= 1)
|
||||||
|
|
||||||
// linear booster
|
// linear booster
|
||||||
|
/**
|
||||||
|
* Parameter of linear booster
|
||||||
|
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
|
||||||
|
*/
|
||||||
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
|
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
|
||||||
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||||
|
|
||||||
|
|||||||
@ -19,30 +19,54 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
private[spark] trait GeneralParams extends Params {
|
trait GeneralParams extends Params {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of rounds for boosting
|
||||||
|
*/
|
||||||
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
|
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* number of workers used to train xgboost model. default: 1
|
||||||
|
*/
|
||||||
val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost",
|
val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* number of threads used by per worker. default 1
|
||||||
|
*/
|
||||||
val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker",
|
val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* whether to use external memory as cache. default: false
|
||||||
|
*/
|
||||||
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
|
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
|
||||||
"memory as cache")
|
"memory as cache")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 0 means printing running messages, 1 means silent mode. default: 0
|
||||||
|
*/
|
||||||
val silent = new IntParam(this, "silent",
|
val silent = new IntParam(this, "silent",
|
||||||
"0 means printing running messages, 1 means silent mode.",
|
"0 means printing running messages, 1 means silent mode.",
|
||||||
(value: Int) => value >= 0 && value <= 1)
|
(value: Int) => value >= 0 && value <= 1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* customized objective function provided by user. default: null
|
||||||
|
*/
|
||||||
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
|
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
|
||||||
"provided by the user")
|
"provided by user")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* customized evaluation function provided by user. default: null
|
||||||
|
*/
|
||||||
val customEval = new Param[EvalTrait](this, "custom_obj", "customized evaluation function " +
|
val customEval = new Param[EvalTrait](this, "custom_obj", "customized evaluation function " +
|
||||||
"provided by the user")
|
"provided by user")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the value treated as missing. default: Float.NaN
|
||||||
|
*/
|
||||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||||
|
|
||||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||||
|
|||||||
@ -20,15 +20,28 @@ import scala.collection.immutable.HashSet
|
|||||||
|
|
||||||
import org.apache.spark.ml.param.{DoubleParam, Param, Params}
|
import org.apache.spark.ml.param.{DoubleParam, Param, Params}
|
||||||
|
|
||||||
private[spark] trait LearningTaskParams extends Params {
|
trait LearningTaskParams extends Params {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify the learning task and the corresponding learning objective.
|
||||||
|
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||||
|
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear
|
||||||
|
*/
|
||||||
val objective = new Param[String](this, "objective", "objective function used for training," +
|
val objective = new Param[String](this, "objective", "objective function used for training," +
|
||||||
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the initial prediction score of all instances, global bias. default=0.5
|
||||||
|
*/
|
||||||
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
|
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
|
||||||
" instances, global bias")
|
" instances, global bias")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluation metrics for validation data, a default metric will be assigned according to
|
||||||
|
* objective(rmse for regression, and error for classification, mean average precision for
|
||||||
|
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, ndcg, map, gamma-deviance
|
||||||
|
*/
|
||||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||||
" error for classification, mean average precision for ranking), options: " +
|
" error for classification, mean average precision for ranking), options: " +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user