diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 4ae41e8fa..574cecf1d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -141,6 +141,9 @@ class XGBoostClassifier ( def setCustomEval(value: EvalTrait): this.type = set(customEval, value) + def setSinglePrecisionHistogram(value: Boolean): this.type = + set(singlePrecisionHistogram, value) + // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 188b11176..d8c278d46 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -145,6 +145,9 @@ class XGBoostRegressor ( def setCustomEval(value: EvalTrait): this.type = set(customEval, value) + def setSinglePrecisionHistogram(value: Boolean): this.type = + set(singlePrecisionHistogram, value) + // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 6ec588d4b..1a7dd2a73 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params import scala.collection.immutable.HashSet -import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} +import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params} private[spark] trait BoosterParams extends Params { @@ -173,6 +173,14 @@ private[spark] trait BoosterParams extends Params { final def getMaxBins: Int = $(maxBins) + /** + * whether to build histograms using single precision floating point values + */ + final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram", + "whether to use single precision to build histograms") + + final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram) + /** * This is only used for approximate greedy algorithm. * This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select