[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)
- Introduce an abstract XGBoost Estimator - Update to the latest XGBoost parameters - Add all XGBoost parameters supported in XGBoost4j-spark. - Add setter and getter for these parameters. - Remove the deprecated parameters - Address the missing value handling - Remove any ETL operations in XGBoost - Rework the GPU plugin - Expand sanity tests for CPU and GPU consistency
This commit is contained in:
parent
d94f6679fc
commit
67c8c96784
@ -38,6 +38,7 @@ Contents
|
|||||||
XGBoost4J-Spark-GPU Tutorial <xgboost4j_spark_gpu_tutorial>
|
XGBoost4J-Spark-GPU Tutorial <xgboost4j_spark_gpu_tutorial>
|
||||||
Code Examples <https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example>
|
Code Examples <https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example>
|
||||||
API docs <api>
|
API docs <api>
|
||||||
|
How to migrate to XGBoost-Spark jvm 3.x <xgboost_spark_migration>
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
|||||||
162
doc/jvm/xgboost_spark_migration.rst
Normal file
162
doc/jvm/xgboost_spark_migration.rst
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
########################################################
|
||||||
|
Migration Guide: How to migrate to XGBoost-Spark jvm 3.x
|
||||||
|
########################################################
|
||||||
|
|
||||||
|
XGBoost-Spark jvm packages underwent significant modifications in version 3.0,
|
||||||
|
which may cause compatibility issues with existing user code.
|
||||||
|
|
||||||
|
This guide will walk you through the process of updating your code to ensure
|
||||||
|
it's compatible with XGBoost-Spark 3.0 and later versions.
|
||||||
|
|
||||||
|
**********************
|
||||||
|
XGBoost Spark Packages
|
||||||
|
**********************
|
||||||
|
|
||||||
|
XGBoost-Spark 3.0 introduced a single uber package named xgboost-spark_2.12-3.0.0.jar, which bundles
|
||||||
|
both xgboost4j and xgboost4j-spark. This means you can now simply use `xgboost-spark`` for your application.
|
||||||
|
|
||||||
|
* For CPU
|
||||||
|
|
||||||
|
.. code-block:: xml
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>ml.dmlc</groupId>
|
||||||
|
<artifactId>xgboost-spark_${scala.binary.version}</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
* For GPU
|
||||||
|
|
||||||
|
.. code-block:: xml
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>ml.dmlc</groupId>
|
||||||
|
<artifactId>xgboost-spark-gpu_${scala.binary.version}</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
When submitting the XGBoost application to the Spark cluster, you only need to specify the single `xgboost-spark` package.
|
||||||
|
|
||||||
|
* For CPU
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
spark-submit \
|
||||||
|
--jars xgboost-spark_2.12-3.0.0.jar \
|
||||||
|
... \
|
||||||
|
|
||||||
|
|
||||||
|
* For GPU
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
spark-submit \
|
||||||
|
--jars xgboost-spark_2.12-3.0.0.jar \
|
||||||
|
... \
|
||||||
|
|
||||||
|
**************
|
||||||
|
XGBoost Ranking
|
||||||
|
**************
|
||||||
|
|
||||||
|
Learning to rank using XGBoostRegressor has been replaced by a dedicated `XGBoostRanker`, which is specifically designed
|
||||||
|
to support ranking algorithms.
|
||||||
|
|
||||||
|
.. code-block:: scala
|
||||||
|
|
||||||
|
// before 3.0
|
||||||
|
val regressor = new XGBoostRegressor().setObjective("rank:ndcg")
|
||||||
|
|
||||||
|
// after 3.0
|
||||||
|
val ranker = new XGBoostRanker()
|
||||||
|
|
||||||
|
******************************
|
||||||
|
XGBoost Constructor Parameters
|
||||||
|
******************************
|
||||||
|
|
||||||
|
XGBoost Spark now categorizes parameters into two groups: XGBoost-Spark parameters and XGBoost parameters.
|
||||||
|
When constructing an XGBoost estimator, only XGBoost-specific parameters are permitted. XGBoost-Spark specific
|
||||||
|
parameters must be configured using the estimator's setter methods. It's worth noting that
|
||||||
|
`XGBoost Parameters <https://xgboost.readthedocs.io/en/stable/parameter.html>`_
|
||||||
|
can be set both during construction and through the estimator's setter methods.
|
||||||
|
|
||||||
|
.. code-block:: scala
|
||||||
|
|
||||||
|
// before 3.0
|
||||||
|
val xgboost_paras = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
"num_round" -> 5,
|
||||||
|
"num_workers" -> 1,
|
||||||
|
"features" -> "feature_column",
|
||||||
|
"label" -> "label_column",
|
||||||
|
)
|
||||||
|
val classifier = new XGBoostClassifier(xgboost_paras)
|
||||||
|
|
||||||
|
|
||||||
|
// after 3.0
|
||||||
|
val xgboost_paras = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
)
|
||||||
|
val classifier = new XGBoostClassifier(xgboost_paras)
|
||||||
|
.setNumRound(5)
|
||||||
|
.setNumWorkers(1)
|
||||||
|
.setFeaturesCol("feature_column")
|
||||||
|
.setLabelCol("label_column")
|
||||||
|
|
||||||
|
// Or you can use setter to set all parameters
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setNumRound(5)
|
||||||
|
.setNumWorkers(1)
|
||||||
|
.setFeaturesCol("feature_column")
|
||||||
|
.setLabelCol("label_column")
|
||||||
|
.setEta(1)
|
||||||
|
.setMaxDepth(6)
|
||||||
|
.setObjective("binary:logistic")
|
||||||
|
|
||||||
|
******************
|
||||||
|
Removed Parameters
|
||||||
|
******************
|
||||||
|
|
||||||
|
Starting from 3.0, below parameters are removed.
|
||||||
|
|
||||||
|
- cacheTrainingSet
|
||||||
|
|
||||||
|
If you wish to cache the training dataset, you have the option to implement caching
|
||||||
|
in your code prior to fitting the data to an estimator.
|
||||||
|
|
||||||
|
.. code-block:: scala
|
||||||
|
|
||||||
|
val df = input.cache()
|
||||||
|
val model = new XGBoostClassifier().fit(df)
|
||||||
|
|
||||||
|
- trainTestRatio
|
||||||
|
|
||||||
|
The following method can be employed to do the evaluation.
|
||||||
|
|
||||||
|
.. code-block:: scala
|
||||||
|
|
||||||
|
val Array(train, eval) = trainDf.randomSplit(Array(0.7, 0.3))
|
||||||
|
val classifier = new XGBoostClassifer().setEvalDataset(eval)
|
||||||
|
val model = classifier.fit(train)
|
||||||
|
|
||||||
|
- tracker_conf
|
||||||
|
|
||||||
|
The following method can be used to configure RabitTracker.
|
||||||
|
|
||||||
|
.. code-block:: scala
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifer()
|
||||||
|
.setRabitTrackerTimeout(100)
|
||||||
|
.setRabitTrackerHostIp("192.168.0.2")
|
||||||
|
.setRabitTrackerPort(19203)
|
||||||
|
|
||||||
|
- rabitRingReduceThreshold
|
||||||
|
- rabitTimeout
|
||||||
|
- rabitConnectRetry
|
||||||
|
- singlePrecisionHistogram
|
||||||
|
- lambdaBias
|
||||||
|
- objectiveType
|
||||||
@ -46,7 +46,7 @@
|
|||||||
<use.cuda>OFF</use.cuda>
|
<use.cuda>OFF</use.cuda>
|
||||||
<cudf.version>24.06.0</cudf.version>
|
<cudf.version>24.06.0</cudf.version>
|
||||||
<spark.rapids.version>24.06.0</spark.rapids.version>
|
<spark.rapids.version>24.06.0</spark.rapids.version>
|
||||||
<cudf.classifier>cuda12</cudf.classifier>
|
<spark.rapids.classifier>cuda12</spark.rapids.classifier>
|
||||||
<scalatest.version>3.2.19</scalatest.version>
|
<scalatest.version>3.2.19</scalatest.version>
|
||||||
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
|
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
|
||||||
<skip.native.build>false</skip.native.build>
|
<skip.native.build>false</skip.native.build>
|
||||||
|
|||||||
@ -54,6 +54,7 @@
|
|||||||
<groupId>com.nvidia</groupId>
|
<groupId>com.nvidia</groupId>
|
||||||
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
|
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
|
||||||
<version>${spark.rapids.version}</version>
|
<version>${spark.rapids.version}</version>
|
||||||
|
<classifier>${spark.rapids.classifier}</classifier>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
@ -35,11 +35,39 @@ public class QuantileDMatrix extends DMatrix {
|
|||||||
float missing,
|
float missing,
|
||||||
int maxBin,
|
int maxBin,
|
||||||
int nthread) throws XGBoostError {
|
int nthread) throws XGBoostError {
|
||||||
|
this(iter, null, missing, maxBin, nthread);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||||
|
*
|
||||||
|
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array
|
||||||
|
* interface
|
||||||
|
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
|
||||||
|
* when creating validation/test dataset with QuantileDMatrix. Supplying the
|
||||||
|
* training DMatrix as a reference means that the same quantisation
|
||||||
|
* applied to the training data is applied to the validation/test data
|
||||||
|
* @param missing the missing value
|
||||||
|
* @param maxBin the max bin
|
||||||
|
* @param nthread the parallelism
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public QuantileDMatrix(
|
||||||
|
Iterator<ColumnBatch> iter,
|
||||||
|
QuantileDMatrix refDMatrix,
|
||||||
|
float missing,
|
||||||
|
int maxBin,
|
||||||
|
int nthread) throws XGBoostError {
|
||||||
super(0);
|
super(0);
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
String conf = getConfig(missing, maxBin, nthread);
|
String conf = getConfig(missing, maxBin, nthread);
|
||||||
|
long[] ref = null;
|
||||||
|
if (refDMatrix != null) {
|
||||||
|
ref = new long[1];
|
||||||
|
ref[0] = refDMatrix.getHandle();
|
||||||
|
}
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
||||||
iter, null, conf, out));
|
iter, ref, conf, out));
|
||||||
handle = out[0];
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,4 +115,5 @@ public class QuantileDMatrix extends DMatrix {
|
|||||||
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
||||||
missing, maxBin, nthread);
|
missing, maxBin, nthread);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,68 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.java.nvidia.spark;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import ai.rapids.cudf.ColumnVector;
|
|
||||||
import ai.rapids.cudf.Table;
|
|
||||||
import org.apache.spark.sql.types.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wrapper of CudfTable with schema for scala
|
|
||||||
*/
|
|
||||||
public class GpuColumnBatch implements AutoCloseable {
|
|
||||||
private final StructType schema;
|
|
||||||
private Table table; // the original Table
|
|
||||||
|
|
||||||
public GpuColumnBatch(Table table, StructType schema) {
|
|
||||||
this.table = table;
|
|
||||||
this.schema = schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
if (table != null) {
|
|
||||||
table.close();
|
|
||||||
table = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Slice the columns indicated by indices into a Table*/
|
|
||||||
public Table slice(List<Integer> indices) {
|
|
||||||
if (indices == null || indices.size() == 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
int len = indices.size();
|
|
||||||
ColumnVector[] cv = new ColumnVector[len];
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
int index = indices.get(i);
|
|
||||||
if (index >= table.getNumberOfColumns()) {
|
|
||||||
throw new RuntimeException("Wrong index");
|
|
||||||
}
|
|
||||||
cv[i] = table.getColumn(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Table(cv);
|
|
||||||
}
|
|
||||||
|
|
||||||
public StructType getSchema() {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1 +0,0 @@
|
|||||||
ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost
|
|
||||||
@ -0,0 +1 @@
|
|||||||
|
ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,17 +16,17 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import _root_.scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
|
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
class QuantileDMatrix private[scala](
|
class QuantileDMatrix private[scala](
|
||||||
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
|
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
* Create QuantileDMatrix from iterator based on the array interface
|
||||||
*
|
*
|
||||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
|
||||||
* @param missing the missing value
|
* @param missing the missing value
|
||||||
* @param maxBin the max bin
|
* @param maxBin the max bin
|
||||||
* @param nthread the parallelism
|
* @param nthread the parallelism
|
||||||
@ -36,6 +36,27 @@ class QuantileDMatrix private[scala](
|
|||||||
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
|
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create QuantileDMatrix from iterator based on the array interface
|
||||||
|
*
|
||||||
|
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
|
||||||
|
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
|
||||||
|
* when creating validation/test dataset with QuantileDMatrix. Supplying the
|
||||||
|
* training DMatrix as a reference means that the same quantisation applied
|
||||||
|
* to the training data is applied to the validation/test data
|
||||||
|
* @param missing the missing value
|
||||||
|
* @param maxBin the max bin
|
||||||
|
* @param nthread the parallelism
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
def this(iter: Iterator[ColumnBatch],
|
||||||
|
ref: QuantileDMatrix,
|
||||||
|
missing: Float,
|
||||||
|
maxBin: Int,
|
||||||
|
nthread: Int) {
|
||||||
|
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set label of dmatrix
|
* set label of dmatrix
|
||||||
*
|
*
|
||||||
@ -84,7 +105,7 @@ class QuantileDMatrix private[scala](
|
|||||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
|
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set label of DMatrix from cuda array interface
|
* Set label of DMatrix from array interface
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
override def setLabel(column: Column): Unit =
|
override def setLabel(column: Column): Unit =
|
||||||
@ -104,4 +125,9 @@ class QuantileDMatrix private[scala](
|
|||||||
override def setBaseMargin(column: Column): Unit =
|
override def setBaseMargin(column: Column): Unit =
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
|
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
|
||||||
|
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setQueryId(column: Column): Unit = {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setQueryId.")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1,602 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2024 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
|
||||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
|
||||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
|
|
||||||
import org.apache.spark.sql.functions.{col, collect_list, struct}
|
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
|
||||||
import org.apache.spark.sql.vectorized.ColumnarBatch
|
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* GpuPreXGBoost brings Rapids-Plugin to XGBoost4j-Spark to accelerate XGBoost4j
|
|
||||||
* training and transform process
|
|
||||||
*/
|
|
||||||
class GpuPreXGBoost extends PreXGBoostProvider {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the provider is enabled or not
|
|
||||||
*
|
|
||||||
* @param dataset the input dataset
|
|
||||||
* @return Boolean
|
|
||||||
*/
|
|
||||||
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
|
|
||||||
GpuPreXGBoost.providerEnabled(dataset)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
|
||||||
*
|
|
||||||
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
|
|
||||||
* @param dataset the training data
|
|
||||||
* @param params all user defined and defaulted params
|
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
|
|
||||||
* RDD[() => Watches] will be used as the training input
|
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
|
||||||
*/
|
|
||||||
override def buildDatasetToRDD(estimator: Estimator[_],
|
|
||||||
dataset: Dataset[_],
|
|
||||||
params: Map[String, Any]):
|
|
||||||
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
|
|
||||||
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform Dataset
|
|
||||||
*
|
|
||||||
* @param model [[XGBoostClassificationModel]] or [[XGBoostRegressionModel]]
|
|
||||||
* @param dataset the input Dataset to transform
|
|
||||||
* @return the transformed DataFrame
|
|
||||||
*/
|
|
||||||
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
|
||||||
GpuPreXGBoost.transformDataset(model, dataset)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def transformSchema(
|
|
||||||
xgboostEstimator: XGBoostEstimatorCommon,
|
|
||||||
schema: StructType): StructType = {
|
|
||||||
GpuPreXGBoost.transformSchema(xgboostEstimator, schema)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class BoosterFlag extends Serializable {
|
|
||||||
// indicate if the GPU parameters are set.
|
|
||||||
var isGpuParamsSet = false
|
|
||||||
}
|
|
||||||
|
|
||||||
object GpuPreXGBoost extends PreXGBoostProvider {
|
|
||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
private val FEATURES_COLS = "features_cols"
|
|
||||||
private val TRAIN_NAME = "train"
|
|
||||||
|
|
||||||
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
|
|
||||||
// RuntimeConfig
|
|
||||||
val optionConf = dataset.map(ds => Some(ds.sparkSession.conf))
|
|
||||||
.getOrElse(SparkSession.getActiveSession.map(ss => ss.conf))
|
|
||||||
|
|
||||||
if (optionConf.isDefined) {
|
|
||||||
val conf = optionConf.get
|
|
||||||
val rapidsEnabled = try {
|
|
||||||
conf.get("spark.rapids.sql.enabled").toBoolean
|
|
||||||
} catch {
|
|
||||||
// Rapids plugin has default "spark.rapids.sql.enabled" to true
|
|
||||||
case _: NoSuchElementException => true
|
|
||||||
case _: Throwable => false // Any exception will return false
|
|
||||||
}
|
|
||||||
rapidsEnabled && conf.get("spark.sql.extensions", "")
|
|
||||||
.split(",")
|
|
||||||
.contains("com.nvidia.spark.rapids.SQLExecPlugin")
|
|
||||||
} else false
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
|
||||||
*
|
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
|
||||||
* @param dataset the training data
|
|
||||||
* @param params all user defined and defaulted params
|
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
|
|
||||||
* RDD[() => Watches] will be used as the training input to build DMatrix
|
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
|
||||||
*/
|
|
||||||
override def buildDatasetToRDD(
|
|
||||||
estimator: Estimator[_],
|
|
||||||
dataset: Dataset[_],
|
|
||||||
params: Map[String, Any]):
|
|
||||||
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
|
|
||||||
|
|
||||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
|
||||||
estimator match {
|
|
||||||
case est: XGBoostEstimatorCommon =>
|
|
||||||
require(
|
|
||||||
est.isDefined(est.device) &&
|
|
||||||
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
|
|
||||||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
|
||||||
s"GPU train requires `device` set to `cuda` or `gpu`."
|
|
||||||
)
|
|
||||||
val groupName = estimator match {
|
|
||||||
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
|
||||||
regressor.getGroupCol } else ""
|
|
||||||
case _: XGBoostClassifier => ""
|
|
||||||
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
|
|
||||||
}
|
|
||||||
// Check schema and cast columns' type
|
|
||||||
(GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol, est.baseMarginCol),
|
|
||||||
est.getFeaturesCols, groupName, est.getEvalSets(params))
|
|
||||||
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
|
|
||||||
}
|
|
||||||
|
|
||||||
val castedDF = GpuUtils.prepareColumnType(dataset, feturesCols, labelName, weightName,
|
|
||||||
marginName)
|
|
||||||
|
|
||||||
// Check columns and build column data batch
|
|
||||||
val trainingData = GpuUtils.buildColumnDataBatch(feturesCols,
|
|
||||||
labelName, weightName, marginName, groupName, castedDF)
|
|
||||||
|
|
||||||
// eval map
|
|
||||||
val evalDataMap = evalSets.map {
|
|
||||||
case (name, df) =>
|
|
||||||
val castDF = GpuUtils.prepareColumnType(df, feturesCols, labelName,
|
|
||||||
weightName, marginName)
|
|
||||||
(name, GpuUtils.buildColumnDataBatch(feturesCols, labelName, weightName,
|
|
||||||
marginName, groupName, castDF))
|
|
||||||
}
|
|
||||||
|
|
||||||
xgbExecParams: XGBoostExecutionParams =>
|
|
||||||
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
|
|
||||||
xgbExecParams.cacheTrainingSet)
|
|
||||||
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform Dataset
|
|
||||||
*
|
|
||||||
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
|
||||||
* @param dataset the input Dataset to transform
|
|
||||||
* @return the transformed DataFrame
|
|
||||||
*/
|
|
||||||
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
|
||||||
|
|
||||||
val (booster, predictFunc, schema, featureColNames, missing) = model match {
|
|
||||||
case m: XGBoostClassificationModel =>
|
|
||||||
Seq(XGBoostClassificationModel._rawPredictionCol,
|
|
||||||
XGBoostClassificationModel._probabilityCol, m.leafPredictionCol, m.contribPredictionCol)
|
|
||||||
|
|
||||||
// predict and turn to Row
|
|
||||||
val predictFunc =
|
|
||||||
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
|
||||||
m.producePredictionItrs(booster, dm)
|
|
||||||
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
|
|
||||||
predLeafItr, predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the final Schema
|
|
||||||
var schema = StructType(dataset.schema.fields ++
|
|
||||||
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
|
||||||
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
if (m.isDefined(m.leafPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (m.isDefined(m.contribPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
|
|
||||||
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
|
|
||||||
|
|
||||||
case m: XGBoostRegressionModel =>
|
|
||||||
Seq(XGBoostRegressionModel._originalPredictionCol, m.leafPredictionCol,
|
|
||||||
m.contribPredictionCol)
|
|
||||||
|
|
||||||
// predict and turn to Row
|
|
||||||
val predictFunc =
|
|
||||||
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
|
||||||
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
|
||||||
m.producePredictionItrs(booster, dm)
|
|
||||||
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr,
|
|
||||||
predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the final Schema
|
|
||||||
var schema = StructType(dataset.schema.fields ++
|
|
||||||
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
if (m.isDefined(m.leafPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (m.isDefined(m.contribPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
|
|
||||||
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
|
|
||||||
}
|
|
||||||
|
|
||||||
val sc = dataset.sparkSession.sparkContext
|
|
||||||
|
|
||||||
// Prepare some vars will be passed to executors.
|
|
||||||
val bOrigSchema = sc.broadcast(dataset.schema)
|
|
||||||
val bRowSchema = sc.broadcast(schema)
|
|
||||||
val bBooster = sc.broadcast(booster)
|
|
||||||
val bBoosterFlag = sc.broadcast(new BoosterFlag)
|
|
||||||
|
|
||||||
// Small vars so don't need to broadcast them
|
|
||||||
val isLocal = sc.isLocal
|
|
||||||
val featureIds = featureColNames.distinct.map(dataset.schema.fieldIndex)
|
|
||||||
|
|
||||||
// start transform by df->rd->mapPartition
|
|
||||||
val rowRDD: RDD[Row] = GpuUtils.toColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions {
|
|
||||||
tableIters =>
|
|
||||||
// UnsafeProjection is not serializable so do it on the executor side
|
|
||||||
val toUnsafe = UnsafeProjection.create(bOrigSchema.value)
|
|
||||||
|
|
||||||
// booster is visible for all spark tasks in the same executor
|
|
||||||
val booster = bBooster.value
|
|
||||||
val boosterFlag = bBoosterFlag.value
|
|
||||||
|
|
||||||
synchronized {
|
|
||||||
// there are two kind of race conditions,
|
|
||||||
// 1. multi-taskes set parameters at a time
|
|
||||||
// 2. one task sets parameter and another task reads the parameter
|
|
||||||
// both of them can cause potential un-expected behavior, moreover,
|
|
||||||
// it may cause executor crash
|
|
||||||
// So add synchronized to allow only one task to set parameter if it is not set.
|
|
||||||
// and rely on BlockManager to ensure the same booster only be called once to
|
|
||||||
// set parameter.
|
|
||||||
if (!boosterFlag.isGpuParamsSet) {
|
|
||||||
// set some params of gpu related to booster
|
|
||||||
// - gpu id
|
|
||||||
// - predictor: Force to gpu predictor since native doesn't save predictor.
|
|
||||||
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
|
|
||||||
booster.setParam("device", s"cuda:$gpuId")
|
|
||||||
logger.info("GPU transform on device: " + gpuId)
|
|
||||||
boosterFlag.isGpuParamsSet = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterator on Row
|
|
||||||
new Iterator[Row] {
|
|
||||||
// Convert InternalRow to Row
|
|
||||||
private val converter: InternalRow => Row = CatalystTypeConverters
|
|
||||||
.createToScalaConverter(bOrigSchema.value)
|
|
||||||
.asInstanceOf[InternalRow => Row]
|
|
||||||
// GPU batches read in must be closed by the receiver (us)
|
|
||||||
@transient var currentBatch: ColumnarBatch = null
|
|
||||||
|
|
||||||
// Iterator on Row
|
|
||||||
var iter: Iterator[Row] = null
|
|
||||||
|
|
||||||
TaskContext.get().addTaskCompletionListener[Unit](_ => {
|
|
||||||
closeCurrentBatch() // close the last ColumnarBatch
|
|
||||||
})
|
|
||||||
|
|
||||||
private def closeCurrentBatch(): Unit = {
|
|
||||||
if (currentBatch != null) {
|
|
||||||
currentBatch.close()
|
|
||||||
currentBatch = null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def loadNextBatch(): Unit = {
|
|
||||||
closeCurrentBatch()
|
|
||||||
if (tableIters.hasNext) {
|
|
||||||
val dataTypes = bOrigSchema.value.fields.map(x => x.dataType)
|
|
||||||
iter = withResource(tableIters.next()) { table =>
|
|
||||||
val gpuColumnBatch = new GpuColumnBatch(table, bOrigSchema.value)
|
|
||||||
// Create DMatrix
|
|
||||||
val feaTable = gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(featureIds).asJava)
|
|
||||||
if (feaTable == null) {
|
|
||||||
throw new RuntimeException("Something wrong for feature indices")
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
|
|
||||||
val dm = new DMatrix(cudfColumnBatch, missing, 1)
|
|
||||||
if (dm == null) {
|
|
||||||
Iterator.empty
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
currentBatch = new ColumnarBatch(
|
|
||||||
GpuUtils.extractBatchToHost(table, dataTypes),
|
|
||||||
table.getRowCount().toInt)
|
|
||||||
val rowIterator = currentBatch.rowIterator().asScala
|
|
||||||
.map(toUnsafe)
|
|
||||||
.map(converter(_))
|
|
||||||
predictFunc(booster, dm, rowIterator)
|
|
||||||
|
|
||||||
} finally {
|
|
||||||
dm.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
feaTable.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
iter = null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def hasNext: Boolean = {
|
|
||||||
val itHasNext = iter != null && iter.hasNext
|
|
||||||
if (!itHasNext) { // Don't have extra Row for current ColumnarBatch
|
|
||||||
loadNextBatch()
|
|
||||||
iter != null && iter.hasNext
|
|
||||||
} else {
|
|
||||||
itHasNext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def next(): Row = {
|
|
||||||
if (iter == null || !iter.hasNext) {
|
|
||||||
loadNextBatch()
|
|
||||||
}
|
|
||||||
if (iter == null) {
|
|
||||||
throw new NoSuchElementException()
|
|
||||||
}
|
|
||||||
iter.next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bOrigSchema.unpersist(blocking = false)
|
|
||||||
bRowSchema.unpersist(blocking = false)
|
|
||||||
bBooster.unpersist(blocking = false)
|
|
||||||
dataset.sparkSession.createDataFrame(rowRDD, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform schema
|
|
||||||
*
|
|
||||||
* @param est supporting XGBoostClassifier/XGBoostClassificationModel and
|
|
||||||
* XGBoostRegressor/XGBoostRegressionModel
|
|
||||||
* @param schema the input schema
|
|
||||||
* @return the transformed schema
|
|
||||||
*/
|
|
||||||
override def transformSchema(
|
|
||||||
est: XGBoostEstimatorCommon,
|
|
||||||
schema: StructType): StructType = {
|
|
||||||
|
|
||||||
val fit = est match {
|
|
||||||
case _: XGBoostClassifier | _: XGBoostRegressor => true
|
|
||||||
case _ => false
|
|
||||||
}
|
|
||||||
|
|
||||||
val Seq(label, weight, margin) = GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol,
|
|
||||||
est.baseMarginCol)
|
|
||||||
|
|
||||||
GpuUtils.validateSchema(schema, est.getFeaturesCols, label, weight, margin, fit)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Repartition all the Columnar Dataset (training and evaluation) to nWorkers,
|
|
||||||
* and assemble them into a map
|
|
||||||
*/
|
|
||||||
private def prepareInputData(
|
|
||||||
trainingData: ColumnDataBatch,
|
|
||||||
evalSetsMap: Map[String, ColumnDataBatch],
|
|
||||||
nWorkers: Int,
|
|
||||||
isCacheData: Boolean): Map[String, ColumnDataBatch] = {
|
|
||||||
// Cache is not supported
|
|
||||||
if (isCacheData) {
|
|
||||||
logger.warn("the cache param will be ignored by GPU pipeline!")
|
|
||||||
}
|
|
||||||
|
|
||||||
(Map(TRAIN_NAME -> trainingData) ++ evalSetsMap).map {
|
|
||||||
case (name, colData) =>
|
|
||||||
// No light cost way to get number of partitions from DataFrame, so always repartition
|
|
||||||
val newDF = colData.groupColName
|
|
||||||
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
|
|
||||||
.getOrElse(repartitionInputData(colData.rawDF, nWorkers))
|
|
||||||
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
|
|
||||||
// we can't involve any coalesce operation here, since Barrier mode will check
|
|
||||||
// the RDD patterns which does not allow coalesce.
|
|
||||||
dataFrame.repartition(nWorkers)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def repartitionForGroup(
|
|
||||||
groupName: String,
|
|
||||||
dataFrame: DataFrame,
|
|
||||||
nWorkers: Int): DataFrame = {
|
|
||||||
// Group the data first
|
|
||||||
logger.info("Start groupBy for LTR")
|
|
||||||
val schema = dataFrame.schema
|
|
||||||
val groupedDF = dataFrame
|
|
||||||
.groupBy(groupName)
|
|
||||||
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
|
|
||||||
|
|
||||||
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
|
|
||||||
// Expand the grouped rows after repartition
|
|
||||||
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
|
|
||||||
new Iterator[Row] {
|
|
||||||
var iterInRow: Iterator[Any] = Iterator.empty
|
|
||||||
|
|
||||||
override def hasNext: Boolean = {
|
|
||||||
if (iter.hasNext && !iterInRow.hasNext) {
|
|
||||||
// the first is groupId, second is list
|
|
||||||
iterInRow = iter.next.getSeq(1).iterator
|
|
||||||
}
|
|
||||||
iterInRow.hasNext
|
|
||||||
}
|
|
||||||
|
|
||||||
override def next(): Row = {
|
|
||||||
iterInRow.next.asInstanceOf[Row]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildRDDWatches(
|
|
||||||
dataMap: Map[String, ColumnDataBatch],
|
|
||||||
xgbExeParams: XGBoostExecutionParams,
|
|
||||||
noEvalSet: Boolean): RDD[() => Watches] = {
|
|
||||||
|
|
||||||
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
|
|
||||||
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
|
|
||||||
// Start training
|
|
||||||
if (noEvalSet) {
|
|
||||||
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
|
|
||||||
val colIndicesForTrain = dataMap(TRAIN_NAME).colIndices
|
|
||||||
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
|
|
||||||
iter =>
|
|
||||||
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
|
|
||||||
Iterator(() => buildWatches(
|
|
||||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
|
||||||
colIndicesForTrain, iterColBatch, maxBin))
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Train with evaluation sets
|
|
||||||
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
|
|
||||||
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
|
|
||||||
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
|
|
||||||
nameAndColumnBatchIter =>
|
|
||||||
Iterator(() => buildWatchesWithEval(
|
|
||||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
|
||||||
nameAndColIndices, nameAndColumnBatchIter, maxBin))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildWatches(
|
|
||||||
cachedDirName: Option[String],
|
|
||||||
missing: Float,
|
|
||||||
indices: ColumnIndices,
|
|
||||||
iter: Iterator[GpuColumnBatch],
|
|
||||||
maxBin: Int): Watches = {
|
|
||||||
|
|
||||||
val (dm, time) = GpuUtils.time {
|
|
||||||
buildDMatrix(iter, indices, missing, maxBin)
|
|
||||||
}
|
|
||||||
logger.debug("Benchmark[Train: Build DMatrix incrementally] " + time)
|
|
||||||
val (aDMatrix, aName) = if (dm == null) {
|
|
||||||
(Array.empty[DMatrix], Array.empty[String])
|
|
||||||
} else {
|
|
||||||
(Array(dm), Array("train"))
|
|
||||||
}
|
|
||||||
new Watches(aDMatrix, aName, cachedDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildWatchesWithEval(
|
|
||||||
cachedDirName: Option[String],
|
|
||||||
missing: Float,
|
|
||||||
indices: Map[String, ColumnIndices],
|
|
||||||
nameAndColumns: Iterator[(String, Iterator[GpuColumnBatch])],
|
|
||||||
maxBin: Int): Watches = {
|
|
||||||
val dms = nameAndColumns.map {
|
|
||||||
case (name, iter) => (name, {
|
|
||||||
val (dm, time) = GpuUtils.time {
|
|
||||||
buildDMatrix(iter, indices(name), missing, maxBin)
|
|
||||||
}
|
|
||||||
logger.debug(s"Benchmark[Train build $name DMatrix] " + time)
|
|
||||||
dm
|
|
||||||
})
|
|
||||||
}.filter(_._2 != null).toArray
|
|
||||||
|
|
||||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build QuantileDMatrix based on GpuColumnBatches
|
|
||||||
*
|
|
||||||
* @param iter a sequence of GpuColumnBatch
|
|
||||||
* @param indices indicate the feature, label, weight, base margin column ids.
|
|
||||||
* @param missing the missing value
|
|
||||||
* @param maxBin the maxBin
|
|
||||||
* @return DMatrix
|
|
||||||
*/
|
|
||||||
private def buildDMatrix(
|
|
||||||
iter: Iterator[GpuColumnBatch],
|
|
||||||
indices: ColumnIndices,
|
|
||||||
missing: Float,
|
|
||||||
maxBin: Int): DMatrix = {
|
|
||||||
val rapidsIterator = new RapidsIterator(iter, indices)
|
|
||||||
new QuantileDMatrix(rapidsIterator, missing, maxBin, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// zip all the Columnar RDDs into one RDD containing named column data batch.
|
|
||||||
private def coPartitionForGpu(
|
|
||||||
dataMap: Map[String, ColumnDataBatch],
|
|
||||||
sc: SparkContext,
|
|
||||||
nWorkers: Int): RDD[(String, Iterator[GpuColumnBatch])] = {
|
|
||||||
val emptyDataRdd = sc.parallelize(
|
|
||||||
Array.fill[(String, Iterator[GpuColumnBatch])](nWorkers)(null), nWorkers)
|
|
||||||
|
|
||||||
dataMap.foldLeft(emptyDataRdd) {
|
|
||||||
case (zippedRdd, (name, gdfColData)) =>
|
|
||||||
zippedRdd.zipPartitions(GpuUtils.toColumnarRdd(gdfColData.rawDF)) {
|
|
||||||
(itWrapper, iterCol) =>
|
|
||||||
val itCol = iterCol.map(table => new GpuColumnBatch(table, null))
|
|
||||||
(itWrapper.toArray :+ (name -> itCol)).filter(x => x != null).toIterator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class RapidsIterator(
|
|
||||||
base: Iterator[GpuColumnBatch],
|
|
||||||
indices: ColumnIndices) extends Iterator[CudfColumnBatch] {
|
|
||||||
|
|
||||||
override def hasNext: Boolean = base.hasNext
|
|
||||||
|
|
||||||
override def next(): CudfColumnBatch = {
|
|
||||||
// Since we have sliced original Table into different tables. Needs to close the original one.
|
|
||||||
withResource(base.next()) { gpuColumnBatch =>
|
|
||||||
val weights = indices.weightId.map(Seq(_)).getOrElse(Seq.empty)
|
|
||||||
val margins = indices.marginId.map(Seq(_)).getOrElse(Seq.empty)
|
|
||||||
|
|
||||||
new CudfColumnBatch(
|
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
|
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
|
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
|
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
|
|
||||||
null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Executes the provided code block and then closes the resource */
|
|
||||||
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
|
||||||
try {
|
|
||||||
block(r)
|
|
||||||
} finally {
|
|
||||||
r.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,178 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import ai.rapids.cudf.Table
|
|
||||||
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
|
||||||
import org.apache.spark.ml.param.{Param, Params}
|
|
||||||
import org.apache.spark.sql.functions.col
|
|
||||||
import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
|
|
||||||
import org.apache.spark.sql.vectorized.ColumnVector
|
|
||||||
|
|
||||||
private[spark] object GpuUtils {
|
|
||||||
|
|
||||||
def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
|
|
||||||
// spark-rapids has shimmed the GpuColumnVector from 22.10
|
|
||||||
GpuColumnVectorUtils.extractHostColumns(table, types)
|
|
||||||
}
|
|
||||||
|
|
||||||
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
|
|
||||||
|
|
||||||
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
|
|
||||||
|
|
||||||
/** APIs for gpu column data related */
|
|
||||||
def buildColumnDataBatch(featureNames: Seq[String],
|
|
||||||
labelName: String,
|
|
||||||
weightName: String,
|
|
||||||
marginName: String,
|
|
||||||
groupName: String,
|
|
||||||
dataFrame: DataFrame): ColumnDataBatch = {
|
|
||||||
// Some check first
|
|
||||||
val schema = dataFrame.schema
|
|
||||||
val featureNameSet = featureNames.distinct
|
|
||||||
GpuUtils.validateSchema(schema, featureNameSet, labelName, weightName, marginName)
|
|
||||||
|
|
||||||
// group column
|
|
||||||
val (opGroup, groupId) = if (groupName.isEmpty) {
|
|
||||||
(None, None)
|
|
||||||
} else {
|
|
||||||
GpuUtils.checkNumericType(schema, groupName)
|
|
||||||
(Some(groupName), Some(schema.fieldIndex(groupName)))
|
|
||||||
}
|
|
||||||
// weight and base margin columns
|
|
||||||
val Seq(weightId, marginId) = Seq(weightName, marginName).map {
|
|
||||||
name =>
|
|
||||||
if (name.isEmpty) None else Some(schema.fieldIndex(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
val colsIndices = ColumnIndices(featureNameSet.map(schema.fieldIndex),
|
|
||||||
schema.fieldIndex(labelName), weightId, marginId, groupId)
|
|
||||||
ColumnDataBatch(dataFrame, colsIndices, opGroup)
|
|
||||||
}
|
|
||||||
|
|
||||||
def checkNumericType(schema: StructType, colName: String,
|
|
||||||
msg: String = ""): Unit = {
|
|
||||||
val actualDataType = schema(colName).dataType
|
|
||||||
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
|
|
||||||
require(actualDataType.isInstanceOf[NumericType],
|
|
||||||
s"Column $colName must be of NumericType but found: " +
|
|
||||||
s"${actualDataType.catalogString}.$message")
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Check and Cast the columns to FloatType */
|
|
||||||
def prepareColumnType(
|
|
||||||
dataset: Dataset[_],
|
|
||||||
featureNames: Seq[String],
|
|
||||||
labelName: String = "",
|
|
||||||
weightName: String = "",
|
|
||||||
marginName: String = "",
|
|
||||||
fitting: Boolean = true): DataFrame = {
|
|
||||||
// check first
|
|
||||||
val featureNameSet = featureNames.distinct
|
|
||||||
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
|
|
||||||
|
|
||||||
val castToFloat = (df: DataFrame, colName: String) => {
|
|
||||||
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
|
|
||||||
df
|
|
||||||
} else {
|
|
||||||
val colMeta = df.schema(colName).metadata
|
|
||||||
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val colNames = if (fitting) {
|
|
||||||
var names = featureNameSet :+ labelName
|
|
||||||
if (weightName.nonEmpty) {
|
|
||||||
names = names :+ weightName
|
|
||||||
}
|
|
||||||
if (marginName.nonEmpty) {
|
|
||||||
names = names :+ marginName
|
|
||||||
}
|
|
||||||
names
|
|
||||||
} else {
|
|
||||||
featureNameSet
|
|
||||||
}
|
|
||||||
colNames.foldLeft(dataset.asInstanceOf[DataFrame])(
|
|
||||||
(ds, colName) => castToFloat(ds, colName))
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Validate input schema */
|
|
||||||
def validateSchema(schema: StructType,
|
|
||||||
featureNames: Seq[String],
|
|
||||||
labelName: String = "",
|
|
||||||
weightName: String = "",
|
|
||||||
marginName: String = "",
|
|
||||||
fitting: Boolean = true): StructType = {
|
|
||||||
val msg = if (fitting) "train" else "transform"
|
|
||||||
// feature columns
|
|
||||||
require(featureNames.nonEmpty, s"Gpu $msg requires features columns. " +
|
|
||||||
"please refer to `setFeaturesCol(value: Array[String])`!")
|
|
||||||
featureNames.foreach(fn => checkNumericType(schema, fn))
|
|
||||||
if (fitting) {
|
|
||||||
require(labelName.nonEmpty, "label column is not set.")
|
|
||||||
checkNumericType(schema, labelName)
|
|
||||||
|
|
||||||
if (weightName.nonEmpty) {
|
|
||||||
checkNumericType(schema, weightName)
|
|
||||||
}
|
|
||||||
if (marginName.nonEmpty) {
|
|
||||||
checkNumericType(schema, marginName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema
|
|
||||||
}
|
|
||||||
|
|
||||||
def time[R](block: => R): (R, Float) = {
|
|
||||||
val t0 = System.currentTimeMillis
|
|
||||||
val result = block // call-by-name
|
|
||||||
val t1 = System.currentTimeMillis
|
|
||||||
(result, (t1 - t0).toFloat / 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Get column names from Parameter */
|
|
||||||
def getColumnNames(params: Params)(cols: Param[String]*): Seq[String] = {
|
|
||||||
// get column name, null | undefined will be casted to ""
|
|
||||||
def getColumnName(params: Params)(param: Param[String]): String = {
|
|
||||||
if (params.isDefined(param)) {
|
|
||||||
val colName = params.getOrDefault(param)
|
|
||||||
if (colName != null) colName else ""
|
|
||||||
} else ""
|
|
||||||
}
|
|
||||||
|
|
||||||
val getName = getColumnName(params)(_)
|
|
||||||
cols.map(getName)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A container to contain the column ids
|
|
||||||
*/
|
|
||||||
private[spark] case class ColumnIndices(
|
|
||||||
featureIds: Seq[Int],
|
|
||||||
labelId: Int,
|
|
||||||
weightId: Option[Int],
|
|
||||||
marginId: Option[Int],
|
|
||||||
groupId: Option[Int])
|
|
||||||
|
|
||||||
private[spark] case class ColumnDataBatch(
|
|
||||||
rawDF: DataFrame,
|
|
||||||
colIndices: ColumnIndices,
|
|
||||||
groupColName: Option[String])
|
|
||||||
@ -0,0 +1,315 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
import scala.jdk.CollectionConverters._
|
||||||
|
|
||||||
|
import ai.rapids.cudf.Table
|
||||||
|
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.spark.TaskContext
|
||||||
|
import org.apache.spark.ml.param.Param
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
|
||||||
|
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
|
||||||
|
import org.apache.spark.sql.types.{DataType, FloatType, IntegerType}
|
||||||
|
import org.apache.spark.sql.vectorized.ColumnarBatch
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GpuXGBoostPlugin is the XGBoost plugin which leverages spark-rapids
|
||||||
|
* to accelerate the XGBoost from ETL to train.
|
||||||
|
*/
|
||||||
|
class GpuXGBoostPlugin extends XGBoostPlugin {
|
||||||
|
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether the plugin is enabled or not, if not enabled, fallback
|
||||||
|
* to the regular CPU pipeline
|
||||||
|
*
|
||||||
|
* @param dataset the input dataset
|
||||||
|
* @return Boolean
|
||||||
|
*/
|
||||||
|
override def isEnabled(dataset: Dataset[_]): Boolean = {
|
||||||
|
val conf = dataset.sparkSession.conf
|
||||||
|
val hasRapidsPlugin = conf.get("spark.plugins", "").split(",").contains(
|
||||||
|
"com.nvidia.spark.SQLPlugin")
|
||||||
|
val rapidsEnabled = try {
|
||||||
|
conf.get("spark.rapids.sql.enabled").toBoolean
|
||||||
|
} catch {
|
||||||
|
// Rapids plugin has default "spark.rapids.sql.enabled" to true
|
||||||
|
case _: NoSuchElementException => true
|
||||||
|
case _: Throwable => false // Any exception will return false
|
||||||
|
}
|
||||||
|
hasRapidsPlugin && rapidsEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO, support numeric type
|
||||||
|
private[spark] def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
|
||||||
|
estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): Dataset[_] = {
|
||||||
|
|
||||||
|
// Columns to be selected for XGBoost training
|
||||||
|
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
|
||||||
|
val schema = dataset.schema
|
||||||
|
|
||||||
|
def selectCol(c: Param[String], targetType: DataType = FloatType) = {
|
||||||
|
// TODO support numeric types
|
||||||
|
if (estimator.isDefinedNonEmpty(c)) {
|
||||||
|
selectedCols.append(estimator.castIfNeeded(schema, estimator.getOrDefault(c), targetType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol)
|
||||||
|
.foreach(p => selectCol(p))
|
||||||
|
estimator match {
|
||||||
|
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO support array/vector feature
|
||||||
|
estimator.getFeaturesCols.foreach { name =>
|
||||||
|
val col = estimator.castIfNeeded(dataset.schema, name)
|
||||||
|
selectedCols.append(col)
|
||||||
|
}
|
||||||
|
val input = dataset.select(selectedCols.toArray: _*)
|
||||||
|
estimator.repartitionIfNeeded(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// visible for testing
|
||||||
|
private[spark] def validate[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
|
||||||
|
estimator: XGBoostEstimator[T, M],
|
||||||
|
dataset: Dataset[_]): Unit = {
|
||||||
|
require(estimator.getTreeMethod == "gpu_hist" || estimator.getDevice != "cpu",
|
||||||
|
"Using Spark-Rapids to accelerate XGBoost must set device=cuda")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert Dataset to RDD[Watches] which will be fed into XGBoost
|
||||||
|
*
|
||||||
|
* @param estimator which estimator to be handled.
|
||||||
|
* @param dataset to be converted.
|
||||||
|
* @return RDD[Watches]
|
||||||
|
*/
|
||||||
|
override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
|
||||||
|
estimator: XGBoostEstimator[T, M],
|
||||||
|
dataset: Dataset[_]): RDD[Watches] = {
|
||||||
|
|
||||||
|
validate(estimator, dataset)
|
||||||
|
|
||||||
|
val train = preprocess(estimator, dataset)
|
||||||
|
val schema = train.schema
|
||||||
|
|
||||||
|
val indices = estimator.buildColumnIndices(schema)
|
||||||
|
|
||||||
|
val maxBin = estimator.getMaxBins
|
||||||
|
val nthread = estimator.getNthread
|
||||||
|
val missing = estimator.getMissing
|
||||||
|
|
||||||
|
/** build QuantileDMatrix on the executor side */
|
||||||
|
def buildQuantileDMatrix(iter: Iterator[Table],
|
||||||
|
ref: Option[QuantileDMatrix] = None): QuantileDMatrix = {
|
||||||
|
val colBatchIter = iter.map { table =>
|
||||||
|
withResource(new GpuColumnBatch(table)) { batch =>
|
||||||
|
new CudfColumnBatch(
|
||||||
|
batch.select(indices.featureIds.get),
|
||||||
|
batch.select(indices.labelId),
|
||||||
|
batch.select(indices.weightId.getOrElse(-1)),
|
||||||
|
batch.select(indices.marginId.getOrElse(-1)),
|
||||||
|
batch.select(indices.groupId.getOrElse(-1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ref.map(r => new QuantileDMatrix(colBatchIter, r, missing, maxBin, nthread)).getOrElse(
|
||||||
|
new QuantileDMatrix(colBatchIter, missing, maxBin, nthread)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
estimator.getEvalDataset().map { evalDs =>
|
||||||
|
val evalProcessed = preprocess(estimator, evalDs)
|
||||||
|
ColumnarRdd(train.toDF()).zipPartitions(ColumnarRdd(evalProcessed.toDF())) {
|
||||||
|
(trainIter, evalIter) =>
|
||||||
|
new Iterator[Watches] {
|
||||||
|
override def hasNext: Boolean = trainIter.hasNext
|
||||||
|
override def next(): Watches = {
|
||||||
|
val trainDM = buildQuantileDMatrix(trainIter)
|
||||||
|
val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM))
|
||||||
|
new Watches(Array(trainDM, evalDM),
|
||||||
|
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.getOrElse(
|
||||||
|
ColumnarRdd(train.toDF()).mapPartitions { iter =>
|
||||||
|
new Iterator[Watches] {
|
||||||
|
override def hasNext: Boolean = iter.hasNext
|
||||||
|
override def next(): Watches = {
|
||||||
|
val dm = buildQuantileDMatrix(iter)
|
||||||
|
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def transform[M <: XGBoostModel[M]](model: XGBoostModel[M],
|
||||||
|
dataset: Dataset[_]): DataFrame = {
|
||||||
|
val sc = dataset.sparkSession.sparkContext
|
||||||
|
|
||||||
|
val (transformedSchema, pred) = model.preprocess(dataset)
|
||||||
|
val bBooster = sc.broadcast(model.nativeBooster)
|
||||||
|
val bOriginalSchema = sc.broadcast(dataset.schema)
|
||||||
|
|
||||||
|
val featureIds = model.getFeaturesCols.distinct.map(dataset.schema.fieldIndex).toList
|
||||||
|
val isLocal = sc.isLocal
|
||||||
|
val missing = model.getMissing
|
||||||
|
val nThread = model.getNthread
|
||||||
|
|
||||||
|
val rdd = ColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions { tableIters =>
|
||||||
|
// booster is visible for all spark tasks in the same executor
|
||||||
|
val booster = bBooster.value
|
||||||
|
val originalSchema = bOriginalSchema.value
|
||||||
|
|
||||||
|
// UnsafeProjection is not serializable so do it on the executor side
|
||||||
|
val toUnsafe = UnsafeProjection.create(originalSchema)
|
||||||
|
|
||||||
|
if (!booster.deviceIsSet) {
|
||||||
|
booster.deviceIsSet.synchronized {
|
||||||
|
if (!booster.deviceIsSet) {
|
||||||
|
booster.deviceIsSet = true
|
||||||
|
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
|
||||||
|
booster.setParam("device", s"cuda:$gpuId")
|
||||||
|
logger.info("GPU transform on GPU device: cuda:" + gpuId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterator on Row
|
||||||
|
new Iterator[Row] {
|
||||||
|
// Convert InternalRow to Row
|
||||||
|
private val converter: InternalRow => Row = CatalystTypeConverters
|
||||||
|
.createToScalaConverter(originalSchema)
|
||||||
|
.asInstanceOf[InternalRow => Row]
|
||||||
|
|
||||||
|
// GPU batches read in must be closed by the receiver
|
||||||
|
@transient var currentBatch: ColumnarBatch = null
|
||||||
|
|
||||||
|
// Iterator on Row
|
||||||
|
var iter: Iterator[Row] = null
|
||||||
|
|
||||||
|
TaskContext.get().addTaskCompletionListener[Unit](_ => {
|
||||||
|
closeCurrentBatch() // close the last ColumnarBatch
|
||||||
|
})
|
||||||
|
|
||||||
|
private def closeCurrentBatch(): Unit = {
|
||||||
|
if (currentBatch != null) {
|
||||||
|
currentBatch.close()
|
||||||
|
currentBatch = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def loadNextBatch(): Unit = {
|
||||||
|
closeCurrentBatch()
|
||||||
|
if (tableIters.hasNext) {
|
||||||
|
val dataTypes = originalSchema.fields.map(x => x.dataType)
|
||||||
|
iter = withResource(tableIters.next()) { table =>
|
||||||
|
// Create DMatrix
|
||||||
|
val featureTable = new GpuColumnBatch(table).select(featureIds)
|
||||||
|
if (featureTable == null) {
|
||||||
|
val msg = featureIds.mkString(",")
|
||||||
|
throw new RuntimeException(s"Couldn't create feature table for the " +
|
||||||
|
s"feature indices $msg")
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
val cudfColumnBatch = new CudfColumnBatch(featureTable, null, null, null, null)
|
||||||
|
val dm = new DMatrix(cudfColumnBatch, missing, nThread)
|
||||||
|
if (dm == null) {
|
||||||
|
Iterator.empty
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
currentBatch = new ColumnarBatch(
|
||||||
|
GpuColumnVectorUtils.extractHostColumns(table, dataTypes),
|
||||||
|
table.getRowCount().toInt)
|
||||||
|
val rowIterator = currentBatch.rowIterator().asScala.map(toUnsafe)
|
||||||
|
.map(converter(_))
|
||||||
|
model.predictInternal(booster, dm, pred, rowIterator).toIterator
|
||||||
|
} finally {
|
||||||
|
dm.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
featureTable.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
iter = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def hasNext: Boolean = {
|
||||||
|
val itHasNext = iter != null && iter.hasNext
|
||||||
|
if (!itHasNext) { // Don't have extra Row for current ColumnarBatch
|
||||||
|
loadNextBatch()
|
||||||
|
iter != null && iter.hasNext
|
||||||
|
} else {
|
||||||
|
itHasNext
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def next(): Row = {
|
||||||
|
if (iter == null || !iter.hasNext) {
|
||||||
|
loadNextBatch()
|
||||||
|
}
|
||||||
|
if (iter == null) {
|
||||||
|
throw new NoSuchElementException()
|
||||||
|
}
|
||||||
|
iter.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bBooster.unpersist(false)
|
||||||
|
bOriginalSchema.unpersist(false)
|
||||||
|
|
||||||
|
val output = dataset.sparkSession.createDataFrame(rdd, transformedSchema)
|
||||||
|
model.postTransform(output, pred).toDF()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class GpuColumnBatch(table: Table) extends AutoCloseable {
|
||||||
|
|
||||||
|
def select(index: Int): Table = {
|
||||||
|
select(Seq(index))
|
||||||
|
}
|
||||||
|
|
||||||
|
def select(indices: Seq[Int]): Table = {
|
||||||
|
if (!indices.forall(index => index < table.getNumberOfColumns && index >= 0)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
new Table(indices.map(table.getColumn): _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def close(): Unit = {
|
||||||
|
if (Option(table).isDefined) {
|
||||||
|
table.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,9 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import ai.rapids.cudf.Table;
|
import ai.rapids.cudf.Table;
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
@ -122,8 +120,7 @@ public class DMatrixTest {
|
|||||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
|
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
|
||||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
|
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
|
||||||
|
|
||||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
QuantileDMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
||||||
|
|
||||||
float[] anchorLabel = convertFloatTofloat(label1, label2);
|
float[] anchorLabel = convertFloatTofloat(label1, label2);
|
||||||
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
|
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
|
||||||
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
|
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
|
||||||
@ -135,6 +132,57 @@ public class DMatrixTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Float[] generateFloatArray(int size, long seed) {
|
||||||
|
Float[] array = new Float[size];
|
||||||
|
Random random = new Random(seed);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
array[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetQuantileCut() throws XGBoostError {
|
||||||
|
|
||||||
|
int rows = 100;
|
||||||
|
try (
|
||||||
|
Table X_0 = new Table.TestBuilder()
|
||||||
|
.column(generateFloatArray(rows, 1l))
|
||||||
|
.column(generateFloatArray(rows, 2l))
|
||||||
|
.column(generateFloatArray(rows, 3l))
|
||||||
|
.column(generateFloatArray(rows, 4l))
|
||||||
|
.column(generateFloatArray(rows, 5l))
|
||||||
|
.build();
|
||||||
|
Table y_0 = new Table.TestBuilder().column(generateFloatArray(rows, 6l)).build();
|
||||||
|
|
||||||
|
Table X_1 = new Table.TestBuilder()
|
||||||
|
.column(generateFloatArray(rows, 11l))
|
||||||
|
.column(generateFloatArray(rows, 12l))
|
||||||
|
.column(generateFloatArray(rows, 13l))
|
||||||
|
.column(generateFloatArray(rows, 14l))
|
||||||
|
.column(generateFloatArray(rows, 15l))
|
||||||
|
.build();
|
||||||
|
Table y_1 = new Table.TestBuilder().column(generateFloatArray(rows, 16l)).build();
|
||||||
|
) {
|
||||||
|
List<ColumnBatch> tables = new LinkedList<>();
|
||||||
|
tables.add(new CudfColumnBatch(X_0, y_0, null, null, null));
|
||||||
|
QuantileDMatrix train = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
||||||
|
|
||||||
|
tables.clear();
|
||||||
|
tables.add(new CudfColumnBatch(X_1, y_1, null, null, null));
|
||||||
|
QuantileDMatrix eval = new QuantileDMatrix(tables.iterator(), train, 0.0f, 256, 1);
|
||||||
|
|
||||||
|
DMatrix.QuantileCut trainCut = train.getQuantileCut();
|
||||||
|
DMatrix.QuantileCut evalCut = eval.getQuantileCut();
|
||||||
|
|
||||||
|
TestCase.assertTrue(trainCut.getIndptr().length == evalCut.getIndptr().length);
|
||||||
|
TestCase.assertTrue(Arrays.equals(trainCut.getIndptr(), evalCut.getIndptr()));
|
||||||
|
|
||||||
|
TestCase.assertTrue(trainCut.getValues().length == evalCut.getValues().length);
|
||||||
|
TestCase.assertTrue(Arrays.equals(trainCut.getValues(), evalCut.getValues()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private float[] convertFloatTofloat(Float[]... datas) {
|
private float[] convertFloatTofloat(Float[]... datas) {
|
||||||
int totalLength = 0;
|
int totalLength = 0;
|
||||||
for (Float[] data : datas) {
|
for (Float[] data : datas) {
|
||||||
|
|||||||
@ -16,11 +16,13 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
import ai.rapids.cudf.Table
|
import ai.rapids.cudf.Table
|
||||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
||||||
|
|
||||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||||
|
|
||||||
@ -73,13 +75,4 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Executes the provided code block and then closes the resource */
|
|
||||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
|
||||||
try {
|
|
||||||
block(r)
|
|
||||||
} finally {
|
|
||||||
r.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,288 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import java.nio.file.{Files, Path}
|
|
||||||
import java.sql.{Date, Timestamp}
|
|
||||||
import java.util.{Locale, TimeZone}
|
|
||||||
|
|
||||||
import org.scalatest.BeforeAndAfterAll
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
import org.apache.spark.{GpuTestUtils, SparkConf}
|
|
||||||
import org.apache.spark.internal.Logging
|
|
||||||
import org.apache.spark.network.util.JavaUtils
|
|
||||||
import org.apache.spark.sql.{Row, SparkSession}
|
|
||||||
|
|
||||||
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
|
|
||||||
import SparkSessionHolder.withSparkSession
|
|
||||||
|
|
||||||
protected def getResourcePath(resource: String): String = {
|
|
||||||
require(resource.startsWith("/"), "resource must start with /")
|
|
||||||
getClass.getResource(resource).getPath
|
|
||||||
}
|
|
||||||
|
|
||||||
def enableCsvConf(): SparkConf = {
|
|
||||||
new SparkConf()
|
|
||||||
.set("spark.rapids.sql.csv.read.float.enabled", "true")
|
|
||||||
.set("spark.rapids.sql.csv.read.double.enabled", "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
|
|
||||||
// set "spark.rapids.sql.explain" to "ALL" to check if the operators
|
|
||||||
// can be replaced by GPU
|
|
||||||
val c = conf.clone()
|
|
||||||
.set("spark.rapids.sql.enabled", "true")
|
|
||||||
withSparkSession(c, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
def withCpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
|
|
||||||
val c = conf.clone()
|
|
||||||
.set("spark.rapids.sql.enabled", "false") // Just to be sure
|
|
||||||
withSparkSession(c, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
def compareResults(
|
|
||||||
sort: Boolean,
|
|
||||||
floatEpsilon: Double,
|
|
||||||
fromLeft: Array[Row],
|
|
||||||
fromRight: Array[Row]): Boolean = {
|
|
||||||
if (sort) {
|
|
||||||
val left = fromLeft.map(_.toSeq).sortWith(seqLt)
|
|
||||||
val right = fromRight.map(_.toSeq).sortWith(seqLt)
|
|
||||||
compare(left, right, floatEpsilon)
|
|
||||||
} else {
|
|
||||||
compare(fromLeft, fromRight, floatEpsilon)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// we guarantee that the types will be the same
|
|
||||||
private def seqLt(a: Seq[Any], b: Seq[Any]): Boolean = {
|
|
||||||
if (a.length < b.length) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// lengths are the same
|
|
||||||
for (i <- a.indices) {
|
|
||||||
val v1 = a(i)
|
|
||||||
val v2 = b(i)
|
|
||||||
if (v1 != v2) {
|
|
||||||
// null is always < anything but null
|
|
||||||
if (v1 == null) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if (v2 == null) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
(v1, v2) match {
|
|
||||||
case (i1: Int, i2: Int) => if (i1 < i2) {
|
|
||||||
return true
|
|
||||||
} else if (i1 > i2) {
|
|
||||||
return false
|
|
||||||
}// else equal go on
|
|
||||||
case (i1: Long, i2: Long) => if (i1 < i2) {
|
|
||||||
return true
|
|
||||||
} else if (i1 > i2) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (i1: Float, i2: Float) => if (i1.isNaN() && !i2.isNaN()) return false
|
|
||||||
else if (!i1.isNaN() && i2.isNaN()) return true
|
|
||||||
else if (i1 < i2) {
|
|
||||||
return true
|
|
||||||
} else if (i1 > i2) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (i1: Date, i2: Date) => if (i1.before(i2)) {
|
|
||||||
return true
|
|
||||||
} else if (i1.after(i2)) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (i1: Double, i2: Double) => if (i1.isNaN() && !i2.isNaN()) return false
|
|
||||||
else if (!i1.isNaN() && i2.isNaN()) return true
|
|
||||||
else if (i1 < i2) {
|
|
||||||
return true
|
|
||||||
} else if (i1 > i2) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (i1: Short, i2: Short) => if (i1 < i2) {
|
|
||||||
return true
|
|
||||||
} else if (i1 > i2) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (i1: Timestamp, i2: Timestamp) => if (i1.before(i2)) {
|
|
||||||
return true
|
|
||||||
} else if (i1.after(i2)) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (s1: String, s2: String) =>
|
|
||||||
val cmp = s1.compareTo(s2)
|
|
||||||
if (cmp < 0) {
|
|
||||||
return true
|
|
||||||
} else if (cmp > 0) {
|
|
||||||
return false
|
|
||||||
} // else equal go on
|
|
||||||
case (o1, _) =>
|
|
||||||
throw new UnsupportedOperationException(o1.getClass + " is not supported yet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// They are equal...
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
private def compare(expected: Any, actual: Any, epsilon: Double = 0.0): Boolean = {
|
|
||||||
def doublesAreEqualWithinPercentage(expected: Double, actual: Double): (String, Boolean) = {
|
|
||||||
if (!compare(expected, actual)) {
|
|
||||||
if (expected != 0) {
|
|
||||||
val v = Math.abs((expected - actual) / expected)
|
|
||||||
(s"\n\nABS($expected - $actual) / ABS($actual) == $v is not <= $epsilon ", v <= epsilon)
|
|
||||||
} else {
|
|
||||||
val v = Math.abs(expected - actual)
|
|
||||||
(s"\n\nABS($expected - $actual) == $v is not <= $epsilon ", v <= epsilon)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
("SUCCESS", true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(expected, actual) match {
|
|
||||||
case (a: Float, b: Float) if a.isNaN && b.isNaN => true
|
|
||||||
case (a: Double, b: Double) if a.isNaN && b.isNaN => true
|
|
||||||
case (null, null) => true
|
|
||||||
case (null, _) => false
|
|
||||||
case (_, null) => false
|
|
||||||
case (a: Array[_], b: Array[_]) =>
|
|
||||||
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
|
|
||||||
case (a: Map[_, _], b: Map[_, _]) =>
|
|
||||||
a.size == b.size && a.keys.forall { aKey =>
|
|
||||||
b.keys.find(bKey => compare(aKey, bKey))
|
|
||||||
.exists(bKey => compare(a(aKey), b(bKey), epsilon))
|
|
||||||
}
|
|
||||||
case (a: Iterable[_], b: Iterable[_]) =>
|
|
||||||
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
|
|
||||||
case (a: Product, b: Product) =>
|
|
||||||
compare(a.productIterator.toSeq, b.productIterator.toSeq, epsilon)
|
|
||||||
case (a: Row, b: Row) =>
|
|
||||||
compare(a.toSeq, b.toSeq, epsilon)
|
|
||||||
// 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
|
|
||||||
case (a: Double, b: Double) if epsilon <= 0 =>
|
|
||||||
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
|
|
||||||
case (a: Double, b: Double) if epsilon > 0 =>
|
|
||||||
val ret = doublesAreEqualWithinPercentage(a, b)
|
|
||||||
if (!ret._2) {
|
|
||||||
System.err.println(ret._1 + " (double)")
|
|
||||||
}
|
|
||||||
ret._2
|
|
||||||
case (a: Float, b: Float) if epsilon <= 0 =>
|
|
||||||
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
|
|
||||||
case (a: Float, b: Float) if epsilon > 0 =>
|
|
||||||
val ret = doublesAreEqualWithinPercentage(a, b)
|
|
||||||
if (!ret._2) {
|
|
||||||
System.err.println(ret._1 + " (float)")
|
|
||||||
}
|
|
||||||
ret._2
|
|
||||||
case (a, b) => a == b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
|
|
||||||
protected var tempDir: Path = _
|
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
|
||||||
super.beforeAll()
|
|
||||||
tempDir = Files.createTempDirectory(getClass.getName)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def afterAll(): Unit = {
|
|
||||||
JavaUtils.deleteRecursively(tempDir.toFile)
|
|
||||||
super.afterAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
protected def createTmpFolder(prefix: String): Path = {
|
|
||||||
Files.createTempDirectory(tempDir, prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object SparkSessionHolder extends Logging {
|
|
||||||
|
|
||||||
private var spark = createSparkSession()
|
|
||||||
private var origConf = spark.conf.getAll
|
|
||||||
private var origConfKeys = origConf.keys.toSet
|
|
||||||
|
|
||||||
private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
|
|
||||||
case (key, value) if spark.conf.get(key, null) != value =>
|
|
||||||
spark.conf.set(key, value)
|
|
||||||
case _ => // No need to modify it
|
|
||||||
}
|
|
||||||
|
|
||||||
private def createSparkSession(): SparkSession = {
|
|
||||||
GpuTestUtils.cleanupAnyExistingSession()
|
|
||||||
|
|
||||||
// Timezone is fixed to UTC to allow timestamps to work by default
|
|
||||||
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
|
|
||||||
// Add Locale setting
|
|
||||||
Locale.setDefault(Locale.US)
|
|
||||||
|
|
||||||
val builder = SparkSession.builder()
|
|
||||||
.master("local[2]")
|
|
||||||
.config("spark.sql.adaptive.enabled", "false")
|
|
||||||
.config("spark.rapids.sql.enabled", "false")
|
|
||||||
.config("spark.rapids.sql.test.enabled", "false")
|
|
||||||
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
|
|
||||||
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
|
|
||||||
.config("spark.sql.files.maxPartitionBytes", "1000")
|
|
||||||
.appName("XGBoost4j-Spark-Gpu unit test")
|
|
||||||
|
|
||||||
builder.getOrCreate()
|
|
||||||
}
|
|
||||||
|
|
||||||
private def reinitSession(): Unit = {
|
|
||||||
spark = createSparkSession()
|
|
||||||
origConf = spark.conf.getAll
|
|
||||||
origConfKeys = origConf.keys.toSet
|
|
||||||
}
|
|
||||||
|
|
||||||
def sparkSession: SparkSession = {
|
|
||||||
if (SparkSession.getActiveSession.isEmpty) {
|
|
||||||
reinitSession()
|
|
||||||
}
|
|
||||||
spark
|
|
||||||
}
|
|
||||||
|
|
||||||
def resetSparkSessionConf(): Unit = {
|
|
||||||
if (SparkSession.getActiveSession.isEmpty) {
|
|
||||||
reinitSession()
|
|
||||||
} else {
|
|
||||||
setAllConfs(origConf.toArray)
|
|
||||||
val currentKeys = spark.conf.getAll.keys.toSet
|
|
||||||
val toRemove = currentKeys -- origConfKeys
|
|
||||||
toRemove.foreach(spark.conf.unset)
|
|
||||||
}
|
|
||||||
logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
|
|
||||||
}
|
|
||||||
|
|
||||||
def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
|
|
||||||
resetSparkSessionConf
|
|
||||||
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
|
|
||||||
setAllConfs(conf.getAll)
|
|
||||||
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
|
|
||||||
spark.sparkContext.setLogLevel("WARN")
|
|
||||||
f(spark)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,232 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.apache.spark.sql.functions.{col, udf, when}
|
|
||||||
import org.apache.spark.sql.types.{FloatType, StructField, StructType}
|
|
||||||
|
|
||||||
class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
|
||||||
private val dataPath = if (new java.io.File("../../demo/data/veterans_lung_cancer.csv").isFile) {
|
|
||||||
"../../demo/data/veterans_lung_cancer.csv"
|
|
||||||
} else {
|
|
||||||
"../demo/data/veterans_lung_cancer.csv"
|
|
||||||
}
|
|
||||||
|
|
||||||
val labelName = "label_col"
|
|
||||||
val schema = StructType(Seq(
|
|
||||||
StructField("f1", FloatType), StructField("f2", FloatType), StructField("f3", FloatType),
|
|
||||||
StructField("f4", FloatType), StructField("f5", FloatType), StructField("f6", FloatType),
|
|
||||||
StructField("f7", FloatType), StructField("f8", FloatType), StructField("f9", FloatType),
|
|
||||||
StructField("f10", FloatType), StructField("f11", FloatType), StructField("f12", FloatType),
|
|
||||||
StructField(labelName, FloatType)
|
|
||||||
))
|
|
||||||
val featureNames = schema.fieldNames.filter(s => !s.equals(labelName))
|
|
||||||
|
|
||||||
test("The transform result should be same for several runs on same model") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
// Get a model
|
|
||||||
val model = new XGBoostClassifier(xgbParam)
|
|
||||||
.fit(originalDf)
|
|
||||||
val left = model.transform(testDf).collect()
|
|
||||||
val right = model.transform(testDf).collect()
|
|
||||||
// The left should be same with right
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("use weight") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
|
|
||||||
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
|
|
||||||
|
|
||||||
val model = new XGBoostClassifier(xgbParam)
|
|
||||||
.fit(originalDf)
|
|
||||||
val model2 = new XGBoostClassifier(xgbParam)
|
|
||||||
.setWeightCol("weight")
|
|
||||||
.fit(dfWithWeight)
|
|
||||||
|
|
||||||
val left = model.transform(testDf).collect()
|
|
||||||
val right = model2.transform(testDf).collect()
|
|
||||||
// left should be different with right
|
|
||||||
assert(!compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Save model and transform GPU dataset") {
|
|
||||||
// Train a model on GPU
|
|
||||||
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val classifier = new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setTreeMethod("gpu_hist")
|
|
||||||
(classifier.fit(rawInput), testDf)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
gpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on GPU
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
val left = gpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "rawPrediction", "probability", "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = gpuModelFromFile
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "rawPrediction", "probability", "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Model trained on CPU can transform GPU dataset") {
|
|
||||||
// Train a model on CPU
|
|
||||||
val cpuModel = withCpuSparkSession() { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
.setInputCols(featureNames)
|
|
||||||
.setOutputCol("features")
|
|
||||||
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
|
||||||
|
|
||||||
val classifier = new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol("features")
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setTreeMethod("auto")
|
|
||||||
classifier.fit(trainingDf)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
cpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val cpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on GPU
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
// Since CPU model does not know the information about the features cols that GPU transform
|
|
||||||
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
|
||||||
// manually
|
|
||||||
val thrown = intercept[NoSuchElementException](cpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.collect())
|
|
||||||
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
|
||||||
|
|
||||||
val left = cpuModel
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.transform(testDf)
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = cpuModelFromFile
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.transform(testDf)
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Model trained on GPU can transform CPU dataset") {
|
|
||||||
// Train a model on GPU
|
|
||||||
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val classifier = new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setTreeMethod("gpu_hist")
|
|
||||||
classifier.fit(rawInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
gpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on CPU
|
|
||||||
withCpuSparkSession() { spark =>
|
|
||||||
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
|
||||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val featureColName = "feature_col"
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
.setInputCols(featureNames)
|
|
||||||
.setOutputCol(featureColName)
|
|
||||||
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
|
||||||
|
|
||||||
// Since GPU model does not know the information about the features col name that CPU
|
|
||||||
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
|
||||||
intercept[IllegalArgumentException](
|
|
||||||
gpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.collect())
|
|
||||||
|
|
||||||
val left = gpuModel
|
|
||||||
.setFeaturesCol(featureColName)
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "rawPrediction", "probability", "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = gpuModelFromFile
|
|
||||||
.setFeaturesCol(featureColName)
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "rawPrediction", "probability", "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,212 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier}
|
|
||||||
|
|
||||||
import org.apache.spark.sql.functions.col
|
|
||||||
import org.apache.spark.sql.types.StringType
|
|
||||||
|
|
||||||
class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|
||||||
|
|
||||||
private val labelName = "label_col"
|
|
||||||
private val weightName = "weight_col"
|
|
||||||
private val baseMarginName = "margin_col"
|
|
||||||
private val featureNames = Array("f1", "f2", "f3")
|
|
||||||
private val allColumnNames = featureNames :+ weightName :+ baseMarginName :+ labelName
|
|
||||||
private val trainingData = Seq(
|
|
||||||
// f1, f2, f3, weight, margin, label
|
|
||||||
(1.0f, 2.0f, 3.0f, 1.0f, 0.5f, 0),
|
|
||||||
(2.0f, 3.0f, 4.0f, 2.0f, 0.6f, 0),
|
|
||||||
(1.2f, 2.1f, 3.1f, 1.1f, 0.51f, 0),
|
|
||||||
(2.3f, 3.1f, 4.1f, 2.1f, 0.61f, 0),
|
|
||||||
(3.0f, 4.0f, 5.0f, 1.5f, 0.3f, 1),
|
|
||||||
(4.0f, 5.0f, 6.0f, 2.5f, 0.4f, 1),
|
|
||||||
(3.1f, 4.1f, 5.1f, 1.6f, 0.4f, 1),
|
|
||||||
(4.1f, 5.1f, 6.1f, 2.6f, 0.5f, 1),
|
|
||||||
(5.0f, 6.0f, 7.0f, 1.0f, 0.2f, 2),
|
|
||||||
(6.0f, 7.0f, 8.0f, 1.3f, 0.6f, 2),
|
|
||||||
(5.1f, 6.1f, 7.1f, 1.2f, 0.1f, 2),
|
|
||||||
(6.1f, 7.1f, 8.1f, 1.4f, 0.7f, 2),
|
|
||||||
(6.2f, 7.2f, 8.2f, 1.5f, 0.8f, 2))
|
|
||||||
|
|
||||||
test("MLlib way setting features_cols should work") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
val xgbParam = Map(
|
|
||||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
|
||||||
"tree_method" -> "hist", "device" -> "cuda",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName
|
|
||||||
)
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("disorder feature columns should work") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
var trainingDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
|
|
||||||
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
|
|
||||||
|
|
||||||
val xgbParam = Map(
|
|
||||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
|
||||||
"tree_method" -> "hist", "device" -> "cuda"
|
|
||||||
)
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Throw exception when feature/label columns are not numeric type") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val originalDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
var trainingDf = originalDf.withColumn("f2", col("f2").cast(StringType))
|
|
||||||
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
|
||||||
val thrown1 = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown1.getMessage.contains("Column f2 must be of NumericType but found: string."))
|
|
||||||
|
|
||||||
trainingDf = originalDf.withColumn(labelName, col(labelName).cast(StringType))
|
|
||||||
val thrown2 = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown2.getMessage.contains(
|
|
||||||
s"Column $labelName must be of NumericType but found: string."))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Throw exception when features_cols or label_col is not set") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
|
||||||
|
|
||||||
// GPU train requires featuresCols. If not specified,
|
|
||||||
// then NoSuchElementException will be thrown
|
|
||||||
val thrown = intercept[NoSuchElementException] {
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
|
||||||
|
|
||||||
val thrown1 = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown1.getMessage.contains("label does not exist."))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Throw exception when device is not set to cuda") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "hist")
|
|
||||||
val thrown = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Train with eval") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
|
|
||||||
.randomSplit(Array(0.6, 0.2, 0.2), seed = 1)
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
|
||||||
val model1 = new XGBoostClassifier(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
.fit(trainingDf)
|
|
||||||
|
|
||||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
|
||||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
|
||||||
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
|
||||||
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val xgbc = new XGBoostClassifier(xgbParam)
|
|
||||||
xgbc.write.overwrite().save(xgbcPath)
|
|
||||||
val paramMap2 = XGBoostClassifier.load(xgbcPath).MLlib2XGBoostParams
|
|
||||||
xgbParam.foreach {
|
|
||||||
case (k, v: Array[String]) =>
|
|
||||||
assert(v.sameElements(paramMap2(k).asInstanceOf[Array[String]]))
|
|
||||||
case (k, v) =>
|
|
||||||
assert(v.toString == paramMap2(k).toString)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("device ordinal should not be specified") {
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
import spark.implicits._
|
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
|
||||||
val params = Map(
|
|
||||||
"objective" -> "multi:softprob",
|
|
||||||
"num_class" -> 3,
|
|
||||||
"num_round" -> 5,
|
|
||||||
"num_workers" -> 1
|
|
||||||
)
|
|
||||||
val thrown = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(params)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setDevice("cuda:1")
|
|
||||||
.fit(trainingDf)
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains("device given invalid value cuda:1"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,258 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
|
||||||
import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}
|
|
||||||
|
|
||||||
class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|
||||||
|
|
||||||
val labelName = "label_col"
|
|
||||||
val groupName = "group_col"
|
|
||||||
val schema = StructType(Seq(
|
|
||||||
StructField(labelName, FloatType),
|
|
||||||
StructField("f1", FloatType),
|
|
||||||
StructField("f2", FloatType),
|
|
||||||
StructField("f3", FloatType),
|
|
||||||
StructField(groupName, IntegerType)))
|
|
||||||
val featureNames = schema.fieldNames.filter(s =>
|
|
||||||
!(s.equals(labelName) || s.equals(groupName)))
|
|
||||||
|
|
||||||
test("The transform result should be same for several runs on same model") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
// Get a model
|
|
||||||
val model = new XGBoostRegressor(xgbParam)
|
|
||||||
.fit(originalDf)
|
|
||||||
val left = model.transform(testDf).collect()
|
|
||||||
val right = model.transform(testDf).collect()
|
|
||||||
// The left should be same with right
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Tree method gpu_hist still works") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val params = Map(
|
|
||||||
"tree_method" -> "gpu_hist",
|
|
||||||
"features_cols" -> featureNames,
|
|
||||||
"label_col" -> labelName,
|
|
||||||
"num_round" -> 10,
|
|
||||||
"num_workers" -> 1
|
|
||||||
)
|
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
// Get a model
|
|
||||||
val model = new XGBoostRegressor(params).fit(originalDf)
|
|
||||||
val left = model.transform(testDf).collect()
|
|
||||||
val right = model.transform(testDf).collect()
|
|
||||||
// The left should be same with right
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("use weight") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
|
|
||||||
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
|
|
||||||
|
|
||||||
val model = new XGBoostRegressor(xgbParam)
|
|
||||||
.fit(originalDf)
|
|
||||||
val model2 = new XGBoostRegressor(xgbParam)
|
|
||||||
.setWeightCol("weight")
|
|
||||||
.fit(dfWithWeight)
|
|
||||||
|
|
||||||
val left = model.transform(testDf).collect()
|
|
||||||
val right = model2.transform(testDf).collect()
|
|
||||||
// left should be different with right
|
|
||||||
assert(!compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Save model and transform GPU dataset") {
|
|
||||||
// Train a model on GPU
|
|
||||||
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val classifier = new XGBoostRegressor(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setTreeMethod("hist")
|
|
||||||
.setDevice("cuda")
|
|
||||||
(classifier.fit(rawInput), testDf)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
gpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on GPU
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
val left = gpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = gpuModelFromFile
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Model trained on CPU can transform GPU dataset") {
|
|
||||||
// Train a model on CPU
|
|
||||||
val cpuModel = withCpuSparkSession() { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
.setInputCols(featureNames)
|
|
||||||
.setOutputCol("features")
|
|
||||||
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
|
||||||
|
|
||||||
val classifier = new XGBoostRegressor(xgbParam)
|
|
||||||
.setFeaturesCol("features")
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setTreeMethod("auto")
|
|
||||||
classifier.fit(trainingDf)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
cpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val cpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on GPU
|
|
||||||
withGpuSparkSession() { spark =>
|
|
||||||
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
// Since CPU model does not know the information about the features cols that GPU transform
|
|
||||||
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
|
||||||
// manually
|
|
||||||
val thrown = intercept[NoSuchElementException](cpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.collect())
|
|
||||||
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
|
||||||
|
|
||||||
val left = cpuModel
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.transform(testDf)
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = cpuModelFromFile
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.transform(testDf)
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Model trained on GPU can transform CPU dataset") {
|
|
||||||
// Train a model on GPU
|
|
||||||
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1)
|
|
||||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val classifier = new XGBoostRegressor(xgbParam)
|
|
||||||
.setFeaturesCol(featureNames)
|
|
||||||
.setLabelCol(labelName)
|
|
||||||
.setDevice("cuda")
|
|
||||||
classifier.fit(rawInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
gpuModel.write.overwrite().save(xgbrModel)
|
|
||||||
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
|
|
||||||
|
|
||||||
// transform on CPU
|
|
||||||
withCpuSparkSession() { spark =>
|
|
||||||
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val featureColName = "feature_col"
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
.setInputCols(featureNames)
|
|
||||||
.setOutputCol(featureColName)
|
|
||||||
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
|
||||||
|
|
||||||
// Since GPU model does not know the information about the features col name that CPU
|
|
||||||
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
|
||||||
intercept[IllegalArgumentException](
|
|
||||||
gpuModel
|
|
||||||
.transform(testDf)
|
|
||||||
.collect())
|
|
||||||
|
|
||||||
val left = gpuModel
|
|
||||||
.setFeaturesCol(featureColName)
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
val right = gpuModelFromFile
|
|
||||||
.setFeaturesCol(featureColName)
|
|
||||||
.transform(testDf)
|
|
||||||
.select(labelName, "prediction")
|
|
||||||
.collect()
|
|
||||||
|
|
||||||
assert(compareResults(true, 0.000001, left, right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Ranking: train with Group") {
|
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg",
|
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
|
||||||
val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema)
|
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
|
||||||
|
|
||||||
val model = new XGBoostRegressor(xgbParam)
|
|
||||||
.setGroupCol(groupName)
|
|
||||||
.fit(trainingDf)
|
|
||||||
|
|
||||||
val ret = model.transform(testDf).collect()
|
|
||||||
assert(testDf.count() === ret.length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -0,0 +1,145 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||||
|
|
||||||
|
import java.nio.file.{Files, Path}
|
||||||
|
import java.sql.{Date, Timestamp}
|
||||||
|
import java.util.{Locale, TimeZone}
|
||||||
|
|
||||||
|
import org.apache.spark.{GpuTestUtils, SparkConf}
|
||||||
|
import org.apache.spark.internal.Logging
|
||||||
|
import org.apache.spark.network.util.JavaUtils
|
||||||
|
import org.apache.spark.sql.{Row, SparkSession}
|
||||||
|
import org.scalatest.BeforeAndAfterAll
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
|
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
|
||||||
|
|
||||||
|
import SparkSessionHolder.withSparkSession
|
||||||
|
|
||||||
|
protected def getResourcePath(resource: String): String = {
|
||||||
|
require(resource.startsWith("/"), "resource must start with /")
|
||||||
|
getClass.getResource(resource).getPath
|
||||||
|
}
|
||||||
|
|
||||||
|
def enableCsvConf(): SparkConf = {
|
||||||
|
new SparkConf()
|
||||||
|
.set("spark.rapids.sql.csv.read.float.enabled", "true")
|
||||||
|
.set("spark.rapids.sql.csv.read.double.enabled", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
|
||||||
|
// set "spark.rapids.sql.explain" to "ALL" to check if the operators
|
||||||
|
// can be replaced by GPU
|
||||||
|
val c = conf.clone()
|
||||||
|
.set("spark.rapids.sql.enabled", "true")
|
||||||
|
withSparkSession(c, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
def withCpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
|
||||||
|
val c = conf.clone()
|
||||||
|
.set("spark.rapids.sql.enabled", "false") // Just to be sure
|
||||||
|
withSparkSession(c, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait TmpFolderSuite extends BeforeAndAfterAll {
|
||||||
|
self: AnyFunSuite =>
|
||||||
|
protected var tempDir: Path = _
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
super.beforeAll()
|
||||||
|
tempDir = Files.createTempDirectory(getClass.getName)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def afterAll(): Unit = {
|
||||||
|
JavaUtils.deleteRecursively(tempDir.toFile)
|
||||||
|
super.afterAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def createTmpFolder(prefix: String): Path = {
|
||||||
|
Files.createTempDirectory(tempDir, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object SparkSessionHolder extends Logging {
|
||||||
|
|
||||||
|
private var spark = createSparkSession()
|
||||||
|
private var origConf = spark.conf.getAll
|
||||||
|
private var origConfKeys = origConf.keys.toSet
|
||||||
|
|
||||||
|
private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
|
||||||
|
case (key, value) if spark.conf.get(key, null) != value =>
|
||||||
|
spark.conf.set(key, value)
|
||||||
|
case _ => // No need to modify it
|
||||||
|
}
|
||||||
|
|
||||||
|
private def createSparkSession(): SparkSession = {
|
||||||
|
GpuTestUtils.cleanupAnyExistingSession()
|
||||||
|
|
||||||
|
// Timezone is fixed to UTC to allow timestamps to work by default
|
||||||
|
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
|
||||||
|
// Add Locale setting
|
||||||
|
Locale.setDefault(Locale.US)
|
||||||
|
|
||||||
|
val builder = SparkSession.builder()
|
||||||
|
.master("local[2]")
|
||||||
|
.config("spark.sql.adaptive.enabled", "false")
|
||||||
|
.config("spark.rapids.sql.test.enabled", "false")
|
||||||
|
.config("spark.stage.maxConsecutiveAttempts", "1")
|
||||||
|
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
|
||||||
|
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
|
||||||
|
.config("spark.sql.files.maxPartitionBytes", "1000")
|
||||||
|
.appName("XGBoost4j-Spark-Gpu unit test")
|
||||||
|
|
||||||
|
builder.getOrCreate()
|
||||||
|
}
|
||||||
|
|
||||||
|
private def reinitSession(): Unit = {
|
||||||
|
spark = createSparkSession()
|
||||||
|
origConf = spark.conf.getAll
|
||||||
|
origConfKeys = origConf.keys.toSet
|
||||||
|
}
|
||||||
|
|
||||||
|
def sparkSession: SparkSession = {
|
||||||
|
if (SparkSession.getActiveSession.isEmpty) {
|
||||||
|
reinitSession()
|
||||||
|
}
|
||||||
|
spark
|
||||||
|
}
|
||||||
|
|
||||||
|
def resetSparkSessionConf(): Unit = {
|
||||||
|
if (SparkSession.getActiveSession.isEmpty) {
|
||||||
|
reinitSession()
|
||||||
|
} else {
|
||||||
|
setAllConfs(origConf.toArray)
|
||||||
|
val currentKeys = spark.conf.getAll.keys.toSet
|
||||||
|
val toRemove = currentKeys -- origConfKeys
|
||||||
|
toRemove.foreach(spark.conf.unset)
|
||||||
|
}
|
||||||
|
logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
|
||||||
|
}
|
||||||
|
|
||||||
|
def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
|
||||||
|
resetSparkSessionConf
|
||||||
|
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
|
||||||
|
setAllConfs(conf.getAll)
|
||||||
|
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
|
||||||
|
spark.sparkContext.setLogLevel("WARN")
|
||||||
|
f(spark)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,523 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ai.rapids.cudf.Table
|
||||||
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
|
||||||
|
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
||||||
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
|
import org.apache.spark.sql.{Dataset, SparkSession}
|
||||||
|
import org.apache.spark.SparkConf
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
class GpuXGBoostPluginSuite extends GpuTestSuite {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
val xgbParams: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"eta" -> 0.2,
|
||||||
|
"objective" -> "binary:logistic"
|
||||||
|
)
|
||||||
|
val features = Array("c1", "c2")
|
||||||
|
val estimator = new XGBoostClassifier(xgbParams)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setMissing(0.2f)
|
||||||
|
.setAlpha(0.97)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setNumRound(3)
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
assert(estimator.getMaxDepth === 5)
|
||||||
|
assert(estimator.getEta === 0.2)
|
||||||
|
assert(estimator.getObjective === "binary:logistic")
|
||||||
|
assert(estimator.getFeaturesCols === features)
|
||||||
|
assert(estimator.getMissing === 0.2f)
|
||||||
|
assert(estimator.getAlpha === 0.97)
|
||||||
|
assert(estimator.getDevice === "cuda")
|
||||||
|
assert(estimator.getNumRound === 3)
|
||||||
|
|
||||||
|
estimator.setEta(0.66).setMaxDepth(7)
|
||||||
|
assert(estimator.getMaxDepth === 7)
|
||||||
|
assert(estimator.getEta === 0.66)
|
||||||
|
|
||||||
|
val model = estimator.fit(df)
|
||||||
|
assert(model.getMaxDepth === 7)
|
||||||
|
assert(model.getEta === 0.66)
|
||||||
|
assert(model.getObjective === "binary:logistic")
|
||||||
|
assert(model.getFeaturesCols === features)
|
||||||
|
assert(model.getMissing === 0.2f)
|
||||||
|
assert(model.getAlpha === 0.97)
|
||||||
|
assert(model.getLeafPredictionCol === "leaf")
|
||||||
|
assert(model.getContribPredictionCol === "contrib")
|
||||||
|
assert(model.getDevice === "cuda")
|
||||||
|
assert(model.getNumRound === 3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("isEnabled") {
|
||||||
|
def checkIsEnabled(spark: SparkSession, expected: Boolean): Unit = {
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Seq((1.0f, 2.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 1.0f)
|
||||||
|
).toDF("c1", "c2", "label")
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
assert(classifier.getPlugin.isDefined)
|
||||||
|
assert(classifier.getPlugin.get.isEnabled(df) === expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// spark.rapids.sql.enabled is not set explicitly, default to true
|
||||||
|
withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)})
|
||||||
|
|
||||||
|
// set spark.rapids.sql.enabled to false
|
||||||
|
withCpuSparkSession() { spark =>
|
||||||
|
checkIsEnabled(spark, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set spark.rapids.sql.enabled to true
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
checkIsEnabled(spark, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("parameter validation") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
|
||||||
|
val plugin = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
plugin.validate(classifier, df)
|
||||||
|
}
|
||||||
|
classifier.setDevice("cuda")
|
||||||
|
plugin.validate(classifier, df)
|
||||||
|
|
||||||
|
classifier.setDevice("gpu")
|
||||||
|
plugin.validate(classifier, df)
|
||||||
|
|
||||||
|
classifier.setDevice("cpu")
|
||||||
|
classifier.setTreeMethod("gpu_hist")
|
||||||
|
plugin.validate(classifier, df)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("preprocess") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
.repartition(5)
|
||||||
|
|
||||||
|
assert(df.schema.names.contains("other"))
|
||||||
|
assert(df.rdd.getNumPartitions === 5)
|
||||||
|
|
||||||
|
val features = Array("c1", "c2")
|
||||||
|
var classifier = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(3)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
assert(classifier.getPlugin.isDefined)
|
||||||
|
assert(classifier.getPlugin.get.isInstanceOf[GpuXGBoostPlugin])
|
||||||
|
var out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
|
||||||
|
.preprocess(classifier, df)
|
||||||
|
|
||||||
|
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
|
||||||
|
assert(out.schema.names.contains(classifier.getLabelCol))
|
||||||
|
assert(!out.schema.names.contains("weight") && !out.schema.names.contains("margin"))
|
||||||
|
assert(out.rdd.getNumPartitions === 3)
|
||||||
|
|
||||||
|
classifier = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(4)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setDevice("cuda")
|
||||||
|
out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
|
||||||
|
.preprocess(classifier, df)
|
||||||
|
|
||||||
|
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
|
||||||
|
assert(out.schema.names.contains(classifier.getLabelCol))
|
||||||
|
assert(out.schema.names.contains("weight") && out.schema.names.contains("margin"))
|
||||||
|
assert(out.rdd.getNumPartitions === 4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test distributed
|
||||||
|
test("build RDD Watches") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
|
||||||
|
// dataPoint -> (missing, rowNum, nonMissing)
|
||||||
|
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
|
||||||
|
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
|
||||||
|
val df = Seq(
|
||||||
|
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
|
||||||
|
val features = Array("c1", "c2")
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(2)
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setDevice("cuda")
|
||||||
|
.setMissing(missing)
|
||||||
|
|
||||||
|
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df)
|
||||||
|
val result = rdd.mapPartitions { iter =>
|
||||||
|
val watches = iter.next()
|
||||||
|
val size = watches.size
|
||||||
|
val labels = watches.datasets(0).getLabel
|
||||||
|
val weight = watches.datasets(0).getWeight
|
||||||
|
val margins = watches.datasets(0).getBaseMargin
|
||||||
|
val rowNumber = watches.datasets(0).rowNum
|
||||||
|
val nonMissing = watches.datasets(0).nonMissingNum
|
||||||
|
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
|
||||||
|
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty
|
||||||
|
|
||||||
|
for (row <- result) {
|
||||||
|
assert(row._1 === 1)
|
||||||
|
rowNumber.append(row._2)
|
||||||
|
nonMissing.append(row._3)
|
||||||
|
labels.append(row._4: _*)
|
||||||
|
weight.append(row._5: _*)
|
||||||
|
margins.append(row._6: _*)
|
||||||
|
}
|
||||||
|
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
|
||||||
|
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
|
||||||
|
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
|
||||||
|
assert(rowNumber.sum === expectedRowNum)
|
||||||
|
assert(nonMissing.sum === expectedNonMissing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("build RDD Watches with Eval") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val train = Seq(
|
||||||
|
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
|
||||||
|
// dataPoint -> (missing, rowNum, nonMissing)
|
||||||
|
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
|
||||||
|
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
|
||||||
|
val eval = Seq(
|
||||||
|
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
|
||||||
|
val features = Array("c1", "c2")
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(2)
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setDevice("cuda")
|
||||||
|
.setMissing(missing)
|
||||||
|
.setEvalDataset(eval)
|
||||||
|
|
||||||
|
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train)
|
||||||
|
val result = rdd.mapPartitions { iter =>
|
||||||
|
val watches = iter.next()
|
||||||
|
val size = watches.size
|
||||||
|
val labels = watches.datasets(1).getLabel
|
||||||
|
val weight = watches.datasets(1).getWeight
|
||||||
|
val margins = watches.datasets(1).getBaseMargin
|
||||||
|
val rowNumber = watches.datasets(1).rowNum
|
||||||
|
val nonMissing = watches.datasets(1).nonMissingNum
|
||||||
|
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
|
||||||
|
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty
|
||||||
|
|
||||||
|
for (row <- result) {
|
||||||
|
assert(row._1 === 2)
|
||||||
|
rowNumber.append(row._2)
|
||||||
|
nonMissing.append(row._3)
|
||||||
|
labels.append(row._4: _*)
|
||||||
|
weight.append(row._5: _*)
|
||||||
|
margins.append(row._6: _*)
|
||||||
|
}
|
||||||
|
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
|
||||||
|
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
|
||||||
|
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
|
||||||
|
assert(rowNumber.sum === expectedRowNum)
|
||||||
|
assert(nonMissing.sum === expectedNonMissing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("transformed schema") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Seq(
|
||||||
|
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
|
||||||
|
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
|
||||||
|
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
|
||||||
|
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
|
||||||
|
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
|
||||||
|
).toDF("c1", "c2", "weight", "margin", "label", "other")
|
||||||
|
|
||||||
|
val estimator = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(1)
|
||||||
|
.setNumRound(2)
|
||||||
|
.setFeaturesCol(Array("c1", "c2"))
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
assert(estimator.getPlugin.isDefined && estimator.getPlugin.get.isEnabled(df))
|
||||||
|
|
||||||
|
val out = estimator.fit(df).transform(df)
|
||||||
|
// Transform should not discard the other columns of the transforming dataframe
|
||||||
|
Seq("c1", "c2", "weight", "margin", "label", "other").foreach { v =>
|
||||||
|
assert(out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform for XGBoostClassifier needs to add extra columns
|
||||||
|
Seq("rawPrediction", "probability", "prediction").foreach { v =>
|
||||||
|
assert(out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
assert(out.schema.names.length === 9)
|
||||||
|
|
||||||
|
val out1 = estimator.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
|
||||||
|
.fit(df)
|
||||||
|
.transform(df)
|
||||||
|
Seq("leaf", "contrib").foreach { v =>
|
||||||
|
assert(out1.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkEqual(left: Array[Array[Float]],
|
||||||
|
right: Array[Array[Float]],
|
||||||
|
epsilon: Float = 1e-4f): Unit = {
|
||||||
|
assert(left.size === right.size)
|
||||||
|
left.zip(right).foreach { case (leftValue, rightValue) =>
|
||||||
|
leftValue.zip(rightValue).foreach { case (l, r) =>
|
||||||
|
assert(math.abs(l - r) < epsilon)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq("binary:logistic", "multi:softprob").foreach { case objective =>
|
||||||
|
test(s"$objective: XGBoost-Spark should match xgboost4j") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
|
||||||
|
val numRound = 100
|
||||||
|
var xgboostParams: Map[String, Any] = Map(
|
||||||
|
"objective" -> objective,
|
||||||
|
"device" -> "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
val (trainPath, testPath) = if (objective == "binary:logistic") {
|
||||||
|
(writeFile(Classification.train.toDF("label", "weight", "c1", "c2", "c3")),
|
||||||
|
writeFile(Classification.test.toDF("label", "weight", "c1", "c2", "c3")))
|
||||||
|
} else {
|
||||||
|
xgboostParams = xgboostParams ++ Map("num_class" -> 6)
|
||||||
|
(writeFile(MultiClassification.train.toDF("label", "weight", "c1", "c2", "c3")),
|
||||||
|
writeFile(MultiClassification.test.toDF("label", "weight", "c1", "c2", "c3")))
|
||||||
|
}
|
||||||
|
|
||||||
|
val df = spark.read.parquet(trainPath)
|
||||||
|
val testdf = spark.read.parquet(testPath)
|
||||||
|
|
||||||
|
val features = Array("c1", "c2", "c3")
|
||||||
|
val featuresIndices = features.map(df.schema.fieldIndex)
|
||||||
|
val label = "label"
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier(xgboostParams)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setLabelCol(label)
|
||||||
|
.setNumRound(numRound)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
val xgb4jModel = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(trainPath)))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices),
|
||||||
|
batch.select(df.schema.fieldIndex(label)), null, null, null
|
||||||
|
)
|
||||||
|
val qdm = new QuantileDMatrix(Seq(cb).iterator, classifier.getMissing,
|
||||||
|
classifier.getMaxBins, classifier.getNthread)
|
||||||
|
ScalaXGBoost.train(qdm, xgboostParams, numRound)
|
||||||
|
}
|
||||||
|
|
||||||
|
val (xgb4jLeaf, xgb4jContrib, xgb4jProb, xgb4jRaw) = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(testPath)))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
|
||||||
|
)
|
||||||
|
val qdm = new DMatrix(cb, classifier.getMissing, classifier.getNthread)
|
||||||
|
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
|
||||||
|
xgb4jModel.predict(qdm), xgb4jModel.predict(qdm, outPutMargin = true))
|
||||||
|
}
|
||||||
|
|
||||||
|
val rows = classifier.fit(df).transform(testdf).collect()
|
||||||
|
|
||||||
|
// Check Leaf
|
||||||
|
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
|
|
||||||
|
// Check contrib
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
// Check probability
|
||||||
|
var xgbSparkProb = rows.map(row =>
|
||||||
|
row.getAs[DenseVector]("probability").toArray.map(_.toFloat))
|
||||||
|
if (objective == "binary:logistic") {
|
||||||
|
xgbSparkProb = xgbSparkProb.map(v => Array(v(1)))
|
||||||
|
}
|
||||||
|
checkEqual(xgb4jProb, xgbSparkProb)
|
||||||
|
|
||||||
|
// Check raw
|
||||||
|
var xgbSparkRaw = rows.map(row =>
|
||||||
|
row.getAs[DenseVector]("rawPrediction").toArray.map(_.toFloat))
|
||||||
|
if (objective == "binary:logistic") {
|
||||||
|
xgbSparkRaw = xgbSparkRaw.map(v => Array(v(1)))
|
||||||
|
}
|
||||||
|
checkEqual(xgb4jRaw, xgbSparkRaw)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test(s"Regression: XGBoost-Spark should match xgboost4j") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
|
||||||
|
val trainPath = writeFile(Regression.train.toDF("label", "weight", "c1", "c2", "c3"))
|
||||||
|
val testPath = writeFile(Regression.test.toDF("label", "weight", "c1", "c2", "c3"))
|
||||||
|
|
||||||
|
val df = spark.read.parquet(trainPath)
|
||||||
|
val testdf = spark.read.parquet(testPath)
|
||||||
|
|
||||||
|
val features = Array("c1", "c2", "c3")
|
||||||
|
val featuresIndices = features.map(df.schema.fieldIndex)
|
||||||
|
val label = "label"
|
||||||
|
|
||||||
|
val numRound = 100
|
||||||
|
val xgboostParams: Map[String, Any] = Map(
|
||||||
|
"device" -> "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
val regressor = new XGBoostRegressor(xgboostParams)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setLabelCol(label)
|
||||||
|
.setNumRound(numRound)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
val xgb4jModel = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(trainPath)))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices),
|
||||||
|
batch.select(df.schema.fieldIndex(label)), null, null, null
|
||||||
|
)
|
||||||
|
val qdm = new QuantileDMatrix(Seq(cb).iterator, regressor.getMissing,
|
||||||
|
regressor.getMaxBins, regressor.getNthread)
|
||||||
|
ScalaXGBoost.train(qdm, xgboostParams, numRound)
|
||||||
|
}
|
||||||
|
|
||||||
|
val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(testPath)))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
|
||||||
|
)
|
||||||
|
val qdm = new DMatrix(cb, regressor.getMissing, regressor.getNthread)
|
||||||
|
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
|
||||||
|
xgb4jModel.predict(qdm))
|
||||||
|
}
|
||||||
|
|
||||||
|
val rows = regressor.fit(df).transform(testdf).collect()
|
||||||
|
|
||||||
|
// Check Leaf
|
||||||
|
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
|
|
||||||
|
// Check contrib
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
// Check prediction
|
||||||
|
val xgbSparkPred = rows.map(row =>
|
||||||
|
Array(row.getAs[Double]("prediction").toFloat))
|
||||||
|
checkEqual(xgb4jPred, xgbSparkPred)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def writeFile(df: Dataset[_]): String = {
|
||||||
|
def listFiles(directory: String): Array[String] = {
|
||||||
|
val dir = new File(directory)
|
||||||
|
if (dir.exists && dir.isDirectory) {
|
||||||
|
dir.listFiles.filter(f => f.isFile && f.getName.startsWith("part-")).map(_.getName)
|
||||||
|
} else {
|
||||||
|
Array.empty[String]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dir = createTmpFolder("gpu_").toAbsolutePath.toString
|
||||||
|
df.coalesce(1).write.parquet(s"$dir/data")
|
||||||
|
|
||||||
|
val file = listFiles(s"$dir/data")(0)
|
||||||
|
s"$dir/data/$file"
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,86 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
|
trait TrainTestData {
|
||||||
|
|
||||||
|
protected def generateClassificationDataset(
|
||||||
|
numRows: Int,
|
||||||
|
numClass: Int,
|
||||||
|
seed: Int = 1): Seq[(Int, Float, Float, Float, Float)] = {
|
||||||
|
val random = new Random()
|
||||||
|
random.setSeed(seed)
|
||||||
|
(1 to numRows).map { _ =>
|
||||||
|
val label = random.nextInt(numClass)
|
||||||
|
// label, weight, c1, c2, c3
|
||||||
|
(label, random.nextFloat().abs, random.nextGaussian().toFloat, random.nextGaussian().toFloat,
|
||||||
|
random.nextGaussian().toFloat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def generateRegressionDataset(
|
||||||
|
numRows: Int,
|
||||||
|
seed: Int = 11): Seq[(Float, Float, Float, Float, Float)] = {
|
||||||
|
val random = new Random()
|
||||||
|
random.setSeed(seed)
|
||||||
|
(1 to numRows).map { _ =>
|
||||||
|
// label, weight, c1, c2, c3
|
||||||
|
(random.nextFloat(), random.nextFloat().abs, random.nextGaussian().toFloat,
|
||||||
|
random.nextGaussian().toFloat,
|
||||||
|
random.nextGaussian().toFloat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def generateRankDataset(
|
||||||
|
numRows: Int,
|
||||||
|
numClass: Int,
|
||||||
|
maxGroup: Int = 12,
|
||||||
|
seed: Int = 99): Seq[(Int, Float, Int, Float, Float, Float)] = {
|
||||||
|
val random = new Random()
|
||||||
|
random.setSeed(seed)
|
||||||
|
(1 to numRows).map { _ =>
|
||||||
|
val group = random.nextInt(maxGroup)
|
||||||
|
// label, weight, group, c1, c2, c3
|
||||||
|
(random.nextInt(numClass), group.toFloat, group,
|
||||||
|
random.nextGaussian().toFloat,
|
||||||
|
random.nextGaussian().toFloat,
|
||||||
|
random.nextGaussian().toFloat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object Classification extends TrainTestData {
|
||||||
|
val train = generateClassificationDataset(300, 2, 3)
|
||||||
|
val test = generateClassificationDataset(150, 2, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
object MultiClassification extends TrainTestData {
|
||||||
|
val train = generateClassificationDataset(300, 4, 11)
|
||||||
|
val test = generateClassificationDataset(150, 4, 12)
|
||||||
|
}
|
||||||
|
|
||||||
|
object Regression extends TrainTestData {
|
||||||
|
val train = generateRegressionDataset(300, 222)
|
||||||
|
val test = generateRegressionDataset(150, 223)
|
||||||
|
}
|
||||||
|
|
||||||
|
object Ranking extends TrainTestData {
|
||||||
|
val train = generateRankDataset(300, 10, 555)
|
||||||
|
val test = generateRankDataset(150, 10, 556)
|
||||||
|
}
|
||||||
@ -1,602 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import java.nio.file.Files
|
|
||||||
import java.util.ServiceLoader
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
|
||||||
import org.apache.spark.sql.functions.{col, lit}
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
|
||||||
import org.apache.spark.ml.linalg.Vector
|
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
|
||||||
import org.apache.spark.storage.StorageLevel
|
|
||||||
|
|
||||||
/**
|
|
||||||
* PreXGBoost serves preparing data before training and transform
|
|
||||||
*/
|
|
||||||
object PreXGBoost extends PreXGBoostProvider {
|
|
||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
|
|
||||||
private lazy val defaultBaseMarginColumn = lit(Float.NaN)
|
|
||||||
private lazy val defaultWeightColumn = lit(1.0)
|
|
||||||
private lazy val defaultGroupColumn = lit(-1)
|
|
||||||
|
|
||||||
// Find the correct PreXGBoostProvider by ServiceLoader
|
|
||||||
private val optionProvider: Option[PreXGBoostProvider] = {
|
|
||||||
val classLoader = Option(Thread.currentThread().getContextClassLoader)
|
|
||||||
.getOrElse(getClass.getClassLoader)
|
|
||||||
|
|
||||||
val serviceLoader = ServiceLoader.load(classOf[PreXGBoostProvider], classLoader)
|
|
||||||
|
|
||||||
// For now, we only trust GpuPreXGBoost.
|
|
||||||
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
|
|
||||||
"ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost")).toList match {
|
|
||||||
case Nil => None
|
|
||||||
case head::Nil =>
|
|
||||||
Some(head)
|
|
||||||
case _ => None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform schema
|
|
||||||
*
|
|
||||||
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
|
|
||||||
* XGBoostRegressor/XGBoostRegressionModel
|
|
||||||
* @param schema the input schema
|
|
||||||
* @return the transformed schema
|
|
||||||
*/
|
|
||||||
override def transformSchema(
|
|
||||||
xgboostEstimator: XGBoostEstimatorCommon,
|
|
||||||
schema: StructType): StructType = {
|
|
||||||
|
|
||||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(None)) {
|
|
||||||
return optionProvider.get.transformSchema(xgboostEstimator, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
xgboostEstimator match {
|
|
||||||
case est: XGBoostClassifier => est.transformSchemaInternal(schema)
|
|
||||||
case model: XGBoostClassificationModel => model.transformSchemaInternal(schema)
|
|
||||||
case reg: XGBoostRegressor => reg.transformSchemaInternal(schema)
|
|
||||||
case model: XGBoostRegressionModel => model.transformSchemaInternal(schema)
|
|
||||||
case _ => throw new RuntimeException("Unsupporting " + xgboostEstimator)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
|
||||||
*
|
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
|
||||||
* @param dataset the training data
|
|
||||||
* @param params all user defined and defaulted params
|
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
|
|
||||||
* RDD[() => Watches] will be used as the training input
|
|
||||||
* Option[RDD[_]\] is the optional cached RDD
|
|
||||||
*/
|
|
||||||
override def buildDatasetToRDD(
|
|
||||||
estimator: Estimator[_],
|
|
||||||
dataset: Dataset[_],
|
|
||||||
params: Map[String, Any]): XGBoostExecutionParams =>
|
|
||||||
(RDD[() => Watches], Option[RDD[_]]) = {
|
|
||||||
|
|
||||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
|
||||||
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
val (packedParams, evalSet, xgbInput) = estimator match {
|
|
||||||
case est: XGBoostEstimatorCommon =>
|
|
||||||
// get weight column, if weight is not defined, default to lit(1.0)
|
|
||||||
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
|
|
||||||
defaultWeightColumn
|
|
||||||
} else col(est.getWeightCol)
|
|
||||||
|
|
||||||
// get base-margin column, if base-margin is not defined, default to lit(Float.NaN)
|
|
||||||
val baseMargin = if (!est.isDefined(est.baseMarginCol) || est.getBaseMarginCol.isEmpty) {
|
|
||||||
defaultBaseMarginColumn
|
|
||||||
} else col(est.getBaseMarginCol)
|
|
||||||
|
|
||||||
val group = est match {
|
|
||||||
case regressor: XGBoostRegressor =>
|
|
||||||
// get group column, if group is not defined, default to lit(-1)
|
|
||||||
Some(
|
|
||||||
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
|
|
||||||
defaultGroupColumn
|
|
||||||
} else col(regressor.getGroupCol)
|
|
||||||
)
|
|
||||||
case _ => None
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
val (xgbInput, featuresName) = est.vectorize(dataset)
|
|
||||||
|
|
||||||
val evalSets = est.getEvalSets(params).transform((_, df) => {
|
|
||||||
val (dfTransformed, _) = est.vectorize(df)
|
|
||||||
dfTransformed
|
|
||||||
})
|
|
||||||
|
|
||||||
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
|
|
||||||
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
|
|
||||||
|
|
||||||
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
|
||||||
}
|
|
||||||
|
|
||||||
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
|
|
||||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
|
||||||
packedParams, xgbInput.asInstanceOf[DataFrame]).head
|
|
||||||
|
|
||||||
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
|
||||||
val evalRDDMap = evalSet.map {
|
|
||||||
case (name, dataFrame) => (name,
|
|
||||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams,
|
|
||||||
dataFrame.asInstanceOf[DataFrame]).head)
|
|
||||||
}
|
|
||||||
|
|
||||||
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
|
|
||||||
|
|
||||||
xgbExecParams: XGBoostExecutionParams =>
|
|
||||||
composeInputData(trainingSet, hasGroup, packedParams.numWorkers) match {
|
|
||||||
case Left(trainingData) =>
|
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
|
||||||
} else None
|
|
||||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
|
||||||
case Right(trainingData) =>
|
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
|
||||||
} else None
|
|
||||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform Dataset
|
|
||||||
*
|
|
||||||
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
|
||||||
* @param dataset the input Dataset to transform
|
|
||||||
* @return the transformed DataFrame
|
|
||||||
*/
|
|
||||||
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
|
||||||
|
|
||||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
|
||||||
return optionProvider.get.transformDataset(model, dataset)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** get the necessary parameters */
|
|
||||||
val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing,
|
|
||||||
allowNonZeroForMissing, predictFunc, schema) =
|
|
||||||
model match {
|
|
||||||
case m: XGBoostClassificationModel =>
|
|
||||||
val (xgbInput, featuresName) = m.vectorize(dataset)
|
|
||||||
// predict and turn to Row
|
|
||||||
val predictFunc =
|
|
||||||
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
|
||||||
m.producePredictionItrs(booster, dm)
|
|
||||||
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
|
|
||||||
predLeafItr, predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the final Schema
|
|
||||||
var schema = StructType(xgbInput.schema.fields ++
|
|
||||||
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
|
||||||
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
if (m.isDefined(m.leafPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (m.isDefined(m.contribPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
|
|
||||||
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
|
|
||||||
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
|
||||||
|
|
||||||
case m: XGBoostRegressionModel =>
|
|
||||||
// predict and turn to Row
|
|
||||||
val (xgbInput, featuresName) = m.vectorize(dataset)
|
|
||||||
val predictFunc =
|
|
||||||
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
|
||||||
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
|
||||||
m.producePredictionItrs(booster, dm)
|
|
||||||
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the final Schema
|
|
||||||
var schema = StructType(xgbInput.schema.fields ++
|
|
||||||
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
if (m.isDefined(m.leafPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (m.isDefined(m.contribPredictionCol)) {
|
|
||||||
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
|
|
||||||
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
|
|
||||||
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster)
|
|
||||||
val appName = xgbInput.sparkSession.sparkContext.appName
|
|
||||||
|
|
||||||
val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
|
||||||
new AbstractIterator[Row] {
|
|
||||||
private var batchCnt = 0
|
|
||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val cacheInfo = {
|
|
||||||
if (useExternalMemory) {
|
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
|
||||||
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val dm = new DMatrix(
|
|
||||||
processMissingValues(features.map(_.asXGB), missing, allowNonZeroForMissing),
|
|
||||||
cacheInfo)
|
|
||||||
|
|
||||||
try {
|
|
||||||
predictFunc(bBooster.value, dm, batchRow.iterator)
|
|
||||||
} finally {
|
|
||||||
batchCnt += 1
|
|
||||||
dm.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
|
||||||
|
|
||||||
override def next(): Row = batchIterImpl.next()
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
|
||||||
xgbInput.sparkSession.createDataFrame(resultRDD, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[() => Watches]
|
|
||||||
*
|
|
||||||
* @param trainingSet the input training RDD[XGBLabeledPoint]
|
|
||||||
* @param evalRDDMap the eval set
|
|
||||||
* @param hasGroup if has group
|
|
||||||
* @return function to build (RDD[() => Watches], the cached RDD)
|
|
||||||
*/
|
|
||||||
private[spark] def buildRDDLabeledPointToRDDWatches(
|
|
||||||
trainingSet: RDD[XGBLabeledPoint],
|
|
||||||
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
|
|
||||||
hasGroup: Boolean = false):
|
|
||||||
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
|
|
||||||
|
|
||||||
xgbExecParams: XGBoostExecutionParams =>
|
|
||||||
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
|
|
||||||
case Left(trainingData) =>
|
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
|
||||||
} else None
|
|
||||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
|
||||||
case Right(trainingData) =>
|
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
|
||||||
} else None
|
|
||||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform RDD according to group column
|
|
||||||
*
|
|
||||||
* @param trainingData the input XGBLabeledPoint RDD
|
|
||||||
* @param hasGroup if has group column
|
|
||||||
* @param nWorkers total xgboost number workers to run xgboost tasks
|
|
||||||
* @return Either: the left is RDD with group, and the right is RDD without group
|
|
||||||
*/
|
|
||||||
private def composeInputData(
|
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
|
||||||
hasGroup: Boolean,
|
|
||||||
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
|
|
||||||
if (hasGroup) {
|
|
||||||
Left(repartitionForTrainingGroup(trainingData, nWorkers))
|
|
||||||
} else {
|
|
||||||
Right(trainingData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Repartition trainingData with group directly may cause data chaos, since the same group data
|
|
||||||
* may be split into different partitions.
|
|
||||||
*
|
|
||||||
* The first step is to aggregate the same group into same partition
|
|
||||||
* The second step is to repartition to nWorkers
|
|
||||||
*
|
|
||||||
* TODO, Could we repartition trainingData on group?
|
|
||||||
*/
|
|
||||||
private[spark] def repartitionForTrainingGroup(trainingData: RDD[XGBLabeledPoint],
|
|
||||||
nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
|
||||||
val allGroups = aggByGroupInfo(trainingData)
|
|
||||||
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
|
||||||
allGroups.repartition(nWorkers)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build RDD[() => Watches] for Ranking
|
|
||||||
* @param trainingData the training data RDD
|
|
||||||
* @param xgbExecutionParams xgboost execution params
|
|
||||||
* @param evalSetsMap the eval RDD
|
|
||||||
* @return RDD[() => Watches]
|
|
||||||
*/
|
|
||||||
private def trainForRanking(
|
|
||||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
|
|
||||||
if (evalSetsMap.isEmpty) {
|
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
|
||||||
val buildWatches = () => Watches.buildWatchesWithGroup(xgbExecutionParam,
|
|
||||||
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
|
||||||
xgbExecutionParam.allowNonZeroForMissing),
|
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
|
||||||
Iterator.single(buildWatches)
|
|
||||||
}).cache()
|
|
||||||
} else {
|
|
||||||
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
|
||||||
labeledPointGroupSets => {
|
|
||||||
val buildWatches = () => Watches.buildWatchesWithGroup(
|
|
||||||
labeledPointGroupSets.map {
|
|
||||||
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
|
|
||||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
|
||||||
},
|
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
|
||||||
Iterator.single(buildWatches)
|
|
||||||
}).cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def coPartitionGroupSets(
|
|
||||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
|
||||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
|
||||||
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
|
||||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
|
||||||
case (name, rdd) => {
|
|
||||||
val aggedRdd = aggByGroupInfo(rdd)
|
|
||||||
if (aggedRdd.getNumPartitions != nWorkers) {
|
|
||||||
name -> aggedRdd.repartition(nWorkers)
|
|
||||||
} else {
|
|
||||||
name -> aggedRdd
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
|
|
||||||
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)) {
|
|
||||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
|
||||||
rddOfIterWrapper.zipPartitions(rddOfIter) {
|
|
||||||
(itrWrapper, itr) =>
|
|
||||||
if (!itr.hasNext) {
|
|
||||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
|
||||||
"the number of elements in each dataframe is larger than the number of workers")
|
|
||||||
throw new Exception("too few elements in evaluation sets")
|
|
||||||
}
|
|
||||||
val itrArray = itrWrapper.toArray
|
|
||||||
if (itrArray.head != null) {
|
|
||||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
|
||||||
} else {
|
|
||||||
new IteratorWrapper(Array(name -> itr))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
|
||||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
|
||||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
|
||||||
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
|
||||||
|
|
||||||
// edge groups with partition id.
|
|
||||||
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
|
||||||
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
|
||||||
group => (TaskContext.getPartitionId(), group))
|
|
||||||
|
|
||||||
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
|
||||||
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
|
|
||||||
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
|
||||||
groups => {
|
|
||||||
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
|
||||||
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
|
||||||
it.toArray.sortBy(_._1).flatMap(_._2.points)
|
|
||||||
})
|
|
||||||
normalGroups.union(stitchedGroups)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build RDD[() => Watches] for Non-Ranking
|
|
||||||
* @param trainingData the training data RDD
|
|
||||||
* @param xgbExecutionParams xgboost execution params
|
|
||||||
* @param evalSetsMap the eval RDD
|
|
||||||
* @return RDD[() => Watches]
|
|
||||||
*/
|
|
||||||
private def trainForNonRanking(
|
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
|
|
||||||
if (evalSetsMap.isEmpty) {
|
|
||||||
trainingData.mapPartitions { labeledPoints => {
|
|
||||||
val buildWatches = () => Watches.buildWatches(xgbExecutionParams,
|
|
||||||
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
|
||||||
xgbExecutionParams.allowNonZeroForMissing),
|
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
|
||||||
Iterator.single(buildWatches)
|
|
||||||
}}.cache()
|
|
||||||
} else {
|
|
||||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
|
||||||
mapPartitions {
|
|
||||||
nameAndLabeledPointSets =>
|
|
||||||
val buildWatches = () => Watches.buildWatches(
|
|
||||||
nameAndLabeledPointSets.map {
|
|
||||||
case (name, iter) => (name, DataUtils.processMissingValues(iter,
|
|
||||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
|
||||||
},
|
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
|
||||||
Iterator.single(buildWatches)
|
|
||||||
}.cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def coPartitionNoGroupSets(
|
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
|
||||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
|
||||||
nWorkers: Int) = {
|
|
||||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
|
||||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
|
||||||
val repartitionedDatasets = allDatasets.map { case (name, rdd) =>
|
|
||||||
if (rdd.getNumPartitions != nWorkers) {
|
|
||||||
(name, rdd.repartition(nWorkers))
|
|
||||||
} else {
|
|
||||||
(name, rdd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
|
|
||||||
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)) {
|
|
||||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
|
||||||
rddOfIterWrapper.zipPartitions(rddOfIter) {
|
|
||||||
(itrWrapper, itr) =>
|
|
||||||
if (!itr.hasNext) {
|
|
||||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
|
||||||
"the number of elements in each dataframe is larger than the number of workers")
|
|
||||||
throw new Exception("too few elements in evaluation sets")
|
|
||||||
}
|
|
||||||
val itrArray = itrWrapper.toArray
|
|
||||||
if (itrArray.head != null) {
|
|
||||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
|
||||||
} else {
|
|
||||||
new IteratorWrapper(Array(name -> itr))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
|
||||||
if (useExternalMemory) {
|
|
||||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
|
||||||
Some(dir.toAbsolutePath.toString)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
|
|
||||||
extends Iterator[(String, Iterator[T])] {
|
|
||||||
|
|
||||||
private var currentIndex = 0
|
|
||||||
|
|
||||||
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
|
|
||||||
|
|
||||||
override def next(): (String, Iterator[T]) = {
|
|
||||||
currentIndex += 1
|
|
||||||
arrayOfXGBLabeledPoints(currentIndex - 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Training data group in a RDD partition.
|
|
||||||
*
|
|
||||||
* @param groupId The group id
|
|
||||||
* @param points Array of XGBLabeledPoint within the same group.
|
|
||||||
* @param isEdgeGroup whether it is a first or last group in a RDD partition.
|
|
||||||
*/
|
|
||||||
private[spark] case class XGBLabeledPointGroup(
|
|
||||||
groupId: Int,
|
|
||||||
points: Array[XGBLabeledPoint],
|
|
||||||
isEdgeGroup: Boolean)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
|
||||||
* And the first and the last groups may not have all the items due to the data partition.
|
|
||||||
* <code>LabeledPointGroupIterator</code> organizes data in a tuple format:
|
|
||||||
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
|
||||||
* The edge groups across partitions can be stitched together later.
|
|
||||||
* @param base collection of <code>XGBLabeledPoint</code>
|
|
||||||
*/
|
|
||||||
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
|
||||||
extends AbstractIterator[XGBLabeledPointGroup] {
|
|
||||||
|
|
||||||
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
|
||||||
private var isNewGroup = false
|
|
||||||
|
|
||||||
override def hasNext: Boolean = {
|
|
||||||
base.hasNext || isNewGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
override def next(): XGBLabeledPointGroup = {
|
|
||||||
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
|
||||||
var isFirstGroup = true
|
|
||||||
if (firstPointOfNextGroup != null) {
|
|
||||||
builder += firstPointOfNextGroup
|
|
||||||
isFirstGroup = false
|
|
||||||
}
|
|
||||||
|
|
||||||
isNewGroup = false
|
|
||||||
while (!isNewGroup && base.hasNext) {
|
|
||||||
val point = base.next()
|
|
||||||
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
|
||||||
firstPointOfNextGroup = point
|
|
||||||
if (point.group == groupId) {
|
|
||||||
// add to current group
|
|
||||||
builder += point
|
|
||||||
} else {
|
|
||||||
// start a new group
|
|
||||||
isNewGroup = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val isLastGroup = !isNewGroup
|
|
||||||
val result = builder.result()
|
|
||||||
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
|
||||||
|
|
||||||
group
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,72 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
|
||||||
|
|
||||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.types.StructType
|
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* PreXGBoost implementation provider
|
|
||||||
*/
|
|
||||||
private[scala] trait PreXGBoostProvider {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the provider is enabled or not
|
|
||||||
* @param dataset the input dataset
|
|
||||||
* @return Boolean
|
|
||||||
*/
|
|
||||||
def providerEnabled(dataset: Option[Dataset[_]]): Boolean = false
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform schema
|
|
||||||
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
|
|
||||||
* XGBoostRegressor/XGBoostRegressionModel
|
|
||||||
* @param schema the input schema
|
|
||||||
* @return the transformed schema
|
|
||||||
*/
|
|
||||||
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
|
||||||
*
|
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
|
||||||
* @param dataset the training data
|
|
||||||
* @param params all user defined and defaulted params
|
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
|
|
||||||
* RDD[() => Watches] will be used as the training input to build DMatrix
|
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
|
||||||
*/
|
|
||||||
def buildDatasetToRDD(
|
|
||||||
estimator: Estimator[_],
|
|
||||||
dataset: Dataset[_],
|
|
||||||
params: Map[String, Any]):
|
|
||||||
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]])
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform Dataset
|
|
||||||
*
|
|
||||||
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
|
||||||
* @param dataset the input Dataset to transform
|
|
||||||
* @return the transformed DataFrame
|
|
||||||
*/
|
|
||||||
def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -14,12 +14,49 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.util
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||||
|
|
||||||
// based on org.apache.spark.util copy /paste
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
object Utils {
|
|
||||||
|
private[scala] object Utils {
|
||||||
|
|
||||||
|
private[spark] implicit class XGBLabeledPointFeatures(
|
||||||
|
val labeledPoint: XGBLabeledPoint
|
||||||
|
) extends AnyVal {
|
||||||
|
/** Converts the point to [[MLLabeledPoint]]. */
|
||||||
|
private[spark] def asML: MLLabeledPoint = {
|
||||||
|
MLLabeledPoint(labeledPoint.label, labeledPoint.features)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
|
||||||
|
*/
|
||||||
|
def features: Vector = if (labeledPoint.indices == null) {
|
||||||
|
Vectors.dense(labeledPoint.values.map(_.toDouble))
|
||||||
|
} else {
|
||||||
|
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
|
||||||
|
/**
|
||||||
|
* Converts a [[Vector]] to a data point with a dummy label.
|
||||||
|
*
|
||||||
|
* This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
|
||||||
|
* for prediction.
|
||||||
|
*/
|
||||||
|
// TODO support sparsevector
|
||||||
|
def asXGB: XGBLabeledPoint = v match {
|
||||||
|
case v: DenseVector =>
|
||||||
|
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
|
||||||
|
case v: SparseVector =>
|
||||||
|
XGBLabeledPoint(0.0f, v.size, v.indices, v.toDense.values.map(_.toFloat))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
||||||
|
|
||||||
@ -27,6 +64,7 @@ object Utils {
|
|||||||
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
|
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
|
||||||
|
|
||||||
// scalastyle:off classforname
|
// scalastyle:off classforname
|
||||||
|
|
||||||
/** Preferred alternative to Class.forName(className) */
|
/** Preferred alternative to Class.forName(className) */
|
||||||
def classForName(className: String): Class[_] = {
|
def classForName(className: String): Class[_] = {
|
||||||
Class.forName(className, true, getContextOrSparkClassLoader)
|
Class.forName(className, true, getContextOrSparkClassLoader)
|
||||||
@ -35,6 +73,7 @@ object Utils {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the TypeHints according to the value
|
* Get the TypeHints according to the value
|
||||||
|
*
|
||||||
* @param value the instance of class to be serialized
|
* @param value the instance of class to be serialized
|
||||||
* @return if value is null,
|
* @return if value is null,
|
||||||
* return NoTypeHints
|
* return NoTypeHints
|
||||||
@ -53,6 +92,7 @@ object Utils {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the TypeHints according to the saved jsonClass field
|
* Get the TypeHints according to the saved jsonClass field
|
||||||
|
*
|
||||||
* @param json
|
* @param json
|
||||||
* @return TypeHints
|
* @return TypeHints
|
||||||
*/
|
*/
|
||||||
@ -68,4 +108,17 @@ object Utils {
|
|||||||
FullTypeHints(List(Utils.classForName(className)))
|
FullTypeHints(List(Utils.classForName(className)))
|
||||||
}.getOrElse(NoTypeHints)
|
}.getOrElse(NoTypeHints)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val TRAIN_NAME = "train"
|
||||||
|
val VALIDATION_NAME = "eval"
|
||||||
|
|
||||||
|
|
||||||
|
/** Executes the provided code block and then closes the resource */
|
||||||
|
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||||
|
try {
|
||||||
|
block(r)
|
||||||
|
} finally {
|
||||||
|
r.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -18,227 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
import scala.util.Random
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.FileSystem
|
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
|
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
|
||||||
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
|
|
||||||
import org.apache.spark.sql.SparkSession
|
|
||||||
|
|
||||||
/**
|
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||||
* Rabit tracker configurations.
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
*
|
|
||||||
* @param timeout The number of seconds before timeout waiting for workers to connect. and
|
|
||||||
* for the tracker to shutdown.
|
|
||||||
* @param hostIp The Rabit Tracker host IP address.
|
|
||||||
* This is only needed if the host IP cannot be automatically guessed.
|
|
||||||
* @param port The port number for the tracker to listen to. Use a system allocated one by
|
|
||||||
* default.
|
|
||||||
*/
|
|
||||||
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
|
|
||||||
|
|
||||||
object TrackerConf {
|
private[spark] case class RuntimeParams(
|
||||||
def apply(): TrackerConf = TrackerConf(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
|
||||||
|
|
||||||
private[scala] case class XGBoostExecutionParams(
|
|
||||||
numWorkers: Int,
|
numWorkers: Int,
|
||||||
numRounds: Int,
|
numRounds: Int,
|
||||||
useExternalMemory: Boolean,
|
|
||||||
obj: ObjectiveTrait,
|
|
||||||
eval: EvalTrait,
|
|
||||||
missing: Float,
|
|
||||||
allowNonZeroForMissing: Boolean,
|
|
||||||
trackerConf: TrackerConf,
|
trackerConf: TrackerConf,
|
||||||
checkpointParam: Option[ExternalCheckpointParams],
|
|
||||||
xgbInputParams: XGBoostExecutionInputParams,
|
|
||||||
earlyStoppingRounds: Int,
|
earlyStoppingRounds: Int,
|
||||||
cacheTrainingSet: Boolean,
|
device: String,
|
||||||
device: Option[String],
|
|
||||||
isLocal: Boolean,
|
isLocal: Boolean,
|
||||||
featureNames: Option[Array[String]],
|
runOnGpu: Boolean,
|
||||||
featureTypes: Option[Array[String]],
|
obj: Option[ObjectiveTrait] = None,
|
||||||
runOnGpu: Boolean) {
|
eval: Option[EvalTrait] = None)
|
||||||
|
|
||||||
private var rawParamMap: Map[String, Any] = _
|
|
||||||
|
|
||||||
def setRawParamMap(inputMap: Map[String, Any]): Unit = {
|
|
||||||
rawParamMap = inputMap
|
|
||||||
}
|
|
||||||
|
|
||||||
def toMap: Map[String, Any] = {
|
|
||||||
rawParamMap
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){
|
|
||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
|
|
||||||
private val isLocal = sc.isLocal
|
|
||||||
|
|
||||||
private val overridedParams = overrideParams(rawParams, sc)
|
|
||||||
|
|
||||||
validateSparkSslConf()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
|
||||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
|
||||||
* via conf `xgboost.spark.ignoreSsl`.
|
|
||||||
*/
|
|
||||||
private def validateSparkSslConf(): Unit = {
|
|
||||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
|
||||||
SparkSession.getActiveSession match {
|
|
||||||
case Some(ss) =>
|
|
||||||
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
|
||||||
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
|
||||||
case None =>
|
|
||||||
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
|
||||||
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
|
||||||
}
|
|
||||||
if (sparkSslEnabled) {
|
|
||||||
if (xgboostSparkIgnoreSsl) {
|
|
||||||
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
|
||||||
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
|
||||||
} else {
|
|
||||||
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
|
||||||
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
|
||||||
"To override this protection and still use xgboost-spark at your own risk, " +
|
|
||||||
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* we should not include any nested structure in the output of this function as the map is
|
|
||||||
* eventually to be feed to xgboost4j layer
|
|
||||||
*/
|
|
||||||
private def overrideParams(
|
|
||||||
params: Map[String, Any],
|
|
||||||
sc: SparkContext): Map[String, Any] = {
|
|
||||||
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
|
||||||
var overridedParams = params
|
|
||||||
if (overridedParams.contains("nthread")) {
|
|
||||||
val nThread = overridedParams("nthread").toString.toInt
|
|
||||||
require(nThread <= coresPerTask,
|
|
||||||
s"the nthread configuration ($nThread) must be no larger than " +
|
|
||||||
s"spark.task.cpus ($coresPerTask)")
|
|
||||||
} else {
|
|
||||||
overridedParams = overridedParams + ("nthread" -> coresPerTask)
|
|
||||||
}
|
|
||||||
|
|
||||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
|
||||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
|
||||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
|
||||||
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
|
|
||||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
|
||||||
}
|
|
||||||
overridedParams
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The Map parameters accepted by estimator's constructor may have string type,
|
|
||||||
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
|
|
||||||
* kind of parameters into the correct type in the function.
|
|
||||||
*
|
|
||||||
* @return XGBoostExecutionParams
|
|
||||||
*/
|
|
||||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
|
||||||
|
|
||||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
|
||||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
|
||||||
if (obj != null) {
|
|
||||||
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
|
||||||
"is not defined, you have to specify the objective type as classification or regression" +
|
|
||||||
" with a customized objective function")
|
|
||||||
}
|
|
||||||
|
|
||||||
var trainTestRatio = 1.0
|
|
||||||
if (overridedParams.contains("train_test_ratio")) {
|
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
|
||||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
|
||||||
"'eval_set_names'")
|
|
||||||
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
|
|
||||||
}
|
|
||||||
|
|
||||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
|
||||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
|
||||||
val useExternalMemory = overridedParams
|
|
||||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
|
||||||
|
|
||||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
|
||||||
val allowNonZeroForMissing = overridedParams
|
|
||||||
.getOrElse("allow_non_zero_for_missing", false)
|
|
||||||
.asInstanceOf[Boolean]
|
|
||||||
|
|
||||||
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
|
|
||||||
val device: Option[String] = overridedParams.get("device").map(_.toString)
|
|
||||||
val deviceIsGpu = device.exists(_ == "cuda")
|
|
||||||
|
|
||||||
require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
|
|
||||||
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
|
|
||||||
|
|
||||||
// back-compatible with "gpu_hist"
|
|
||||||
val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu
|
|
||||||
|
|
||||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
|
||||||
case None => TrackerConf()
|
|
||||||
case Some(conf: TrackerConf) => conf
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
|
||||||
"instance of TrackerConf.")
|
|
||||||
}
|
|
||||||
|
|
||||||
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
|
|
||||||
|
|
||||||
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
|
||||||
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
|
||||||
|
|
||||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
|
||||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
|
||||||
|
|
||||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
|
||||||
.asInstanceOf[Boolean]
|
|
||||||
|
|
||||||
val featureNames = if (overridedParams.contains("feature_names")) {
|
|
||||||
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
|
|
||||||
} else None
|
|
||||||
val featureTypes = if (overridedParams.contains("feature_types")){
|
|
||||||
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
|
|
||||||
} else None
|
|
||||||
|
|
||||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
|
||||||
missing, allowNonZeroForMissing, trackerConf,
|
|
||||||
checkpointParam,
|
|
||||||
inputParams,
|
|
||||||
earlyStoppingRounds,
|
|
||||||
cacheTrainingSet,
|
|
||||||
device,
|
|
||||||
isLocal,
|
|
||||||
featureNames,
|
|
||||||
featureTypes,
|
|
||||||
runOnGpu
|
|
||||||
)
|
|
||||||
xgbExecParam.setRawParamMap(overridedParams)
|
|
||||||
xgbExecParam
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A trait to manage stage-level scheduling
|
* A trait to manage stage-level scheduling
|
||||||
*/
|
*/
|
||||||
private[spark] trait XGBoostStageLevel extends Serializable {
|
private[spark] trait StageLevelScheduling extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
|
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
|
||||||
@ -255,8 +58,7 @@ private[spark] trait XGBoostStageLevel extends Serializable {
|
|||||||
* @param conf spark configurations
|
* @param conf spark configurations
|
||||||
* @return Boolean to skip stage-level scheduling or not
|
* @return Boolean to skip stage-level scheduling or not
|
||||||
*/
|
*/
|
||||||
private[spark] def skipStageLevelScheduling(
|
private[spark] def skipStageLevelScheduling(sparkVersion: String,
|
||||||
sparkVersion: String,
|
|
||||||
runOnGpu: Boolean,
|
runOnGpu: Boolean,
|
||||||
conf: SparkConf): Boolean = {
|
conf: SparkConf): Boolean = {
|
||||||
if (runOnGpu) {
|
if (runOnGpu) {
|
||||||
@ -313,14 +115,13 @@ private[spark] trait XGBoostStageLevel extends Serializable {
|
|||||||
* on a single executor simultaneously.
|
* on a single executor simultaneously.
|
||||||
*
|
*
|
||||||
* @param sc the spark context
|
* @param sc the spark context
|
||||||
* @param rdd which rdd to be applied with new resource profile
|
* @param rdd the rdd to be applied with new resource profile
|
||||||
* @return the original rdd or the changed rdd
|
* @return the original rdd or the modified rdd
|
||||||
*/
|
*/
|
||||||
private[spark] def tryStageLevelScheduling(
|
private[spark] def tryStageLevelScheduling[T](sc: SparkContext,
|
||||||
sc: SparkContext,
|
xgbExecParams: RuntimeParams,
|
||||||
xgbExecParams: XGBoostExecutionParams,
|
rdd: RDD[T]
|
||||||
rdd: RDD[(Booster, Map[String, Array[Float]])]
|
): RDD[T] = {
|
||||||
): RDD[(Booster, Map[String, Array[Float]])] = {
|
|
||||||
|
|
||||||
val conf = sc.getConf
|
val conf = sc.getConf
|
||||||
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
|
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
|
||||||
@ -360,7 +161,7 @@ private[spark] trait XGBoostStageLevel extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object XGBoost extends XGBoostStageLevel {
|
private[spark] object XGBoost extends StageLevelScheduling {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
def getGPUAddrFromResources: Int = {
|
def getGPUAddrFromResources: Int = {
|
||||||
@ -383,172 +184,118 @@ object XGBoost extends XGBoostStageLevel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def buildWatchesAndCheck(buildWatchesFun: () => Watches): Watches = {
|
|
||||||
val watches = buildWatchesFun()
|
|
||||||
// to workaround the empty partitions in training dataset,
|
|
||||||
// this might not be the best efficient implementation, see
|
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
|
||||||
if (!watches.toMap.contains("train")) {
|
|
||||||
throw new XGBoostError(
|
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
|
||||||
s" ${TaskContext.getPartitionId()}")
|
|
||||||
}
|
|
||||||
watches
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildDistributedBooster(
|
/**
|
||||||
buildWatches: () => Watches,
|
* Train a XGBoost Boost on the dataset in the Watches
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
*
|
||||||
rabitEnv: java.util.Map[String, Object],
|
* @param watches holds the dataset to be trained
|
||||||
obj: ObjectiveTrait,
|
* @param runtimeParams XGBoost runtime parameters
|
||||||
eval: EvalTrait,
|
* @param xgboostParams XGBoost library paramters
|
||||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
* @return a booster and the metrics
|
||||||
|
*/
|
||||||
|
private def trainBooster(watches: Watches,
|
||||||
|
runtimeParams: RuntimeParams,
|
||||||
|
xgboostParams: Map[String, Any]
|
||||||
|
): (Booster, Array[Array[Float]]) = {
|
||||||
|
|
||||||
var watches: Watches = null
|
val numEarlyStoppingRounds = runtimeParams.earlyStoppingRounds
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
val metrics = Array.tabulate(watches.size)(_ =>
|
||||||
val attempt = TaskContext.get().attemptNumber.toString
|
Array.ofDim[Float](runtimeParams.numRounds))
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
|
||||||
val numRounds = xgbExecutionParam.numRounds
|
|
||||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
|
||||||
|
|
||||||
try {
|
var params = xgboostParams
|
||||||
Communicator.init(rabitEnv)
|
if (runtimeParams.runOnGpu) {
|
||||||
|
val gpuId = if (runtimeParams.isLocal) {
|
||||||
watches = buildWatchesAndCheck(buildWatches)
|
TaskContext.get().partitionId() % runtimeParams.numWorkers
|
||||||
|
|
||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
|
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
|
||||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
|
||||||
|
|
||||||
var params = xgbExecutionParam.toMap
|
|
||||||
if (xgbExecutionParam.runOnGpu) {
|
|
||||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
|
||||||
// For local mode, force gpu id to primary device
|
|
||||||
0
|
|
||||||
} else {
|
} else {
|
||||||
getGPUAddrFromResources
|
getGPUAddrFromResources
|
||||||
}
|
}
|
||||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||||
params = params + ("device" -> s"cuda:$gpuId")
|
params = params + ("device" -> s"cuda:$gpuId")
|
||||||
}
|
}
|
||||||
|
val booster = SXGBoost.train(watches.toMap("train"), params, runtimeParams.numRounds,
|
||||||
val booster = if (makeCheckpoint) {
|
watches.toMap, metrics, runtimeParams.obj.getOrElse(null),
|
||||||
SXGBoost.trainAndSaveCheckpoint(
|
runtimeParams.eval.getOrElse(null), earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
watches.toMap("train"), params, numRounds,
|
(booster, metrics)
|
||||||
watches.toMap, metrics, obj, eval,
|
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
|
||||||
} else {
|
|
||||||
SXGBoost.train(watches.toMap("train"), params, numRounds,
|
|
||||||
watches.toMap, metrics, obj, eval,
|
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
|
||||||
}
|
}
|
||||||
if (TaskContext.get().partitionId() == 0) {
|
|
||||||
|
/**
|
||||||
|
* Train a XGBoost booster with parameters on the dataset
|
||||||
|
*
|
||||||
|
* @param input the input dataset for training
|
||||||
|
* @param runtimeParams the runtime parameters for jvm
|
||||||
|
* @param xgboostParams the xgboost parameters to pass to xgboost library
|
||||||
|
* @return the booster and the metrics
|
||||||
|
*/
|
||||||
|
def train(input: RDD[Watches],
|
||||||
|
runtimeParams: RuntimeParams,
|
||||||
|
xgboostParams: Map[String, Any]): (Booster, Map[String, Array[Float]]) = {
|
||||||
|
|
||||||
|
val sc = input.sparkContext
|
||||||
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters: $xgboostParams")
|
||||||
|
|
||||||
|
// TODO Rabit tracker exception handling.
|
||||||
|
val trackerConf = runtimeParams.trackerConf
|
||||||
|
|
||||||
|
val tracker = new RabitTracker(runtimeParams.numWorkers,
|
||||||
|
trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||||
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
|
|
||||||
|
try {
|
||||||
|
val rabitEnv = tracker.getWorkerArgs()
|
||||||
|
|
||||||
|
val boostersAndMetrics = input.barrier().mapPartitions { iter =>
|
||||||
|
val partitionId = TaskContext.getPartitionId()
|
||||||
|
rabitEnv.put("DMLC_TASK_ID", partitionId.toString)
|
||||||
|
try {
|
||||||
|
Communicator.init(rabitEnv)
|
||||||
|
require(iter.hasNext, "Failed to create DMatrix")
|
||||||
|
val watches = iter.next()
|
||||||
|
try {
|
||||||
|
val (booster, metrics) = trainBooster(watches, runtimeParams, xgboostParams)
|
||||||
|
if (partitionId == 0) {
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} else {
|
} else {
|
||||||
Iterator.empty
|
Iterator.empty
|
||||||
}
|
}
|
||||||
} catch {
|
|
||||||
case xgbException: XGBoostError =>
|
|
||||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
|
||||||
throw xgbException
|
|
||||||
} finally {
|
} finally {
|
||||||
|
if (watches != null) {
|
||||||
|
watches.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
// If shutdown throws exception, then the real exception for
|
||||||
|
// training will be swallowed,
|
||||||
|
try {
|
||||||
Communicator.shutdown()
|
Communicator.shutdown()
|
||||||
if (watches != null) watches.delete()
|
} catch {
|
||||||
|
case e: Throwable =>
|
||||||
|
logger.error("Communicator.shutdown error: ", e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Executes the provided code block inside a tracker and then stops the tracker
|
val rdd = tryStageLevelScheduling(sc, runtimeParams, boostersAndMetrics)
|
||||||
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
|
|
||||||
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
|
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
|
||||||
try {
|
|
||||||
block(tracker)
|
|
||||||
} finally {
|
|
||||||
tracker.stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return A tuple of the booster and the metrics used to build training summary
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
private[spark] def trainDistributed(
|
|
||||||
sc: SparkContext,
|
|
||||||
buildTrainingData: XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]),
|
|
||||||
params: Map[String, Any]):
|
|
||||||
(Booster, Map[String, Array[Float]]) = {
|
|
||||||
|
|
||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
|
||||||
|
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
|
|
||||||
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
|
|
||||||
|
|
||||||
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
|
|
||||||
val checkpointManager = new ExternalCheckpointManager(
|
|
||||||
checkpointParam.checkpointPath,
|
|
||||||
FileSystem.get(sc.hadoopConfiguration))
|
|
||||||
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
|
|
||||||
checkpointManager.loadCheckpointAsScalaBooster()
|
|
||||||
}.orNull
|
|
||||||
|
|
||||||
// Get the training data RDD and the cachedRDD
|
|
||||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
|
|
||||||
|
|
||||||
try {
|
|
||||||
val (booster, metrics) = withTracker(
|
|
||||||
runtimeParams.numWorkers,
|
|
||||||
runtimeParams.trackerConf
|
|
||||||
) { tracker =>
|
|
||||||
val rabitEnv = tracker.getWorkerArgs()
|
|
||||||
|
|
||||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
|
|
||||||
var optionWatches: Option[() => Watches] = None
|
|
||||||
|
|
||||||
// take the first Watches to train
|
|
||||||
if (iter.hasNext) {
|
|
||||||
optionWatches = Some(iter.next())
|
|
||||||
}
|
|
||||||
|
|
||||||
optionWatches.map { buildWatches =>
|
|
||||||
buildDistributedBooster(buildWatches,
|
|
||||||
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
|
|
||||||
}.getOrElse(throw new RuntimeException("No Watches to train"))
|
|
||||||
}
|
|
||||||
|
|
||||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
|
|
||||||
boostersAndMetrics)
|
|
||||||
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
||||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||||
// it fails.
|
// it fails.
|
||||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
val (booster, metrics) = rdd.repartition(1).collect()(0)
|
||||||
(booster, metrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
// we should delete the checkpoint directory after a successful training
|
|
||||||
runtimeParams.checkpointParam.foreach {
|
|
||||||
cpParam =>
|
|
||||||
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
|
|
||||||
val checkpointManager = new ExternalCheckpointManager(
|
|
||||||
cpParam.checkpointPath,
|
|
||||||
FileSystem.get(sc.hadoopConfiguration))
|
|
||||||
checkpointManager.cleanPath()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(booster, metrics)
|
(booster, metrics)
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
// if the job was aborted due to an exception
|
// if the job was aborted due to an exception
|
||||||
logger.error("the job was aborted due to ", t)
|
logger.error("XGBoost job was aborted due to ", t)
|
||||||
throw t
|
throw t
|
||||||
} finally {
|
} finally {
|
||||||
optionalCachedRDD.foreach(_.unpersist())
|
try {
|
||||||
|
tracker.stop()
|
||||||
|
} catch {
|
||||||
|
case t: Throwable => logger.error(t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class Watches private[scala] (
|
class Watches private[scala](val datasets: Array[DMatrix],
|
||||||
val datasets: Array[DMatrix],
|
|
||||||
val names: Array[String],
|
val names: Array[String],
|
||||||
val cacheDirName: Option[String]) {
|
val cacheDirName: Option[String]) {
|
||||||
|
|
||||||
@ -568,211 +315,14 @@ class Watches private[scala] (
|
|||||||
override def toString: String = toMap.toString
|
override def toString: String = toMap.toString
|
||||||
}
|
}
|
||||||
|
|
||||||
private object Watches {
|
/**
|
||||||
|
* Rabit tracker configurations.
|
||||||
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
*
|
||||||
val builder = new mutable.ArrayBuilder.ofFloat()
|
* @param timeout The number of seconds before timeout waiting for workers to connect. and
|
||||||
var nTotal = 0
|
* for the tracker to shutdown.
|
||||||
var nUndefined = 0
|
* @param hostIp The Rabit Tracker host IP address.
|
||||||
while (baseMargins.hasNext) {
|
* This is only needed if the host IP cannot be automatically guessed.
|
||||||
nTotal += 1
|
* @param port The port number for the tracker to listen to. Use a system allocated one by
|
||||||
val baseMargin = baseMargins.next()
|
* default.
|
||||||
if (baseMargin.isNaN) {
|
*/
|
||||||
nUndefined += 1 // don't waste space for all-NaNs.
|
private[spark] case class TrackerConf(timeout: Int = 0, hostIp: String = "", port: Int = 0)
|
||||||
} else {
|
|
||||||
builder += baseMargin
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (nUndefined == nTotal) {
|
|
||||||
None
|
|
||||||
} else if (nUndefined == 0) {
|
|
||||||
Some(builder.result())
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
|
||||||
s"If you want to specify base margin, ensure all values are non-NaN.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def buildWatches(
|
|
||||||
nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
|
|
||||||
cachedDirName: Option[String]): Watches = {
|
|
||||||
val dms = nameAndLabeledPointSets.map {
|
|
||||||
case (name, labeledPoints) =>
|
|
||||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val duplicatedItr = labeledPoints.map(labeledPoint => {
|
|
||||||
baseMargins += labeledPoint.baseMargin
|
|
||||||
labeledPoint
|
|
||||||
})
|
|
||||||
val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
|
|
||||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
|
||||||
if (baseMargin.isDefined) {
|
|
||||||
dMatrix.setBaseMargin(baseMargin.get)
|
|
||||||
}
|
|
||||||
(name, dMatrix)
|
|
||||||
}.toArray
|
|
||||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
def buildWatches(
|
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
|
||||||
cacheDirName: Option[String]): Watches = {
|
|
||||||
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
|
||||||
val seed = xgbExecutionParams.xgbInputParams.seed
|
|
||||||
val r = new Random(seed)
|
|
||||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
|
||||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val trainPoints = labeledPoints.filter { labeledPoint =>
|
|
||||||
val accepted = r.nextDouble() <= trainTestRatio
|
|
||||||
if (!accepted) {
|
|
||||||
testPoints += labeledPoint
|
|
||||||
testBaseMargins += labeledPoint.baseMargin
|
|
||||||
} else {
|
|
||||||
trainBaseMargins += labeledPoint.baseMargin
|
|
||||||
}
|
|
||||||
accepted
|
|
||||||
}
|
|
||||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
|
||||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
|
||||||
|
|
||||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
|
||||||
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
|
||||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
|
||||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
|
||||||
|
|
||||||
if (xgbExecutionParams.featureNames.isDefined) {
|
|
||||||
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
|
||||||
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (xgbExecutionParams.featureTypes.isDefined) {
|
|
||||||
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
|
||||||
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
|
||||||
}
|
|
||||||
|
|
||||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
def buildWatchesWithGroup(
|
|
||||||
nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
|
|
||||||
cachedDirName: Option[String]): Watches = {
|
|
||||||
val dms = nameAndlabeledPointGroupSets.map {
|
|
||||||
case (name, labeledPointsGroups) =>
|
|
||||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val groupsInfo = new mutable.ArrayBuilder.ofInt
|
|
||||||
val weights = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val iter = labeledPointsGroups.filter(labeledPointGroup => {
|
|
||||||
var groupWeight = -1.0f
|
|
||||||
var groupSize = 0
|
|
||||||
labeledPointGroup.map { labeledPoint => {
|
|
||||||
if (groupWeight < 0) {
|
|
||||||
groupWeight = labeledPoint.weight
|
|
||||||
} else if (groupWeight != labeledPoint.weight) {
|
|
||||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
|
||||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
|
||||||
}
|
|
||||||
baseMargins += labeledPoint.baseMargin
|
|
||||||
groupSize += 1
|
|
||||||
labeledPoint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
weights += groupWeight
|
|
||||||
groupsInfo += groupSize
|
|
||||||
true
|
|
||||||
})
|
|
||||||
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
|
|
||||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
|
||||||
if (baseMargin.isDefined) {
|
|
||||||
dMatrix.setBaseMargin(baseMargin.get)
|
|
||||||
}
|
|
||||||
dMatrix.setGroup(groupsInfo.result())
|
|
||||||
dMatrix.setWeight(weights.result())
|
|
||||||
(name, dMatrix)
|
|
||||||
}.toArray
|
|
||||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
def buildWatchesWithGroup(
|
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
|
||||||
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
|
||||||
cacheDirName: Option[String]): Watches = {
|
|
||||||
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
|
||||||
val seed = xgbExecutionParams.xgbInputParams.seed
|
|
||||||
val r = new Random(seed)
|
|
||||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
|
||||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
||||||
|
|
||||||
val trainGroups = new mutable.ArrayBuilder.ofInt
|
|
||||||
val testGroups = new mutable.ArrayBuilder.ofInt
|
|
||||||
|
|
||||||
val trainWeights = new mutable.ArrayBuilder.ofFloat
|
|
||||||
val testWeights = new mutable.ArrayBuilder.ofFloat
|
|
||||||
|
|
||||||
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
|
||||||
val accepted = r.nextDouble() <= trainTestRatio
|
|
||||||
if (!accepted) {
|
|
||||||
var groupWeight = -1.0f
|
|
||||||
var groupSize = 0
|
|
||||||
labeledPointGroup.foreach(labeledPoint => {
|
|
||||||
testPoints += labeledPoint
|
|
||||||
testBaseMargins += labeledPoint.baseMargin
|
|
||||||
if (groupWeight < 0) {
|
|
||||||
groupWeight = labeledPoint.weight
|
|
||||||
} else if (labeledPoint.weight != groupWeight) {
|
|
||||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
|
||||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
|
||||||
}
|
|
||||||
groupSize += 1
|
|
||||||
})
|
|
||||||
testWeights += groupWeight
|
|
||||||
testGroups += groupSize
|
|
||||||
} else {
|
|
||||||
var groupWeight = -1.0f
|
|
||||||
var groupSize = 0
|
|
||||||
labeledPointGroup.foreach { labeledPoint => {
|
|
||||||
if (groupWeight < 0) {
|
|
||||||
groupWeight = labeledPoint.weight
|
|
||||||
} else if (labeledPoint.weight != groupWeight) {
|
|
||||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
|
||||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
|
||||||
}
|
|
||||||
trainBaseMargins += labeledPoint.baseMargin
|
|
||||||
groupSize += 1
|
|
||||||
}}
|
|
||||||
trainWeights += groupWeight
|
|
||||||
trainGroups += groupSize
|
|
||||||
}
|
|
||||||
accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
|
||||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
|
||||||
trainMatrix.setGroup(trainGroups.result())
|
|
||||||
trainMatrix.setWeight(trainWeights.result())
|
|
||||||
|
|
||||||
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
|
||||||
if (trainTestRatio < 1.0) {
|
|
||||||
testMatrix.setGroup(testGroups.result())
|
|
||||||
testMatrix.setWeight(testWeights.result())
|
|
||||||
}
|
|
||||||
|
|
||||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
|
||||||
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
|
||||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
|
||||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
|
||||||
|
|
||||||
if (xgbExecutionParams.featureNames.isDefined) {
|
|
||||||
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
|
||||||
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
|
||||||
}
|
|
||||||
if (xgbExecutionParams.featureTypes.isDefined) {
|
|
||||||
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
|
||||||
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
|
||||||
}
|
|
||||||
|
|
||||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,490 +16,190 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
import scala.collection.mutable
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
|
|
||||||
import org.apache.hadoop.fs.Path
|
|
||||||
|
|
||||||
import org.apache.spark.ml.classification._
|
|
||||||
import org.apache.spark.ml.linalg._
|
|
||||||
import org.apache.spark.ml.util._
|
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
import scala.collection.{Iterator, mutable}
|
|
||||||
|
|
||||||
|
import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
|
||||||
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
|
||||||
|
import org.apache.spark.sql.Dataset
|
||||||
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
class XGBoostClassifier (
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
override val uid: String,
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
|
||||||
|
|
||||||
|
class XGBoostClassifier(override val uid: String,
|
||||||
private[spark] val xgboostParams: Map[String, Any])
|
private[spark] val xgboostParams: Map[String, Any])
|
||||||
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
|
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
|
||||||
with XGBoostClassifierParams with DefaultParamsWritable {
|
with XGBoostEstimator[XGBoostClassifier, XGBoostClassificationModel]
|
||||||
|
with XGBProbabilisticClassifierParams[XGBoostClassifier] {
|
||||||
|
|
||||||
def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
|
def this() = this(XGBoostClassifier._uid, Map.empty)
|
||||||
|
|
||||||
def this(uid: String) = this(uid, Map[String, Any]())
|
def this(uid: String) = this(uid, Map.empty)
|
||||||
|
|
||||||
def this(xgboostParams: Map[String, Any]) = this(
|
def this(xgboostParams: Map[String, Any]) = this(XGBoostClassifier._uid, xgboostParams)
|
||||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
|
||||||
|
|
||||||
XGBoost2MLlibParams(xgboostParams)
|
xgboost2SparkParams(xgboostParams)
|
||||||
|
|
||||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
private var numberClasses = 0
|
||||||
|
|
||||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
private def validateObjective(dataset: Dataset[_]): Unit = {
|
||||||
|
// If the objective is set explicitly, it must be in BINARY_CLASSIFICATION_OBJS and
|
||||||
def setNumClass(value: Int): this.type = set(numClass, value)
|
// MULTICLASSIFICATION_OBJS
|
||||||
|
val obj = if (isSet(objective)) {
|
||||||
// setters for general params
|
val tmpObj = getObjective
|
||||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
val supportedObjs = BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq
|
||||||
|
require(supportedObjs.contains(tmpObj),
|
||||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
|
||||||
|
Some(tmpObj)
|
||||||
def setNthread(value: Int): this.type = set(nthread, value)
|
|
||||||
|
|
||||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
|
||||||
|
|
||||||
def setSilent(value: Int): this.type = set(silent, value)
|
|
||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
|
||||||
|
|
||||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
|
||||||
|
|
||||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
|
||||||
|
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
|
||||||
|
|
||||||
def setEta(value: Double): this.type = set(eta, value)
|
|
||||||
|
|
||||||
def setGamma(value: Double): this.type = set(gamma, value)
|
|
||||||
|
|
||||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
|
||||||
|
|
||||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
|
||||||
|
|
||||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
|
||||||
|
|
||||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
|
||||||
|
|
||||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
|
||||||
|
|
||||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
|
||||||
|
|
||||||
def setLambda(value: Double): this.type = set(lambda, value)
|
|
||||||
|
|
||||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
|
||||||
|
|
||||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
|
||||||
|
|
||||||
def setDevice(value: String): this.type = set(device, value)
|
|
||||||
|
|
||||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
|
||||||
|
|
||||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
|
||||||
|
|
||||||
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
|
|
||||||
|
|
||||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
|
||||||
|
|
||||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
|
||||||
|
|
||||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
|
||||||
|
|
||||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
|
||||||
|
|
||||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
|
||||||
|
|
||||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
|
||||||
|
|
||||||
// setters for learning params
|
|
||||||
def setObjective(value: String): this.type = set(objective, value)
|
|
||||||
|
|
||||||
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
|
||||||
|
|
||||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
|
||||||
|
|
||||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
|
||||||
|
|
||||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
|
||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
|
||||||
|
|
||||||
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
|
|
||||||
set(maximizeEvaluationMetrics, value)
|
|
||||||
|
|
||||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
|
||||||
|
|
||||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
|
||||||
|
|
||||||
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
|
|
||||||
allowNonZeroForMissing,
|
|
||||||
value
|
|
||||||
)
|
|
||||||
|
|
||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
|
||||||
set(singlePrecisionHistogram, value)
|
|
||||||
|
|
||||||
def setFeatureNames(value: Array[String]): this.type =
|
|
||||||
set(featureNames, value)
|
|
||||||
|
|
||||||
def setFeatureTypes(value: Array[String]): this.type =
|
|
||||||
set(featureTypes, value)
|
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
|
||||||
private def setupDefaultEvalMetric(): String = {
|
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
|
||||||
if ($(objective).startsWith("multi")) {
|
|
||||||
// multi
|
|
||||||
"mlogloss"
|
|
||||||
} else {
|
} else {
|
||||||
// binary
|
None
|
||||||
"logloss"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback from PreXGBoost
|
def inferNumClasses: Int = {
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
var num = getNumClass
|
||||||
if (isFeaturesColSet(schema)) {
|
// Infer num class if num class is not set explicitly.
|
||||||
// User has vectorized the features into VectorUDT.
|
// Note that user sets the num classes explicitly, we're not checking that.
|
||||||
super.transformSchema(schema)
|
if (num == 0) {
|
||||||
|
num = SparkUtils.getNumClasses(dataset, getLabelCol)
|
||||||
|
}
|
||||||
|
require(num > 0)
|
||||||
|
num
|
||||||
|
}
|
||||||
|
|
||||||
|
// objective is set explicitly.
|
||||||
|
if (obj.isDefined) {
|
||||||
|
if (MULTICLASSIFICATION_OBJS.contains(getObjective)) {
|
||||||
|
numberClasses = inferNumClasses
|
||||||
|
setNumClass(numberClasses)
|
||||||
} else {
|
} else {
|
||||||
transformSchemaWithFeaturesCols(true, schema)
|
numberClasses = 2
|
||||||
|
// binary classification doesn't require num_class be set
|
||||||
|
require(!isSet(numClass), "num_class is not allowed for binary classification")
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
// infer the objective according to the num_class
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
numberClasses = inferNumClasses
|
||||||
PreXGBoost.transformSchema(this, schema)
|
if (numberClasses <= 2) {
|
||||||
}
|
|
||||||
|
|
||||||
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
|
|
||||||
val _numClasses = getNumClasses(dataset)
|
|
||||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
|
||||||
throw new Exception("The number of classes in dataset doesn't match " +
|
|
||||||
"\'num_class\' in xgboost params.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_numClasses == 2) {
|
|
||||||
if (!isDefined(objective)) {
|
|
||||||
// If user doesn't set objective, force it to binary:logistic
|
|
||||||
setObjective("binary:logistic")
|
setObjective("binary:logistic")
|
||||||
}
|
logger.warn("Inferred for binary classification, set the objective to binary:logistic")
|
||||||
} else if (_numClasses > 2) {
|
require(!isSet(numClass), "num_class is not allowed for binary classification")
|
||||||
if (!isDefined(objective)) {
|
} else {
|
||||||
// If user doesn't set objective, force it to multi:softprob
|
logger.warn("Inferred for multi classification, set the objective to multi:softprob")
|
||||||
setObjective("multi:softprob")
|
setObjective("multi:softprob")
|
||||||
|
setNumClass(numberClasses)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
/**
|
||||||
set(evalMetric, setupDefaultEvalMetric())
|
* Validate the parameters before training, throw exception if possible
|
||||||
|
*/
|
||||||
|
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
|
||||||
|
super.validate(dataset)
|
||||||
|
validateObjective(dataset)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(customObj) && $(customObj) != null) {
|
override protected def createModel(booster: Booster, summary: XGBoostTrainingSummary):
|
||||||
set(objectiveType, "classification")
|
XGBoostClassificationModel = {
|
||||||
|
new XGBoostClassificationModel(uid, numberClasses, booster, Option(summary))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Packing with all params plus params user defined
|
|
||||||
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
|
|
||||||
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
|
|
||||||
transformSchema(dataset.schema, logging = true)
|
|
||||||
|
|
||||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
|
|
||||||
buildTrainingData, derivedXGBParamMap)
|
|
||||||
|
|
||||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
|
||||||
model.setSummary(summary)
|
|
||||||
model
|
|
||||||
}
|
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
|
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
|
||||||
|
private val _uid = Identifiable.randomUID("xgbc")
|
||||||
override def load(path: String): XGBoostClassifier = super.load(path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class XGBoostClassificationModel private[ml](
|
class XGBoostClassificationModel private[ml](
|
||||||
override val uid: String,
|
val uid: String,
|
||||||
override val numClasses: Int,
|
val numClasses: Int,
|
||||||
private[scala] val _booster: Booster)
|
val nativeBooster: Booster,
|
||||||
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
val summary: Option[XGBoostTrainingSummary] = None
|
||||||
with XGBoostClassifierParams with InferenceParams
|
) extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
||||||
with MLWritable with Serializable {
|
with XGBoostModel[XGBoostClassificationModel]
|
||||||
|
with XGBProbabilisticClassifierParams[XGBoostClassificationModel] {
|
||||||
|
|
||||||
import XGBoostClassificationModel._
|
def this(uid: String) = this(uid, 0, null)
|
||||||
|
|
||||||
// only called in copy()
|
override protected[spark] def postTransform(dataset: Dataset[_],
|
||||||
def this(uid: String) = this(uid, 2, null)
|
pred: PredictedColumns): Dataset[_] = {
|
||||||
|
var output = super.postTransform(dataset, pred)
|
||||||
|
|
||||||
/**
|
// Always use probability col to get the prediction
|
||||||
* Get the native booster instance of this model.
|
|
||||||
* This is used to call low-level APIs on native booster, such as "getFeatureScore".
|
|
||||||
*/
|
|
||||||
def nativeBooster: Booster = _booster
|
|
||||||
|
|
||||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
|
||||||
|
if (getObjective == "multi:softmax") {
|
||||||
/**
|
|
||||||
* Returns summary (e.g. train/test objective history) of model on the
|
|
||||||
* training set. An exception is thrown if no summary is available.
|
|
||||||
*/
|
|
||||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
|
||||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
|
||||||
trainingSummary = Some(summary)
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
|
||||||
|
|
||||||
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
|
||||||
|
|
||||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
|
||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
|
||||||
|
|
||||||
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
|
|
||||||
allowNonZeroForMissing,
|
|
||||||
value
|
|
||||||
)
|
|
||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Single instance prediction.
|
|
||||||
* Note: The performance is not ideal, use it carefully!
|
|
||||||
*/
|
|
||||||
override def predict(features: Vector): Double = {
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val dm = new DMatrix(processMissingValues(
|
|
||||||
Iterator(features.asXGB),
|
|
||||||
$(missing),
|
|
||||||
$(allowNonZeroForMissing)
|
|
||||||
))
|
|
||||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
|
||||||
if (numClasses == 2) {
|
|
||||||
math.round(probability(0))
|
|
||||||
} else {
|
|
||||||
probability2prediction(Vectors.dense(probability))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Actually we don't use this function at all, to make it pass compiler check.
|
|
||||||
override def predictRaw(features: Vector): Vector = {
|
|
||||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Actually we don't use this function at all, to make it pass compiler check.
|
|
||||||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
|
||||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] def produceResultIterator(
|
|
||||||
originalRowItr: Iterator[Row],
|
|
||||||
rawPredictionItr: Iterator[Row],
|
|
||||||
probabilityItr: Iterator[Row],
|
|
||||||
predLeafItr: Iterator[Row],
|
|
||||||
predContribItr: Iterator[Row]): Iterator[Row] = {
|
|
||||||
// the following implementation is to be improved
|
|
||||||
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
|
||||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
|
||||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
|
|
||||||
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
|
|
||||||
contribs: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
|
|
||||||
contribs.toSeq)
|
|
||||||
}
|
|
||||||
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
|
||||||
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
|
||||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
|
|
||||||
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
|
|
||||||
}
|
|
||||||
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
|
||||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
|
||||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
|
|
||||||
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
|
|
||||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
|
|
||||||
Array[Iterator[Row]] = {
|
|
||||||
val rawPredictionItr = {
|
|
||||||
booster.predict(dm, outPutMargin = true, $(treeLimit)).
|
|
||||||
map(Row(_)).iterator
|
|
||||||
}
|
|
||||||
val probabilityItr = {
|
|
||||||
booster.predict(dm, outPutMargin = false, $(treeLimit)).
|
|
||||||
map(Row(_)).iterator
|
|
||||||
}
|
|
||||||
val predLeafItr = {
|
|
||||||
if (isDefined(leafPredictionCol)) {
|
|
||||||
booster.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val predContribItr = {
|
|
||||||
if (isDefined(contribPredictionCol)) {
|
|
||||||
booster.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
|
||||||
if (isFeaturesColSet(schema)) {
|
|
||||||
// User has vectorized the features into VectorUDT.
|
|
||||||
super.transformSchema(schema)
|
|
||||||
} else {
|
|
||||||
transformSchemaWithFeaturesCols(false, schema)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
|
||||||
PreXGBoost.transformSchema(this, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
|
||||||
transformSchema(dataset.schema, logging = true)
|
|
||||||
if (isDefined(thresholds)) {
|
|
||||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
|
||||||
".transform() called with non-matching numClasses and thresholds.length." +
|
|
||||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output selected columns only.
|
|
||||||
// This is a bit complicated since it tries to avoid repeated computation.
|
|
||||||
var outputData = PreXGBoost.transformDataset(this, dataset)
|
|
||||||
var numColsOutput = 0
|
|
||||||
|
|
||||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
|
||||||
val raw = rawPrediction.map(_.toDouble).toArray
|
|
||||||
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
|
|
||||||
Vectors.dense(rawPredictions)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($(rawPredictionCol).nonEmpty) {
|
|
||||||
outputData = outputData
|
|
||||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
|
||||||
numColsOutput += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getObjective.equals("multi:softmax")) {
|
|
||||||
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
|
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
|
||||||
// Instead, the probability column will be filled with real prediction
|
// Instead, the probability column will be filled with real prediction
|
||||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
probability(0)
|
probability(0)
|
||||||
}
|
}
|
||||||
if ($(predictionCol).nonEmpty) {
|
output = output.withColumn(getPredictionCol, predictUDF(col(TMP_TRANSFORMED_COL)))
|
||||||
outputData = outputData
|
} else {
|
||||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
val predCol = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
numColsOutput += 1
|
val prob = probability.map(_.toDouble).toArray
|
||||||
|
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||||
|
probability2prediction(Vectors.dense(probabilities))
|
||||||
|
}
|
||||||
|
output = output.withColumn(getPredictionCol, predCol(col(TMP_TRANSFORMED_COL)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
if (isDefinedNonEmpty(probabilityCol) && pred.predTmp) {
|
||||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
val prob = probability.map(_.toDouble).toArray
|
val prob = probability.map(_.toDouble).toArray
|
||||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||||
Vectors.dense(probabilities)
|
Vectors.dense(probabilities)
|
||||||
}
|
}
|
||||||
if ($(probabilityCol).nonEmpty) {
|
output = output.withColumn(TMP_TRANSFORMED_COL,
|
||||||
outputData = outputData
|
probabilityUDF(output.col(TMP_TRANSFORMED_COL)))
|
||||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
.withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol)
|
||||||
numColsOutput += 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
if (pred.predRaw) {
|
||||||
// From XGBoost probability to MLlib prediction
|
val rawPredictionUDF = udf { raw: mutable.WrappedArray[Float] =>
|
||||||
val prob = probability.map(_.toDouble).toArray
|
val rawF = raw.map(_.toDouble).toArray
|
||||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
val rawPredictions = if (numClasses == 2) Array(-rawF(0), rawF(0)) else rawF
|
||||||
probability2prediction(Vectors.dense(probabilities))
|
Vectors.dense(rawPredictions)
|
||||||
}
|
|
||||||
if ($(predictionCol).nonEmpty) {
|
|
||||||
outputData = outputData
|
|
||||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
|
||||||
numColsOutput += 1
|
|
||||||
}
|
}
|
||||||
|
output = output.withColumn(getRawPredictionCol,
|
||||||
|
rawPredictionUDF(output.col(getRawPredictionCol)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numColsOutput == 0) {
|
output.drop(TMP_TRANSFORMED_COL)
|
||||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
|
||||||
" since no output columns were set.")
|
|
||||||
}
|
|
||||||
outputData
|
|
||||||
.toDF
|
|
||||||
.drop(col(_rawPredictionCol))
|
|
||||||
.drop(col(_probabilityCol))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||||
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
|
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses,
|
||||||
newModel.setSummary(summary).setParent(parent)
|
nativeBooster, summary), extra)
|
||||||
|
newModel.setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def write: MLWriter =
|
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||||
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
|
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||||
|
}
|
||||||
|
|
||||||
|
override def predictRaw(features: Vector): Vector =
|
||||||
|
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
||||||
|
|
||||||
private[scala] val _rawPredictionCol = "_rawPrediction"
|
override def read: MLReader[XGBoostClassificationModel] = new ModelReader
|
||||||
private[scala] val _probabilityCol = "_probability"
|
|
||||||
|
|
||||||
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
|
||||||
|
|
||||||
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
|
||||||
|
|
||||||
private[XGBoostClassificationModel]
|
|
||||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
|
|
||||||
extends XGBoostWriter {
|
|
||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
|
||||||
// Save metadata and Params
|
|
||||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
|
||||||
|
|
||||||
// Save model data
|
|
||||||
val dataPath = new Path(path, "data").toString
|
|
||||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
|
||||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
|
||||||
instance._booster.saveModel(outputStream, getModelFormat())
|
|
||||||
outputStream.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
|
|
||||||
|
|
||||||
/** Checked against metadata when loading model */
|
|
||||||
private val className = classOf[XGBoostClassificationModel].getName
|
|
||||||
|
|
||||||
|
private class ModelReader extends XGBoostModelReader[XGBoostClassificationModel] {
|
||||||
override def load(path: String): XGBoostClassificationModel = {
|
override def load(path: String): XGBoostClassificationModel = {
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
val xgbModel = loadBooster(path)
|
||||||
|
val meta = SparkUtils.loadMetadata(path, sc)
|
||||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
implicit val format = DefaultFormats
|
||||||
|
val numClasses = (meta.params \ "numClass").extractOpt[Int].getOrElse(2)
|
||||||
val dataPath = new Path(path, "data").toString
|
val model = new XGBoostClassificationModel(meta.uid, numClasses, xgbModel)
|
||||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
meta.getAndSetParams(model)
|
||||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
|
||||||
val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
|
|
||||||
val booster = SXGBoost.loadModel(dataInStream)
|
|
||||||
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
|
||||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,641 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.util.ServiceLoader
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
import scala.jdk.CollectionConverters._
|
||||||
|
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
|
import org.apache.spark.ml.functions.array_to_vector
|
||||||
|
import org.apache.spark.ml.linalg.{SparseVector, Vector}
|
||||||
|
import org.apache.spark.ml.param.{Param, ParamMap}
|
||||||
|
import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
|
||||||
|
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql._
|
||||||
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hold the column index
|
||||||
|
*/
|
||||||
|
private[spark] case class ColumnIndices(
|
||||||
|
labelId: Int,
|
||||||
|
featureId: Option[Int], // the feature type is VectorUDT or Array
|
||||||
|
featureIds: Option[Seq[Int]], // the feature type is columnar
|
||||||
|
weightId: Option[Int],
|
||||||
|
marginId: Option[Int],
|
||||||
|
groupId: Option[Int])
|
||||||
|
|
||||||
|
private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]] {
|
||||||
|
|
||||||
|
private var dataset: Option[Dataset[_]] = None
|
||||||
|
|
||||||
|
def setEvalDataset(ds: Dataset[_]): T = {
|
||||||
|
this.dataset = Some(ds)
|
||||||
|
this.asInstanceOf[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
def getEvalDataset(): Option[Dataset[_]] = {
|
||||||
|
this.dataset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] trait PluginMixin {
|
||||||
|
// Find the XGBoostPlugin by ServiceLoader
|
||||||
|
private val plugin: Option[XGBoostPlugin] = {
|
||||||
|
val classLoader = Option(Thread.currentThread().getContextClassLoader)
|
||||||
|
.getOrElse(getClass.getClassLoader)
|
||||||
|
|
||||||
|
val serviceLoader = ServiceLoader.load(classOf[XGBoostPlugin], classLoader)
|
||||||
|
|
||||||
|
// For now, we only trust GpuXGBoostPlugin.
|
||||||
|
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
|
||||||
|
"ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin")).toList match {
|
||||||
|
case Nil => None
|
||||||
|
case head :: Nil =>
|
||||||
|
Some(head)
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Visible for testing */
|
||||||
|
protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin
|
||||||
|
|
||||||
|
protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
|
||||||
|
plugin.map(_.isEnabled(dataset)).getOrElse(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] trait XGBoostEstimator[
|
||||||
|
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
|
||||||
|
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
|
||||||
|
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable
|
||||||
|
with PluginMixin {
|
||||||
|
|
||||||
|
protected val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cast the field in schema to the desired data type.
|
||||||
|
*
|
||||||
|
* @param dataset the input dataset
|
||||||
|
* @param name which column will be casted to float if possible.
|
||||||
|
* @param targetType the targetd data type
|
||||||
|
* @return Dataset
|
||||||
|
*/
|
||||||
|
private[spark] def castIfNeeded(schema: StructType,
|
||||||
|
name: String,
|
||||||
|
targetType: DataType = FloatType): Column = {
|
||||||
|
if (!(schema(name).dataType == targetType)) {
|
||||||
|
val meta = schema(name).metadata
|
||||||
|
col(name).as(name, meta).cast(targetType)
|
||||||
|
} else {
|
||||||
|
col(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Repartition the dataset to the numWorkers if needed.
|
||||||
|
*
|
||||||
|
* @param dataset to be repartition
|
||||||
|
* @return the repartitioned dataset
|
||||||
|
*/
|
||||||
|
private[spark] def repartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
|
||||||
|
val numPartitions = dataset.rdd.getNumPartitions
|
||||||
|
if (getForceRepartition || getNumWorkers != numPartitions) {
|
||||||
|
dataset.repartition(getNumWorkers)
|
||||||
|
} else {
|
||||||
|
dataset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the columns indices.
|
||||||
|
*/
|
||||||
|
private[spark] def buildColumnIndices(schema: StructType): ColumnIndices = {
|
||||||
|
// Get feature id(s)
|
||||||
|
val (featureIds: Option[Seq[Int]], featureId: Option[Int]) =
|
||||||
|
if (getFeaturesCols.length != 0) {
|
||||||
|
(Some(getFeaturesCols.map(schema.fieldIndex).toSeq), None)
|
||||||
|
} else {
|
||||||
|
(None, Some(schema.fieldIndex(getFeaturesCol)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// function to get the column id according to the parameter
|
||||||
|
def columnId(param: Param[String]): Option[Int] = {
|
||||||
|
if (isDefinedNonEmpty(param)) {
|
||||||
|
Some(schema.fieldIndex($(param)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handle for group
|
||||||
|
val groupId: Option[Int] = this match {
|
||||||
|
case p: HasGroupCol => columnId(p.groupCol)
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
|
|
||||||
|
ColumnIndices(
|
||||||
|
labelId = columnId(labelCol).get,
|
||||||
|
featureId = featureId,
|
||||||
|
featureIds = featureIds,
|
||||||
|
columnId(weightCol),
|
||||||
|
columnId(baseMarginCol),
|
||||||
|
groupId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preprocess the dataset to meet the xgboost input requirement
|
||||||
|
*
|
||||||
|
* @param dataset
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
|
||||||
|
|
||||||
|
// Columns to be selected for XGBoost training
|
||||||
|
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
|
||||||
|
val schema = dataset.schema
|
||||||
|
|
||||||
|
def selectCol(c: Param[String], targetType: DataType) = {
|
||||||
|
if (isDefinedNonEmpty(c)) {
|
||||||
|
// Validation col should be a boolean column.
|
||||||
|
if (c == featuresCol) {
|
||||||
|
selectedCols.append(col($(c)))
|
||||||
|
} else {
|
||||||
|
selectedCols.append(castIfNeeded(schema, $(c), targetType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(p => selectCol(p, FloatType))
|
||||||
|
this match {
|
||||||
|
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
|
||||||
|
|
||||||
|
val columnIndices = buildColumnIndices(input.schema)
|
||||||
|
(input, columnIndices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** visible for testing */
|
||||||
|
private[spark] def toXGBLabeledPoint(dataset: Dataset[_],
|
||||||
|
columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
|
||||||
|
val isSetMissing = isSet(missing)
|
||||||
|
dataset.toDF().rdd.map { row =>
|
||||||
|
val features = row.getAs[Vector](columnIndexes.featureId.get)
|
||||||
|
val label = row.getFloat(columnIndexes.labelId)
|
||||||
|
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
|
||||||
|
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
|
||||||
|
val group = columnIndexes.groupId.map(row.getInt).getOrElse(-1)
|
||||||
|
// To make "0" meaningful, we convert sparse vector if possible to dense to create DMatrix.
|
||||||
|
features match {
|
||||||
|
case _: SparseVector => if (!isSetMissing) {
|
||||||
|
throw new IllegalArgumentException("We've detected sparse vectors in the dataset that " +
|
||||||
|
"need conversion to dense format. However, we can't assume 0 for missing values as " +
|
||||||
|
"it may be meaningful. Please specify the missing value explicitly to ensure " +
|
||||||
|
"accurate data representation for analysis.")
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
val values = features.toArray.map(_.toFloat)
|
||||||
|
XGBLabeledPoint(label, values.length, null, values, weight, group, baseMargin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert the dataframe to RDD, visible to testing
|
||||||
|
*
|
||||||
|
* @param dataset
|
||||||
|
* @param columnsOrder the order of columns including weight/group/base margin ...
|
||||||
|
* @return RDD[Watches]
|
||||||
|
*/
|
||||||
|
private[spark] def toRdd(dataset: Dataset[_],
|
||||||
|
columnIndices: ColumnIndices): RDD[Watches] = {
|
||||||
|
val trainRDD = toXGBLabeledPoint(dataset, columnIndices)
|
||||||
|
|
||||||
|
val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames)
|
||||||
|
val featureTypes = if (getFeatureTypes.isEmpty) None else Some(getFeatureTypes)
|
||||||
|
|
||||||
|
val missing = getMissing
|
||||||
|
|
||||||
|
// Transform the labeledpoint to get margins/groups and build DMatrix
|
||||||
|
// TODO support basemargin for multiclassification
|
||||||
|
// TODO and optimization, move it into JNI.
|
||||||
|
def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = {
|
||||||
|
val dmatrix = if (columnIndices.marginId.isDefined || columnIndices.groupId.isDefined) {
|
||||||
|
val margins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val groups = new mutable.ArrayBuilder.ofInt
|
||||||
|
val groupWeights = new mutable.ArrayBuilder.ofFloat
|
||||||
|
var prevGroup = -101010
|
||||||
|
var prevWeight = -1.0f
|
||||||
|
var groupSize = 0
|
||||||
|
val transformedIter = iter.map { labeledPoint =>
|
||||||
|
if (columnIndices.marginId.isDefined) {
|
||||||
|
margins += labeledPoint.baseMargin
|
||||||
|
}
|
||||||
|
if (columnIndices.groupId.isDefined) {
|
||||||
|
if (prevGroup != labeledPoint.group) {
|
||||||
|
// starting with new group
|
||||||
|
if (prevGroup != -101010) {
|
||||||
|
// write the previous group
|
||||||
|
groups += groupSize
|
||||||
|
groupWeights += prevWeight
|
||||||
|
}
|
||||||
|
groupSize = 1
|
||||||
|
prevWeight = labeledPoint.weight
|
||||||
|
prevGroup = labeledPoint.group
|
||||||
|
} else {
|
||||||
|
// for the same group
|
||||||
|
if (prevWeight != labeledPoint.weight) {
|
||||||
|
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||||
|
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||||
|
}
|
||||||
|
groupSize = groupSize + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
labeledPoint
|
||||||
|
}
|
||||||
|
val dm = new DMatrix(transformedIter, null, missing)
|
||||||
|
columnIndices.marginId.foreach(_ => dm.setBaseMargin(margins.result()))
|
||||||
|
if (columnIndices.groupId.isDefined) {
|
||||||
|
if (prevGroup != -101011) {
|
||||||
|
// write the last group
|
||||||
|
groups += groupSize
|
||||||
|
groupWeights += prevWeight
|
||||||
|
}
|
||||||
|
dm.setGroup(groups.result())
|
||||||
|
// The new DMatrix() will set the weights for each instance. But ranking requires
|
||||||
|
// 1 weight for each group, so need to reset the weight.
|
||||||
|
// This is definitely optimized by moving setting group/base margin into JNI.
|
||||||
|
dm.setWeight(groupWeights.result())
|
||||||
|
}
|
||||||
|
dm
|
||||||
|
} else {
|
||||||
|
new DMatrix(iter, null, missing)
|
||||||
|
}
|
||||||
|
featureTypes.foreach(dmatrix.setFeatureTypes)
|
||||||
|
featureNames.foreach(dmatrix.setFeatureNames)
|
||||||
|
dmatrix
|
||||||
|
}
|
||||||
|
|
||||||
|
getEvalDataset().map { eval =>
|
||||||
|
val (evalDf, _) = preprocess(eval)
|
||||||
|
val evalRDD = toXGBLabeledPoint(evalDf, columnIndices)
|
||||||
|
trainRDD.zipPartitions(evalRDD) { (left, right) =>
|
||||||
|
new Iterator[Watches] {
|
||||||
|
override def hasNext: Boolean = left.hasNext
|
||||||
|
override def next(): Watches = {
|
||||||
|
val trainDMatrix = buildDMatrix(left)
|
||||||
|
val evalDMatrix = buildDMatrix(right)
|
||||||
|
new Watches(Array(trainDMatrix, evalDMatrix),
|
||||||
|
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.getOrElse(
|
||||||
|
trainRDD.mapPartitions { iter =>
|
||||||
|
new Iterator[Watches] {
|
||||||
|
override def hasNext: Boolean = iter.hasNext
|
||||||
|
override def next(): Watches = {
|
||||||
|
val dm = buildDMatrix(iter)
|
||||||
|
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def createModel(booster: Booster, summary: XGBoostTrainingSummary): M
|
||||||
|
|
||||||
|
private[spark] def getRuntimeParameters(isLocal: Boolean): RuntimeParams = {
|
||||||
|
val runOnGpu = if (getDevice != "cpu" || getTreeMethod == "gpu_hist") true else false
|
||||||
|
RuntimeParams(
|
||||||
|
getNumWorkers,
|
||||||
|
getNumRound,
|
||||||
|
TrackerConf(getRabitTrackerTimeout, getRabitTrackerHostIp, getRabitTrackerPort),
|
||||||
|
getNumEarlyStoppingRounds,
|
||||||
|
getDevice,
|
||||||
|
isLocal,
|
||||||
|
runOnGpu,
|
||||||
|
Option(getCustomObj),
|
||||||
|
Option(getCustomEval)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||||
|
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||||
|
* via conf `xgboost.spark.ignoreSsl`.
|
||||||
|
*/
|
||||||
|
private def validateSparkSslConf(spark: SparkSession): Unit = {
|
||||||
|
|
||||||
|
val sparkSslEnabled = spark.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean
|
||||||
|
val xgbIgnoreSsl = spark.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean
|
||||||
|
|
||||||
|
if (sparkSslEnabled) {
|
||||||
|
if (xgbIgnoreSsl) {
|
||||||
|
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||||
|
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||||
|
} else {
|
||||||
|
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||||
|
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||||
|
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||||
|
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate the parameters before training, throw exception if possible
|
||||||
|
*/
|
||||||
|
protected[spark] def validate(dataset: Dataset[_]): Unit = {
|
||||||
|
validateSparkSslConf(dataset.sparkSession)
|
||||||
|
val schema = dataset.schema
|
||||||
|
SparkUtils.checkNumericType(schema, $(labelCol))
|
||||||
|
if (isDefinedNonEmpty(weightCol)) {
|
||||||
|
SparkUtils.checkNumericType(schema, $(weightCol))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDefinedNonEmpty(baseMarginCol)) {
|
||||||
|
SparkUtils.checkNumericType(schema, $(baseMarginCol))
|
||||||
|
}
|
||||||
|
|
||||||
|
val taskCpus = dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", 1)
|
||||||
|
if (isDefined(nthread)) {
|
||||||
|
require(getNthread <= taskCpus,
|
||||||
|
s"the nthread configuration ($getNthread) must be no larger than " +
|
||||||
|
s"spark.task.cpus ($taskCpus)")
|
||||||
|
} else {
|
||||||
|
setNthread(taskCpus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def train(dataset: Dataset[_]): M = {
|
||||||
|
validate(dataset)
|
||||||
|
|
||||||
|
val rdd = if (isPluginEnabled(dataset)) {
|
||||||
|
getPlugin.get.buildRddWatches(this, dataset)
|
||||||
|
} else {
|
||||||
|
val (input, columnIndexes) = preprocess(dataset)
|
||||||
|
toRdd(input, columnIndexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
val xgbParams = getXGBoostParams
|
||||||
|
|
||||||
|
val runtimeParams = getRuntimeParameters(dataset.sparkSession.sparkContext.isLocal)
|
||||||
|
|
||||||
|
val (booster, metrics) = XGBoost.train(rdd, runtimeParams, xgbParams)
|
||||||
|
|
||||||
|
val summary = XGBoostTrainingSummary(metrics)
|
||||||
|
copyValues(createModel(booster, summary))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicate what to be predicted
|
||||||
|
*
|
||||||
|
* @param predLeaf predicate leaf
|
||||||
|
* @param predContrib predicate contribution
|
||||||
|
* @param predRaw predicate raw
|
||||||
|
* @param predTmp predicate probability for classification, and raw for regression
|
||||||
|
*/
|
||||||
|
private[spark] case class PredictedColumns(
|
||||||
|
predLeaf: Boolean,
|
||||||
|
predContrib: Boolean,
|
||||||
|
predRaw: Boolean,
|
||||||
|
predTmp: Boolean)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost base model
|
||||||
|
*/
|
||||||
|
private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable
|
||||||
|
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin {
|
||||||
|
|
||||||
|
protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): M = defaultCopy(extra).asInstanceOf[M]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the native XGBoost Booster
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def nativeBooster: Booster
|
||||||
|
|
||||||
|
def summary: Option[XGBoostTrainingSummary]
|
||||||
|
|
||||||
|
protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = {
|
||||||
|
var output = dataset
|
||||||
|
// Convert leaf/contrib to the vector from array
|
||||||
|
if (pred.predLeaf) {
|
||||||
|
output = output.withColumn(getLeafPredictionCol,
|
||||||
|
array_to_vector(output.col(getLeafPredictionCol)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pred.predContrib) {
|
||||||
|
output = output.withColumn(getContribPredictionCol,
|
||||||
|
array_to_vector(output.col(getContribPredictionCol)))
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preprocess the schema before transforming.
|
||||||
|
*
|
||||||
|
* @return the transformed schema and the
|
||||||
|
*/
|
||||||
|
private[spark] def preprocess(dataset: Dataset[_]): (StructType, PredictedColumns) = {
|
||||||
|
// Be careful about the order of columns
|
||||||
|
var schema = dataset.schema
|
||||||
|
|
||||||
|
/** If the parameter is defined, add it to schema and turn true */
|
||||||
|
def addToSchema(param: Param[String], colName: Option[String] = None): Boolean = {
|
||||||
|
if (isDefinedNonEmpty(param)) {
|
||||||
|
val name = colName.getOrElse($(param))
|
||||||
|
schema = schema.add(StructField(name, ArrayType(FloatType)))
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val predLeaf = addToSchema(leafPredictionCol)
|
||||||
|
val predContrib = addToSchema(contribPredictionCol)
|
||||||
|
|
||||||
|
var predRaw = false
|
||||||
|
// For classification case, the transformed col is probability,
|
||||||
|
// while for others, it's the prediction value.
|
||||||
|
var predTmp = false
|
||||||
|
this match {
|
||||||
|
case p: XGBProbabilisticClassifierParams[_] => // classification case
|
||||||
|
predRaw = addToSchema(p.rawPredictionCol)
|
||||||
|
predTmp = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL))
|
||||||
|
|
||||||
|
if (isDefinedNonEmpty(predictionCol)) {
|
||||||
|
// Let's use transformed col to calculate the prediction
|
||||||
|
if (!predTmp) {
|
||||||
|
// Add the transformed col for prediction
|
||||||
|
schema = schema.add(
|
||||||
|
StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType)))
|
||||||
|
predTmp = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
// Rename TMP_TRANSFORMED_COL to prediction in the postTransform.
|
||||||
|
predTmp = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL))
|
||||||
|
}
|
||||||
|
(schema, PredictedColumns(predLeaf, predContrib, predRaw, predTmp))
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Predict */
|
||||||
|
private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns,
|
||||||
|
batchRow: Iterator[Row]): Seq[Row] = {
|
||||||
|
var tmpOut = batchRow.toSeq.map(_.toSeq)
|
||||||
|
val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
|
||||||
|
case (a, b) => a ++ Seq(b)
|
||||||
|
}
|
||||||
|
if (pred.predLeaf) {
|
||||||
|
tmpOut = zip(tmpOut, booster.predictLeaf(dm))
|
||||||
|
}
|
||||||
|
if (pred.predContrib) {
|
||||||
|
tmpOut = zip(tmpOut, booster.predictContrib(dm))
|
||||||
|
}
|
||||||
|
if (pred.predRaw) {
|
||||||
|
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true))
|
||||||
|
}
|
||||||
|
if (pred.predTmp) {
|
||||||
|
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false))
|
||||||
|
}
|
||||||
|
tmpOut.map(Row.fromSeq)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
|
|
||||||
|
if (getPlugin.isDefined) {
|
||||||
|
return getPlugin.get.transform(this, dataset)
|
||||||
|
}
|
||||||
|
|
||||||
|
val (schema, pred) = preprocess(dataset)
|
||||||
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)
|
||||||
|
// TODO configurable
|
||||||
|
val inferBatchSize = 32 << 10
|
||||||
|
// Broadcast the booster to each executor.
|
||||||
|
val featureName = getFeaturesCol
|
||||||
|
val missing = getMissing
|
||||||
|
|
||||||
|
val output = dataset.toDF().mapPartitions { rowIter =>
|
||||||
|
rowIter.grouped(inferBatchSize).flatMap { batchRow =>
|
||||||
|
val features = batchRow.iterator.map(row => row.getAs[Vector](
|
||||||
|
row.fieldIndex(featureName)))
|
||||||
|
// DMatrix used to prediction
|
||||||
|
val dm = new DMatrix(features.map(_.asXGB), null, missing)
|
||||||
|
try {
|
||||||
|
predictInternal(bBooster.value, dm, pred, batchRow.toIterator)
|
||||||
|
} finally {
|
||||||
|
dm.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}(Encoders.row(schema))
|
||||||
|
bBooster.unpersist(blocking = false)
|
||||||
|
postTransform(output, pred).toDF()
|
||||||
|
}
|
||||||
|
|
||||||
|
override def write: MLWriter = new XGBoostModelWriter(this)
|
||||||
|
|
||||||
|
protected def predictSingleInstance(features: Vector): Array[Float] = {
|
||||||
|
if (nativeBooster == null) {
|
||||||
|
throw new IllegalArgumentException("The model has not been trained")
|
||||||
|
}
|
||||||
|
val dm = new DMatrix(Iterator(features.asXGB), null, getMissing)
|
||||||
|
nativeBooster.predict(data = dm)(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class to write the model
|
||||||
|
*
|
||||||
|
* @param instance model to be written
|
||||||
|
*/
|
||||||
|
private[spark] class XGBoostModelWriter(instance: XGBoostModel[_]) extends MLWriter {
|
||||||
|
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
if (Option(instance.nativeBooster).isEmpty) {
|
||||||
|
throw new RuntimeException("The XGBoost model has not been trained")
|
||||||
|
}
|
||||||
|
SparkUtils.saveMetadata(instance, path, sc)
|
||||||
|
|
||||||
|
// Save model data
|
||||||
|
val dataPath = new Path(path, "data").toString
|
||||||
|
val internalPath = new Path(dataPath, "model")
|
||||||
|
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||||
|
val format = optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||||
|
try {
|
||||||
|
instance.nativeBooster.saveModel(outputStream, format)
|
||||||
|
} finally {
|
||||||
|
outputStream.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] abstract class XGBoostModelReader[M <: XGBoostModel[M]] extends MLReader[M] {
|
||||||
|
|
||||||
|
protected def loadBooster(path: String): Booster = {
|
||||||
|
val dataPath = new Path(path, "data").toString
|
||||||
|
val internalPath = new Path(dataPath, "model")
|
||||||
|
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||||
|
try {
|
||||||
|
SXGBoost.loadModel(dataInStream)
|
||||||
|
} finally {
|
||||||
|
dataInStream.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trait for Ranker and Regressor Model
|
||||||
|
private[spark] trait RankerRegressorBaseModel[M <: XGBoostModel[M]] extends XGBoostModel[M] {
|
||||||
|
|
||||||
|
override protected[spark] def postTransform(dataset: Dataset[_],
|
||||||
|
pred: PredictedColumns): Dataset[_] = {
|
||||||
|
var output = super.postTransform(dataset, pred)
|
||||||
|
if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
|
||||||
|
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
||||||
|
originalPrediction(0).toDouble
|
||||||
|
}
|
||||||
|
output = output
|
||||||
|
.withColumn($(predictionCol), predictUDF(col(TMP_TRANSFORMED_COL)))
|
||||||
|
.drop(TMP_TRANSFORMED_COL)
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.Serializable
|
||||||
|
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
|
trait XGBoostPlugin extends Serializable {
|
||||||
|
/**
|
||||||
|
* Whether the plugin is enabled or not, if not enabled, fallback
|
||||||
|
* to the regular CPU pipeline
|
||||||
|
*
|
||||||
|
* @param dataset the input dataset
|
||||||
|
* @return Boolean
|
||||||
|
*/
|
||||||
|
def isEnabled(dataset: Dataset[_]): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert Dataset to RDD[Watches] which will be fed into XGBoost
|
||||||
|
*
|
||||||
|
* @param estimator which estimator to be handled.
|
||||||
|
* @param dataset to be converted.
|
||||||
|
* @return RDD[Watches]
|
||||||
|
*/
|
||||||
|
def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
|
||||||
|
estimator: XGBoostEstimator[T, M],
|
||||||
|
dataset: Dataset[_]): RDD[Watches]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transform the dataset
|
||||||
|
*/
|
||||||
|
def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], dataset: Dataset[_]): DataFrame
|
||||||
|
|
||||||
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,405 +16,90 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.{Iterator, mutable}
|
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
|
||||||
import org.apache.hadoop.fs.Path
|
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml._
|
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.xgboost.SparkUtils
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
|
|
||||||
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
import org.apache.spark.sql.types.StructType
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor._uid
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
|
||||||
|
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
|
|
||||||
class XGBoostRegressor (
|
class XGBoostRegressor(override val uid: String,
|
||||||
override val uid: String,
|
|
||||||
private val xgboostParams: Map[String, Any])
|
private val xgboostParams: Map[String, Any])
|
||||||
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
|
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
|
||||||
with XGBoostRegressorParams with DefaultParamsWritable {
|
with XGBoostEstimator[XGBoostRegressor, XGBoostRegressionModel] {
|
||||||
|
|
||||||
def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
|
def this() = this(_uid, Map[String, Any]())
|
||||||
|
|
||||||
def this(uid: String) = this(uid, Map[String, Any]())
|
def this(uid: String) = this(uid, Map[String, Any]())
|
||||||
|
|
||||||
def this(xgboostParams: Map[String, Any]) = this(
|
def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
|
||||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
|
||||||
|
|
||||||
XGBoost2MLlibParams(xgboostParams)
|
xgboost2SparkParams(xgboostParams)
|
||||||
|
|
||||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
/**
|
||||||
|
* Validate the parameters before training, throw exception if possible
|
||||||
|
*/
|
||||||
|
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
|
||||||
|
super.validate(dataset)
|
||||||
|
|
||||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
// If the objective is set explicitly, it must be in REGRESSION_OBJS
|
||||||
|
if (isSet(objective)) {
|
||||||
def setGroupCol(value: String): this.type = set(groupCol, value)
|
val tmpObj = getObjective
|
||||||
|
require(REGRESSION_OBJS.contains(tmpObj),
|
||||||
// setters for general params
|
s"Wrong objective for XGBoostRegressor, supported objs: ${REGRESSION_OBJS.mkString(",")}")
|
||||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
|
||||||
|
|
||||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
|
||||||
|
|
||||||
def setNthread(value: Int): this.type = set(nthread, value)
|
|
||||||
|
|
||||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
|
||||||
|
|
||||||
def setSilent(value: Int): this.type = set(silent, value)
|
|
||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
|
||||||
|
|
||||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
|
||||||
|
|
||||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
|
||||||
|
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
|
||||||
|
|
||||||
def setEta(value: Double): this.type = set(eta, value)
|
|
||||||
|
|
||||||
def setGamma(value: Double): this.type = set(gamma, value)
|
|
||||||
|
|
||||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
|
||||||
|
|
||||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
|
||||||
|
|
||||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
|
||||||
|
|
||||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
|
||||||
|
|
||||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
|
||||||
|
|
||||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
|
||||||
|
|
||||||
def setLambda(value: Double): this.type = set(lambda, value)
|
|
||||||
|
|
||||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
|
||||||
|
|
||||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
|
||||||
|
|
||||||
def setDevice(value: String): this.type = set(device, value)
|
|
||||||
|
|
||||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
|
||||||
|
|
||||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
|
||||||
|
|
||||||
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
|
|
||||||
|
|
||||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
|
||||||
|
|
||||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
|
||||||
|
|
||||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
|
||||||
|
|
||||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
|
||||||
|
|
||||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
|
||||||
|
|
||||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
|
||||||
|
|
||||||
// setters for learning params
|
|
||||||
def setObjective(value: String): this.type = set(objective, value)
|
|
||||||
|
|
||||||
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
|
||||||
|
|
||||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
|
||||||
|
|
||||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
|
||||||
|
|
||||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
|
||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
|
||||||
|
|
||||||
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
|
|
||||||
set(maximizeEvaluationMetrics, value)
|
|
||||||
|
|
||||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
|
||||||
|
|
||||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
|
||||||
|
|
||||||
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
|
|
||||||
allowNonZeroForMissing,
|
|
||||||
value
|
|
||||||
)
|
|
||||||
|
|
||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
|
||||||
set(singlePrecisionHistogram, value)
|
|
||||||
|
|
||||||
def setFeatureNames(value: Array[String]): this.type =
|
|
||||||
set(featureNames, value)
|
|
||||||
|
|
||||||
def setFeatureTypes(value: Array[String]): this.type =
|
|
||||||
set(featureTypes, value)
|
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
|
||||||
private def setupDefaultEvalMetric(): String = {
|
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
|
||||||
if ($(objective).startsWith("rank")) {
|
|
||||||
"map"
|
|
||||||
} else {
|
|
||||||
"rmse"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
override protected def createModel(
|
||||||
if (isFeaturesColSet(schema)) {
|
booster: Booster,
|
||||||
// User has vectorized the features into VectorUDT.
|
summary: XGBoostTrainingSummary): XGBoostRegressionModel = {
|
||||||
super.transformSchema(schema)
|
new XGBoostRegressionModel(uid, booster, Option(summary))
|
||||||
} else {
|
|
||||||
transformSchemaWithFeaturesCols(false, schema)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override protected def validateAndTransformSchema(
|
||||||
PreXGBoost.transformSchema(this, schema)
|
schema: StructType,
|
||||||
}
|
fitting: Boolean,
|
||||||
|
featuresDataType: DataType): StructType =
|
||||||
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
|
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
|
||||||
|
|
||||||
if (!isDefined(objective)) {
|
|
||||||
// If user doesn't set objective, force it to reg:squarederror
|
|
||||||
setObjective("reg:squarederror")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
|
||||||
set(evalMetric, setupDefaultEvalMetric())
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isDefined(customObj) && $(customObj) != null) {
|
|
||||||
set(objectiveType, "regression")
|
|
||||||
}
|
|
||||||
|
|
||||||
transformSchema(dataset.schema, logging = true)
|
|
||||||
|
|
||||||
// Packing with all params plus params user defined
|
|
||||||
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
|
|
||||||
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
|
|
||||||
|
|
||||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
|
|
||||||
buildTrainingData, derivedXGBParamMap)
|
|
||||||
|
|
||||||
val model = new XGBoostRegressionModel(uid, _booster)
|
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
|
||||||
model.setSummary(summary)
|
|
||||||
model
|
|
||||||
}
|
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
|
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
|
||||||
|
private val _uid = Identifiable.randomUID("xgbr")
|
||||||
override def load(path: String): XGBoostRegressor = super.load(path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class XGBoostRegressionModel private[ml] (
|
class XGBoostRegressionModel private[ml](val uid: String,
|
||||||
override val uid: String,
|
val nativeBooster: Booster,
|
||||||
private[scala] val _booster: Booster)
|
val summary: Option[XGBoostTrainingSummary] = None)
|
||||||
extends PredictionModel[Vector, XGBoostRegressionModel]
|
extends PredictionModel[Vector, XGBoostRegressionModel]
|
||||||
with XGBoostRegressorParams with InferenceParams
|
with RankerRegressorBaseModel[XGBoostRegressionModel] {
|
||||||
with MLWritable with Serializable {
|
|
||||||
|
|
||||||
import XGBoostRegressionModel._
|
|
||||||
|
|
||||||
// only called in copy()
|
|
||||||
def this(uid: String) = this(uid, null)
|
def this(uid: String) = this(uid, null)
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the native booster instance of this model.
|
|
||||||
* This is used to call low-level APIs on native booster, such as "getFeatureScore".
|
|
||||||
*/
|
|
||||||
def nativeBooster: Booster = _booster
|
|
||||||
|
|
||||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns summary (e.g. train/test objective history) of model on the
|
|
||||||
* training set. An exception is thrown if no summary is available.
|
|
||||||
*/
|
|
||||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
|
||||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
|
||||||
trainingSummary = Some(summary)
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
|
||||||
|
|
||||||
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
|
||||||
|
|
||||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
|
||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
|
||||||
|
|
||||||
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
|
|
||||||
allowNonZeroForMissing,
|
|
||||||
value
|
|
||||||
)
|
|
||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Single instance prediction.
|
|
||||||
* Note: The performance is not ideal, use it carefully!
|
|
||||||
*/
|
|
||||||
override def predict(features: Vector): Double = {
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val dm = new DMatrix(processMissingValues(
|
|
||||||
Iterator(features.asXGB),
|
|
||||||
$(missing),
|
|
||||||
$(allowNonZeroForMissing)
|
|
||||||
))
|
|
||||||
_booster.predict(data = dm)(0)(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] def produceResultIterator(
|
|
||||||
originalRowItr: Iterator[Row],
|
|
||||||
predictionItr: Iterator[Row],
|
|
||||||
predLeafItr: Iterator[Row],
|
|
||||||
predContribItr: Iterator[Row]): Iterator[Row] = {
|
|
||||||
// the following implementation is to be improved
|
|
||||||
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
|
||||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
|
||||||
originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
|
|
||||||
map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
|
|
||||||
}
|
|
||||||
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
|
||||||
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
|
||||||
originalRowItr.zip(predictionItr).zip(predLeafItr).
|
|
||||||
map { case ((originals: Row, prediction: Row), leaves: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
|
|
||||||
}
|
|
||||||
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
|
||||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
|
||||||
originalRowItr.zip(predictionItr).zip(predContribItr).
|
|
||||||
map { case ((originals: Row, prediction: Row), contribs: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
originalRowItr.zip(predictionItr).map {
|
|
||||||
case (originals: Row, originalPrediction: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
|
|
||||||
Array[Iterator[Row]] = {
|
|
||||||
val originalPredictionItr = {
|
|
||||||
booster.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
|
||||||
}
|
|
||||||
val predLeafItr = {
|
|
||||||
if (isDefined(leafPredictionCol)) {
|
|
||||||
booster.predictLeaf(dm, $(treeLimit)).
|
|
||||||
map(Row(_)).iterator
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val predContribItr = {
|
|
||||||
if (isDefined(contribPredictionCol)) {
|
|
||||||
booster.predictContrib(dm, $(treeLimit)).
|
|
||||||
map(Row(_)).iterator
|
|
||||||
} else {
|
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Array(originalPredictionItr, predLeafItr, predContribItr)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
|
||||||
if (isFeaturesColSet(schema)) {
|
|
||||||
// User has vectorized the features into VectorUDT.
|
|
||||||
super.transformSchema(schema)
|
|
||||||
} else {
|
|
||||||
transformSchemaWithFeaturesCols(false, schema)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
|
||||||
PreXGBoost.transformSchema(this, schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
|
||||||
transformSchema(dataset.schema, logging = true)
|
|
||||||
// Output selected columns only.
|
|
||||||
// This is a bit complicated since it tries to avoid repeated computation.
|
|
||||||
var outputData = PreXGBoost.transformDataset(this, dataset)
|
|
||||||
var numColsOutput = 0
|
|
||||||
|
|
||||||
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
|
||||||
originalPrediction(0).toDouble
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($(predictionCol).nonEmpty) {
|
|
||||||
outputData = outputData
|
|
||||||
.withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
|
|
||||||
numColsOutput += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if (numColsOutput == 0) {
|
|
||||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
|
||||||
" since no output columns were set.")
|
|
||||||
}
|
|
||||||
outputData.toDF.drop(col(_originalPredictionCol))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||||
val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
|
val newModel = copyValues(new XGBoostRegressionModel(uid, nativeBooster, summary), extra)
|
||||||
newModel.setSummary(summary).setParent(parent)
|
newModel.setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def write: MLWriter =
|
override def predict(features: Vector): Double = {
|
||||||
new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
|
val values = predictSingleInstance(features)
|
||||||
|
values(0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||||
|
override def read: MLReader[XGBoostRegressionModel] = new ModelReader
|
||||||
|
|
||||||
private[scala] val _originalPredictionCol = "_originalPrediction"
|
private class ModelReader extends XGBoostModelReader[XGBoostRegressionModel] {
|
||||||
|
|
||||||
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
|
||||||
|
|
||||||
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
|
||||||
|
|
||||||
private[XGBoostRegressionModel]
|
|
||||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
|
|
||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
|
||||||
// Save metadata and Params
|
|
||||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
|
||||||
// Save model data
|
|
||||||
val dataPath = new Path(path, "data").toString
|
|
||||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
|
||||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
|
||||||
instance._booster.saveModel(outputStream, getModelFormat())
|
|
||||||
outputStream.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
|
|
||||||
|
|
||||||
/** Checked against metadata when loading model */
|
|
||||||
private val className = classOf[XGBoostRegressionModel].getName
|
|
||||||
|
|
||||||
override def load(path: String): XGBoostRegressionModel = {
|
override def load(path: String): XGBoostRegressionModel = {
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
val xgbModel = loadBooster(path)
|
||||||
|
val meta = SparkUtils.loadMetadata(path, sc)
|
||||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
val model = new XGBoostRegressionModel(meta.uid, xgbModel, None)
|
||||||
|
meta.getAndSetParams(model)
|
||||||
val dataPath = new Path(path, "data").toString
|
|
||||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
|
||||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
|
||||||
|
|
||||||
val booster = SXGBoost.loadModel(dataInStream)
|
|
||||||
val model = new XGBoostRegressionModel(metadata.uid, booster)
|
|
||||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,17 +22,17 @@ class XGBoostTrainingSummary private(
|
|||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
val train = trainObjectiveHistory.mkString(",")
|
val train = trainObjectiveHistory.mkString(",")
|
||||||
val vaidationObjectiveHistoryString = {
|
val validationObjectiveHistoryString = {
|
||||||
validationObjectiveHistory.map {
|
validationObjectiveHistory.map {
|
||||||
case (name, metrics) =>
|
case (name, metrics) =>
|
||||||
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
|
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
|
||||||
}.mkString(";")
|
}.mkString(";")
|
||||||
}
|
}
|
||||||
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $vaidationObjectiveHistoryString)"
|
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $validationObjectiveHistoryString)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[xgboost4j] object XGBoostTrainingSummary {
|
private[spark] object XGBoostTrainingSummary {
|
||||||
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
|
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
|
||||||
new XGBoostTrainingSummary(
|
new XGBoostTrainingSummary(
|
||||||
trainObjectiveHistory = metrics("train"),
|
trainObjectiveHistory = metrics("train"),
|
||||||
|
|||||||
@ -1,295 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import scala.collection.immutable.HashSet
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
|
|
||||||
|
|
||||||
private[spark] trait BoosterParams extends Params {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
|
|
||||||
* can directly get the weights of new features and eta actually shrinks the feature weights
|
|
||||||
* to make the boosting process more conservative. [default=0.3] range: [0,1]
|
|
||||||
*/
|
|
||||||
final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
|
||||||
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
|
||||||
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
|
||||||
(value: Double) => value >= 0 && value <= 1)
|
|
||||||
|
|
||||||
final def getEta: Double = $(eta)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* minimum loss reduction required to make a further partition on a leaf node of the tree.
|
|
||||||
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
|
|
||||||
* Double.MaxValue]
|
|
||||||
*/
|
|
||||||
final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
|
|
||||||
"further partition on a leaf node of the tree. the larger, the more conservative the " +
|
|
||||||
"algorithm will be.", (value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getGamma: Double = $(gamma)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* maximum depth of a tree, increase this value will make model more complex / likely to be
|
|
||||||
* overfitting. [default=6] range: [1, Int.MaxValue]
|
|
||||||
*/
|
|
||||||
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
|
|
||||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
|
|
||||||
|
|
||||||
final def getMaxDepth: Int = $(maxDepth)
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
|
|
||||||
*/
|
|
||||||
final val maxLeaves = new IntParam(this, "maxLeaves",
|
|
||||||
"Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
|
|
||||||
(value: Int) => value >= 0)
|
|
||||||
|
|
||||||
final def getMaxLeaves: Int = $(maxLeaves)
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
|
||||||
* in a leaf node with the sum of instance weight less than min_child_weight, then the building
|
|
||||||
* process will give up further partitioning. In linear regression mode, this simply corresponds
|
|
||||||
* to minimum number of instances needed to be in each node. The larger, the more conservative
|
|
||||||
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
|
|
||||||
*/
|
|
||||||
final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
|
|
||||||
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
|
||||||
" the sum of instance weight less than min_child_weight, then the building process will" +
|
|
||||||
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
|
|
||||||
" number of instances needed to be in each node. The larger, the more conservative" +
|
|
||||||
" the algorithm will be.", (value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getMinChildWeight: Double = $(minChildWeight)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
|
|
||||||
* means there is no constraint. If it is set to a positive value, it can help making the update
|
|
||||||
* step more conservative. Usually this parameter is not needed, but it might help in logistic
|
|
||||||
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
|
|
||||||
* update. [default=0] range: [0, Double.MaxValue]
|
|
||||||
*/
|
|
||||||
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
|
|
||||||
"each tree's weight" +
|
|
||||||
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
|
||||||
" to a positive value, it can help making the update step more conservative. Usually this" +
|
|
||||||
" parameter is not needed, but it might help in logistic regression when class is extremely" +
|
|
||||||
" imbalanced. Set it to value of 1-10 might help control the update",
|
|
||||||
(value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getMaxDeltaStep: Double = $(maxDeltaStep)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
|
|
||||||
* collected half of the data instances to grow trees and this will prevent overfitting.
|
|
||||||
* [default=1] range:(0,1]
|
|
||||||
*/
|
|
||||||
final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
|
|
||||||
"instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
|
|
||||||
"instances to grow trees and this will prevent overfitting.",
|
|
||||||
(value: Double) => value <= 1 && value > 0)
|
|
||||||
|
|
||||||
final def getSubsample: Double = $(subsample)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
|
|
||||||
*/
|
|
||||||
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
|
|
||||||
"columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
|
||||||
|
|
||||||
final def getColsampleBytree: Double = $(colsampleBytree)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
|
|
||||||
*/
|
|
||||||
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
|
|
||||||
"columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
|
||||||
|
|
||||||
final def getColsampleBylevel: Double = $(colsampleBylevel)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* L2 regularization term on weights, increase this value will make model more conservative.
|
|
||||||
* [default=1]
|
|
||||||
*/
|
|
||||||
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
|
|
||||||
"increase this value will make model more conservative.", (value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getLambda: Double = $(lambda)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* L1 regularization term on weights, increase this value will make model more conservative.
|
|
||||||
* [default=0]
|
|
||||||
*/
|
|
||||||
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
|
|
||||||
"this value will make model more conservative.", (value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getAlpha: Double = $(alpha)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The tree construction algorithm used in XGBoost. options:
|
|
||||||
* {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
|
|
||||||
*/
|
|
||||||
final val treeMethod = new Param[String](this, "treeMethod",
|
|
||||||
"The tree construction algorithm used in XGBoost, options: " +
|
|
||||||
"{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
|
|
||||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
|
||||||
|
|
||||||
final def getTreeMethod: String = $(treeMethod)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The device for running XGBoost algorithms, options: cpu, cuda
|
|
||||||
*/
|
|
||||||
final val device = new Param[String](
|
|
||||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
|
|
||||||
(value: String) => BoosterParams.supportedDevices.contains(value)
|
|
||||||
)
|
|
||||||
|
|
||||||
final def getDevice: String = $(device)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* growth policy for fast histogram algorithm
|
|
||||||
*/
|
|
||||||
final val growPolicy = new Param[String](this, "growPolicy",
|
|
||||||
"Controls a way new nodes are added to the tree. Currently supported only if" +
|
|
||||||
" tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
|
|
||||||
" closest to the root. lossguide: split at nodes with highest loss change.",
|
|
||||||
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
|
||||||
|
|
||||||
final def getGrowPolicy: String = $(growPolicy)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* maximum number of bins in histogram
|
|
||||||
*/
|
|
||||||
final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
|
|
||||||
(value: Int) => value > 0)
|
|
||||||
|
|
||||||
final def getMaxBins: Int = $(maxBins)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* whether to build histograms using single precision floating point values
|
|
||||||
*/
|
|
||||||
final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram",
|
|
||||||
"whether to use single precision to build histograms")
|
|
||||||
|
|
||||||
final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
|
|
||||||
* value to consider: sum(negative cases) / sum(positive cases). [default=1]
|
|
||||||
*/
|
|
||||||
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
|
|
||||||
"positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
|
||||||
" sum(negative cases) / sum(positive cases)")
|
|
||||||
|
|
||||||
final def getScalePosWeight: Double = $(scalePosWeight)
|
|
||||||
|
|
||||||
// Dart boosters
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parameter for Dart booster.
|
|
||||||
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
|
|
||||||
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
|
|
||||||
*/
|
|
||||||
final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
|
|
||||||
"options: {'uniform', 'weighted'}",
|
|
||||||
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
|
||||||
|
|
||||||
final def getSampleType: String = $(sampleType)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parameter of Dart booster.
|
|
||||||
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
|
|
||||||
*/
|
|
||||||
final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
|
|
||||||
" algorithm, options: {'tree', 'forest'}",
|
|
||||||
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
|
||||||
|
|
||||||
final def getNormalizeType: String = $(normalizeType)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parameter of Dart booster.
|
|
||||||
* dropout rate. [default=0.0] range: [0.0, 1.0]
|
|
||||||
*/
|
|
||||||
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
|
|
||||||
value >= 0 && value <= 1)
|
|
||||||
|
|
||||||
final def getRateDrop: Double = $(rateDrop)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parameter of Dart booster.
|
|
||||||
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
|
|
||||||
* as gbtree. [default=0.0] range: [0.0, 1.0]
|
|
||||||
*/
|
|
||||||
final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
|
|
||||||
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
|
||||||
(value: Double) => value >= 0 && value <= 1)
|
|
||||||
|
|
||||||
final def getSkipDrop: Double = $(skipDrop)
|
|
||||||
|
|
||||||
// linear booster
|
|
||||||
/**
|
|
||||||
* Parameter of linear booster
|
|
||||||
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
|
|
||||||
*/
|
|
||||||
final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
|
|
||||||
"default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
|
||||||
|
|
||||||
final def getLambdaBias: Double = $(lambdaBias)
|
|
||||||
|
|
||||||
final val treeLimit = new IntParam(this, name = "treeLimit",
|
|
||||||
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
|
|
||||||
setDefault(treeLimit, 0)
|
|
||||||
|
|
||||||
final def getTreeLimit: Int = $(treeLimit)
|
|
||||||
|
|
||||||
final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
|
|
||||||
doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
|
|
||||||
"decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
|
|
||||||
"padded ")
|
|
||||||
|
|
||||||
final def getMonotoneConstraints: String = $(monotoneConstraints)
|
|
||||||
|
|
||||||
final val interactionConstraints = new Param[String](this,
|
|
||||||
name = "interactionConstraints",
|
|
||||||
doc = "Constraints for interaction representing permitted interactions. The constraints" +
|
|
||||||
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
|
|
||||||
" where each inner list is a group of indices of features that are allowed to interact" +
|
|
||||||
" with each other. See tutorial for more information")
|
|
||||||
|
|
||||||
final def getInteractionConstraints: String = $(interactionConstraints)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] object BoosterParams {
|
|
||||||
|
|
||||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
|
||||||
|
|
||||||
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
|
|
||||||
|
|
||||||
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
|
|
||||||
|
|
||||||
val supportedSampleType = HashSet("uniform", "weighted")
|
|
||||||
|
|
||||||
val supportedNormalizeType = HashSet("tree", "forest")
|
|
||||||
|
|
||||||
val supportedDevices = HashSet("cpu", "cuda")
|
|
||||||
}
|
|
||||||
@ -16,20 +16,18 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
import org.json4s.{DefaultFormats, Extraction}
|
||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
import org.json4s.jackson.Serialization
|
import org.json4s.jackson.Serialization
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* General spark parameter that includes TypeHints for (de)serialization using json4s.
|
* General spark parameter that includes TypeHints for (de)serialization using json4s.
|
||||||
*/
|
*/
|
||||||
class CustomGeneralParam[T: Manifest](
|
class CustomGeneralParam[T: Manifest](parent: Params,
|
||||||
parent: Params,
|
|
||||||
name: String,
|
name: String,
|
||||||
doc: String) extends Param[T](parent, name, doc) {
|
doc: String) extends Param[T](parent, name, doc) {
|
||||||
|
|
||||||
@ -52,33 +50,10 @@ class CustomGeneralParam[T: Manifest](
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class CustomEvalParam(
|
class CustomEvalParam(parent: Params,
|
||||||
parent: Params,
|
|
||||||
name: String,
|
name: String,
|
||||||
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
|
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
|
||||||
|
|
||||||
class CustomObjParam(
|
class CustomObjParam(parent: Params,
|
||||||
parent: Params,
|
|
||||||
name: String,
|
name: String,
|
||||||
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
|
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
|
||||||
|
|
||||||
class TrackerConfParam(
|
|
||||||
parent: Params,
|
|
||||||
name: String,
|
|
||||||
doc: String) extends Param[TrackerConf](parent, name, doc) {
|
|
||||||
|
|
||||||
/** Creates a param pair with the given value (for Java). */
|
|
||||||
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
|
|
||||||
|
|
||||||
override def jsonEncode(value: TrackerConf): String = {
|
|
||||||
import org.json4s.jackson.Serialization
|
|
||||||
implicit val formats = Serialization.formats(NoTypeHints)
|
|
||||||
compact(render(Extraction.decompose(value)))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def jsonDecode(json: String): TrackerConf = {
|
|
||||||
implicit val formats = DefaultFormats
|
|
||||||
val parsedValue = parse(json)
|
|
||||||
parsedValue.extract[TrackerConf]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -0,0 +1,61 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dart booster parameters, more details can be found at
|
||||||
|
* https://xgboost.readthedocs.io/en/stable/parameter.html#
|
||||||
|
* additional-parameters-for-dart-booster-booster-dart
|
||||||
|
*/
|
||||||
|
private[spark] trait DartBoosterParams extends Params {
|
||||||
|
|
||||||
|
final val sampleType = new Param[String](this, "sample_type", "Type of sampling algorithm, " +
|
||||||
|
"options: {'uniform', 'weighted'}", ParamValidators.inArray(Array("uniform", "weighted")))
|
||||||
|
|
||||||
|
final def getSampleType: String = $(sampleType)
|
||||||
|
|
||||||
|
final val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
||||||
|
" algorithm, options: {'tree', 'forest'}",
|
||||||
|
ParamValidators.inArray(Array("tree", "forest")))
|
||||||
|
|
||||||
|
final def getNormalizeType: String = $(normalizeType)
|
||||||
|
|
||||||
|
final val rateDrop = new DoubleParam(this, "rate_drop", "Dropout rate (a fraction of previous " +
|
||||||
|
"trees to drop during the dropout)",
|
||||||
|
ParamValidators.inRange(0, 1, true, true))
|
||||||
|
|
||||||
|
final def getRateDrop: Double = $(rateDrop)
|
||||||
|
|
||||||
|
final val oneDrop = new BooleanParam(this, "one_drop", "When this flag is enabled, at least " +
|
||||||
|
"one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout " +
|
||||||
|
"from the original DART paper)")
|
||||||
|
|
||||||
|
final def getOneDrop: Boolean = $(oneDrop)
|
||||||
|
|
||||||
|
final val skipDrop = new DoubleParam(this, "skip_drop", "Probability of skipping the dropout " +
|
||||||
|
"procedure during a boosting iteration.\nIf a dropout is skipped, new trees are added " +
|
||||||
|
"in the same manner as gbtree.\nNote that non-zero skip_drop has higher priority than " +
|
||||||
|
"rate_drop or one_drop.",
|
||||||
|
ParamValidators.inRange(0, 1, true, true))
|
||||||
|
|
||||||
|
final def getSkipDrop: Double = $(skipDrop)
|
||||||
|
|
||||||
|
setDefault(sampleType -> "uniform", normalizeType -> "tree", rateDrop -> 0, skipDrop -> 0)
|
||||||
|
|
||||||
|
}
|
||||||
@ -16,303 +16,45 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
import com.google.common.base.CaseFormat
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* General xgboost parameters, more details can be found
|
||||||
|
* at https://xgboost.readthedocs.io/en/stable/parameter.html#general-parameters
|
||||||
|
*/
|
||||||
private[spark] trait GeneralParams extends Params {
|
private[spark] trait GeneralParams extends Params {
|
||||||
|
|
||||||
/**
|
final val booster = new Param[String](this, "booster", "Which booster to use. Can be gbtree, " +
|
||||||
* The number of rounds for boosting
|
"gblinear or dart; gbtree and dart use tree based models while gblinear uses linear " +
|
||||||
*/
|
"functions.", ParamValidators.inArray(Array("gbtree", "dart")))
|
||||||
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
|
||||||
ParamValidators.gtEq(1))
|
|
||||||
setDefault(numRound, 1)
|
|
||||||
|
|
||||||
final def getNumRound: Int = $(numRound)
|
final def getBooster: String = $(booster)
|
||||||
|
|
||||||
/**
|
final val device = new Param[String](this, "device", "Device for XGBoost to run. User can " +
|
||||||
* number of workers used to train xgboost model. default: 1
|
"set it to one of the following values: {cpu, cuda, gpu}",
|
||||||
*/
|
ParamValidators.inArray(Array("cpu", "cuda", "gpu")))
|
||||||
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
|
|
||||||
ParamValidators.gtEq(1))
|
|
||||||
setDefault(numWorkers, 1)
|
|
||||||
|
|
||||||
final def getNumWorkers: Int = $(numWorkers)
|
final def getDevice: String = $(device)
|
||||||
|
|
||||||
/**
|
final val verbosity = new IntParam(this, "verbosity", "Verbosity of printing messages. Valid " +
|
||||||
* number of threads used by per worker. default 1
|
"values are 0 (silent), 1 (warning), 2 (info), 3 (debug). Sometimes XGBoost tries to change " +
|
||||||
*/
|
"configurations based on heuristics, which is displayed as warning message. If there's " +
|
||||||
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
|
"unexpected behaviour, please try to increase value of verbosity.",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.inRange(0, 3, true, true))
|
||||||
setDefault(nthread, 1)
|
|
||||||
|
|
||||||
final def getNthread: Int = $(nthread)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* whether to use external memory as cache. default: false
|
|
||||||
*/
|
|
||||||
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
|
|
||||||
"whether to use external memory as cache")
|
|
||||||
setDefault(useExternalMemory, false)
|
|
||||||
|
|
||||||
final def getUseExternalMemory: Boolean = $(useExternalMemory)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Deprecated. Please use verbosity instead.
|
|
||||||
* 0 means printing running messages, 1 means silent mode. default: 0
|
|
||||||
*/
|
|
||||||
final val silent = new IntParam(this, "silent",
|
|
||||||
"Deprecated. Please use verbosity instead. " +
|
|
||||||
"0 means printing running messages, 1 means silent mode.",
|
|
||||||
(value: Int) => value >= 0 && value <= 1)
|
|
||||||
|
|
||||||
final def getSilent: Int = $(silent)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), 3 (debug).
|
|
||||||
* default: 1
|
|
||||||
*/
|
|
||||||
final val verbosity = new IntParam(this, "verbosity",
|
|
||||||
"Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), " +
|
|
||||||
"3 (debug).",
|
|
||||||
(value: Int) => value >= 0 && value <= 3)
|
|
||||||
|
|
||||||
final def getVerbosity: Int = $(verbosity)
|
final def getVerbosity: Int = $(verbosity)
|
||||||
|
|
||||||
/**
|
final val validateParameters = new BooleanParam(this, "validate_parameters", "When set to " +
|
||||||
* customized objective function provided by user. default: null
|
"True, XGBoost will perform validation of input parameters to check whether a parameter " +
|
||||||
*/
|
"is used or not. A warning is emitted when there's unknown parameter.")
|
||||||
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
|
|
||||||
"provided by user")
|
|
||||||
|
|
||||||
/**
|
final def getValidateParameters: Boolean = $(validateParameters)
|
||||||
* customized evaluation function provided by user. default: null
|
|
||||||
*/
|
|
||||||
final val customEval = new CustomEvalParam(this, "customEval",
|
|
||||||
"customized evaluation function provided by user")
|
|
||||||
|
|
||||||
/**
|
final val nthread = new IntParam(this, "nthread", "Number of threads used by per worker",
|
||||||
* the value treated as missing. default: Float.NaN
|
ParamValidators.gtEq(1))
|
||||||
*/
|
|
||||||
final val missing = new FloatParam(this, "missing", "the value treated as missing")
|
|
||||||
setDefault(missing, Float.NaN)
|
|
||||||
|
|
||||||
final def getMissing: Float = $(missing)
|
final def getNthread: Int = $(nthread)
|
||||||
|
|
||||||
/**
|
setDefault(booster -> "gbtree", device -> "cpu", verbosity -> 1, validateParameters -> false,
|
||||||
* Allows for having a non-zero value for missing when training on prediction
|
nthread -> 1)
|
||||||
* on a Sparse or Empty vector.
|
|
||||||
*/
|
|
||||||
final val allowNonZeroForMissing = new BooleanParam(
|
|
||||||
this,
|
|
||||||
"allowNonZeroForMissing",
|
|
||||||
"Allow to have a non-zero value for missing when training or " +
|
|
||||||
"predicting on a Sparse or Empty vector. Should only be used if did " +
|
|
||||||
"not use Spark's VectorAssembler class to construct the feature vector " +
|
|
||||||
"but instead used a method that preserves zeros in your vector."
|
|
||||||
)
|
|
||||||
setDefault(allowNonZeroForMissing, false)
|
|
||||||
|
|
||||||
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
|
|
||||||
*/
|
|
||||||
final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
|
|
||||||
"and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
|
|
||||||
"load the checkpoint with highest version as the starting point for training. If " +
|
|
||||||
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
|
|
||||||
|
|
||||||
final def getCheckpointPath: String = $(checkpointPath)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
|
|
||||||
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
|
|
||||||
* also be set if the checkpoint interval is greater than 0.
|
|
||||||
*/
|
|
||||||
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
|
|
||||||
"set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
|
|
||||||
"model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
|
|
||||||
"set if the checkpoint interval is greater than 0.",
|
|
||||||
(interval: Int) => interval == -1 || interval >= 1)
|
|
||||||
|
|
||||||
final def getCheckpointInterval: Int = $(checkpointInterval)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
|
||||||
* TrackerConf class, which has the following definition:
|
|
||||||
*
|
|
||||||
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
|
|
||||||
*
|
|
||||||
* See below for detailed explanations.
|
|
||||||
*
|
|
||||||
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
|
|
||||||
* default: 0 (no timeout)
|
|
||||||
*
|
|
||||||
* Timeout for constructing the communication group and waiting for the tracker to
|
|
||||||
* shutdown when it's instructed to, doesn't apply to communication when tracking
|
|
||||||
* is running.
|
|
||||||
* The timeout value should take the time of data loading and pre-processing into account,
|
|
||||||
* due to potential lazy execution. Alternatively, you may force Spark to
|
|
||||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
|
||||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
|
||||||
* training/testing from hanging indefinitely, possible due to network issues.
|
|
||||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
|
||||||
*
|
|
||||||
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
|
|
||||||
* cannot be automatically guessed.
|
|
||||||
*
|
|
||||||
* - port : The port number for the tracker to listen to. Use a system allocated one by
|
|
||||||
* default.
|
|
||||||
*/
|
|
||||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
|
||||||
setDefault(trackerConf, TrackerConf())
|
|
||||||
|
|
||||||
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
|
||||||
final val seed = new LongParam(this, "seed", "random seed")
|
|
||||||
setDefault(seed, 0L)
|
|
||||||
|
|
||||||
final def getSeed: Long = $(seed)
|
|
||||||
|
|
||||||
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
|
|
||||||
* In native code, the parameter name is feature_name.
|
|
||||||
* */
|
|
||||||
final val featureNames = new StringArrayParam(this, "feature_names",
|
|
||||||
"an array of feature names")
|
|
||||||
|
|
||||||
final def getFeatureNames: Array[String] = $(featureNames)
|
|
||||||
|
|
||||||
/** Feature types, q is numeric and c is categorical.
|
|
||||||
* In native code, the parameter name is feature_type
|
|
||||||
* */
|
|
||||||
final val featureTypes = new StringArrayParam(this, "feature_types",
|
|
||||||
"an array of feature types")
|
|
||||||
|
|
||||||
final def getFeatureTypes: Array[String] = $(featureTypes)
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasLeafPredictionCol extends Params {
|
|
||||||
/**
|
|
||||||
* Param for leaf prediction column name.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
|
|
||||||
"name of the predictLeaf results")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getLeafPredictionCol: String = $(leafPredictionCol)
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasContribPredictionCol extends Params {
|
|
||||||
/**
|
|
||||||
* Param for contribution prediction column name.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
|
|
||||||
"name of the predictContrib results")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getContribPredictionCol: String = $(contribPredictionCol)
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasBaseMarginCol extends Params {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Param for initial prediction (aka base margin) column name.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
|
|
||||||
"Initial prediction (aka base margin) column name.")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getBaseMarginCol: String = $(baseMarginCol)
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasGroupCol extends Params {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Param for group column name.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getGroupCol: String = $(groupCol)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasNumClass extends Params {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* number of classes
|
|
||||||
*/
|
|
||||||
final val numClass = new IntParam(this, "numClass", "number of classes")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getNumClass: Int = $(numClass)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trait for shared param featuresCols.
|
|
||||||
*/
|
|
||||||
trait HasFeaturesCols extends Params {
|
|
||||||
/**
|
|
||||||
* Param for the names of feature columns.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
|
|
||||||
"an array of feature column names.")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getFeaturesCols: Array[String] = $(featuresCols)
|
|
||||||
|
|
||||||
/** Check if featuresCols is valid */
|
|
||||||
def isFeaturesColsValid: Boolean = {
|
|
||||||
isDefined(featuresCols) && $(featuresCols) != Array.empty
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] trait ParamMapFuncs extends Params {
|
|
||||||
|
|
||||||
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
|
||||||
for ((paramName, paramValue) <- xgboostParams) {
|
|
||||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
|
||||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
|
||||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
|
||||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
|
||||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
|
|
||||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
|
||||||
}
|
|
||||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
|
||||||
params.find(_.name == name).foreach {
|
|
||||||
case _: DoubleParam =>
|
|
||||||
set(name, paramValue.toString.toDouble)
|
|
||||||
case _: BooleanParam =>
|
|
||||||
set(name, paramValue.toString.toBoolean)
|
|
||||||
case _: IntParam =>
|
|
||||||
set(name, paramValue.toString.toInt)
|
|
||||||
case _: FloatParam =>
|
|
||||||
set(name, paramValue.toString.toFloat)
|
|
||||||
case _: LongParam =>
|
|
||||||
set(name, paramValue.toString.toLong)
|
|
||||||
case _: Param[_] =>
|
|
||||||
set(name, paramValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def MLlib2XGBoostParams: Map[String, Any] = {
|
|
||||||
val xgboostParams = new mutable.HashMap[String, Any]()
|
|
||||||
for (param <- params) {
|
|
||||||
if (isDefined(param)) {
|
|
||||||
val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
|
|
||||||
xgboostParams += name -> $(param)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xgboostParams.toMap
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,32 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{IntParam, Params}
|
|
||||||
|
|
||||||
private[spark] trait InferenceParams extends Params {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* batch size of inference iteration
|
|
||||||
*/
|
|
||||||
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getInferBatchSize: Int = $(inferBatchSize)
|
|
||||||
|
|
||||||
setDefault(inferBatchSize, 32 << 10)
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -20,98 +20,124 @@ import scala.collection.immutable.HashSet
|
|||||||
|
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify the learning task and the corresponding learning objective.
|
||||||
|
* More details can be found at
|
||||||
|
* https://xgboost.readthedocs.io/en/stable/parameter.html#learning-task-parameters
|
||||||
|
*/
|
||||||
private[spark] trait LearningTaskParams extends Params {
|
private[spark] trait LearningTaskParams extends Params {
|
||||||
|
|
||||||
/**
|
|
||||||
* Specify the learning task and the corresponding learning objective.
|
|
||||||
* options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw,
|
|
||||||
* count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma.
|
|
||||||
* default: reg:squarederror
|
|
||||||
*/
|
|
||||||
final val objective = new Param[String](this, "objective",
|
final val objective = new Param[String](this, "objective",
|
||||||
"objective function used for training")
|
"Objective function used for training",
|
||||||
|
ParamValidators.inArray(LearningTaskParams.SUPPORTED_OBJECTIVES.toArray))
|
||||||
|
|
||||||
final def getObjective: String = $(objective)
|
final def getObjective: String = $(objective)
|
||||||
|
|
||||||
/**
|
final val numClass = new IntParam(this, "num_class", "Number of classes, used by " +
|
||||||
* The learning objective type of the specified custom objective and eval.
|
"multi:softmax and multi:softprob objectives", ParamValidators.gtEq(0))
|
||||||
* Corresponding type will be assigned if custom objective is defined
|
|
||||||
* options: regression, classification. default: null
|
|
||||||
*/
|
|
||||||
final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
|
|
||||||
s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
|
|
||||||
(value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
|
|
||||||
|
|
||||||
final def getObjectiveType: String = $(objectiveType)
|
final def getNumClass: Int = $(numClass)
|
||||||
|
|
||||||
|
final val baseScore = new DoubleParam(this, "base_score", "The initial prediction score of " +
|
||||||
/**
|
"all instances, global bias. The parameter is automatically estimated for selected " +
|
||||||
* the initial prediction score of all instances, global bias. default=0.5
|
"objectives before training. To disable the estimation, specify a real number argument. " +
|
||||||
*/
|
"For sufficient number of iterations, changing this value will not have too much effect.")
|
||||||
final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
|
|
||||||
" instances, global bias")
|
|
||||||
|
|
||||||
final def getBaseScore: Double = $(baseScore)
|
final def getBaseScore: Double = $(baseScore)
|
||||||
|
|
||||||
/**
|
final val evalMetric = new Param[String](this, "eval_metric", "Evaluation metrics for " +
|
||||||
* evaluation metrics for validation data, a default metric will be assigned according to
|
"validation data, a default metric will be assigned according to objective (rmse for " +
|
||||||
* objective(rmse for regression, and error for classification, mean average precision for
|
"regression, and logloss for classification, mean average precision for rank:map, etc.)" +
|
||||||
* ranking). options: rmse, rmsle, mae, mape, logloss, error, merror, mlogloss, auc, aucpr, ndcg,
|
"User can add multiple evaluation metrics. Python users: remember to pass the metrics in " +
|
||||||
* map, gamma-deviance
|
"as list of parameters pairs instead of map, so that latter eval_metric won't override " +
|
||||||
*/
|
"previous ones", ParamValidators.inArray(LearningTaskParams.SUPPORTED_EVAL_METRICS.toArray))
|
||||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
|
||||||
"validation data, a default metric will be assigned according to objective " +
|
|
||||||
"(rmse for regression, and error for classification, mean average precision for ranking)")
|
|
||||||
|
|
||||||
final def getEvalMetric: String = $(evalMetric)
|
final def getEvalMetric: String = $(evalMetric)
|
||||||
|
|
||||||
/**
|
final val seed = new LongParam(this, "seed", "Random number seed.")
|
||||||
* Fraction of training points to use for testing.
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
|
||||||
"fraction of training points to use for testing",
|
|
||||||
ParamValidators.inRange(0, 1))
|
|
||||||
setDefault(trainTestRatio, 1.0)
|
|
||||||
|
|
||||||
@Deprecated
|
final def getSeed: Long = $(seed)
|
||||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
|
||||||
|
|
||||||
/**
|
final val seedPerIteration = new BooleanParam(this, "seed_per_iteration", "Seed PRNG " +
|
||||||
* whether caching training data
|
"determnisticly via iterator number..")
|
||||||
*/
|
|
||||||
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
|
||||||
"whether caching training data")
|
|
||||||
|
|
||||||
/**
|
final def getSeedPerIteration: Boolean = $(seedPerIteration)
|
||||||
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
|
|
||||||
* testing
|
|
||||||
*/
|
|
||||||
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
|
|
||||||
"whether cleaning checkpoint data")
|
|
||||||
|
|
||||||
/**
|
// Parameters for Tweedie Regression (objective=reg:tweedie)
|
||||||
* If non-zero, the training will be stopped after a specified number
|
final val tweedieVariancePower = new DoubleParam(this, "tweedie_variance_power", "Parameter " +
|
||||||
* of consecutive increases in any evaluation metric.
|
"that controls the variance of the Tweedie distribution var(y) ~ E(y)^tweedie_variance_power.",
|
||||||
*/
|
ParamValidators.inRange(1, 2, false, false))
|
||||||
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
|
||||||
"number of rounds of decreasing eval metric to tolerate before " +
|
|
||||||
"stopping the training",
|
|
||||||
(value: Int) => value == 0 || value > 1)
|
|
||||||
|
|
||||||
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
final def getTweedieVariancePower: Double = $(tweedieVariancePower)
|
||||||
|
|
||||||
|
// Parameter for using Pseudo-Huber (reg:pseudohubererror)
|
||||||
|
final val huberSlope = new DoubleParam(this, "huber_slope", "A parameter used for Pseudo-Huber " +
|
||||||
|
"loss to define the (delta) term.")
|
||||||
|
|
||||||
final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
|
final def getHuberSlope: Double = $(huberSlope)
|
||||||
"define the expected optimization to the evaluation metrics, true to maximize otherwise" +
|
|
||||||
" minimize it")
|
|
||||||
|
|
||||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
// Parameter for using Quantile Loss (reg:quantileerror) TODO
|
||||||
|
|
||||||
|
// Parameter for using AFT Survival Loss (survival:aft) and Negative
|
||||||
|
// Log Likelihood of AFT metric (aft-nloglik)
|
||||||
|
final val aftLossDistribution = new Param[String](this, "aft_loss_distribution", "Probability " +
|
||||||
|
"Density Function",
|
||||||
|
ParamValidators.inArray(Array("normal", "logistic", "extreme")))
|
||||||
|
|
||||||
|
final def getAftLossDistribution: String = $(aftLossDistribution)
|
||||||
|
|
||||||
|
// Parameters for learning to rank (rank:ndcg, rank:map, rank:pairwise)
|
||||||
|
final val lambdarankPairMethod = new Param[String](this, "lambdarank_pair_method", "pairs for " +
|
||||||
|
"pair-wise learning",
|
||||||
|
ParamValidators.inArray(Array("mean", "topk")))
|
||||||
|
|
||||||
|
final def getLambdarankPairMethod: String = $(lambdarankPairMethod)
|
||||||
|
|
||||||
|
final val lambdarankNumPairPerSample = new IntParam(this, "lambdarank_num_pair_per_sample",
|
||||||
|
"It specifies the number of pairs sampled for each document when pair method is mean, or" +
|
||||||
|
" the truncation level for queries when the pair method is topk. For example, to train " +
|
||||||
|
"with ndcg@6, set lambdarank_num_pair_per_sample to 6 and lambdarank_pair_method to topk",
|
||||||
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
final def getLambdarankNumPairPerSample: Int = $(lambdarankNumPairPerSample)
|
||||||
|
|
||||||
|
final val lambdarankUnbiased = new BooleanParam(this, "lambdarank_unbiased", "Specify " +
|
||||||
|
"whether do we need to debias input click data.")
|
||||||
|
|
||||||
|
final def getLambdarankUnbiased: Boolean = $(lambdarankUnbiased)
|
||||||
|
|
||||||
|
final val lambdarankBiasNorm = new DoubleParam(this, "lambdarank_bias_norm", "Lp " +
|
||||||
|
"normalization for position debiasing, default is L2. Only relevant when " +
|
||||||
|
"lambdarankUnbiased is set to true.")
|
||||||
|
|
||||||
|
final def getLambdarankBiasNorm: Double = $(lambdarankBiasNorm)
|
||||||
|
|
||||||
|
final val ndcgExpGain = new BooleanParam(this, "ndcg_exp_gain", "Whether we should " +
|
||||||
|
"use exponential gain function for NDCG.")
|
||||||
|
|
||||||
|
final def getNdcgExpGain: Boolean = $(ndcgExpGain)
|
||||||
|
|
||||||
|
setDefault(objective -> "reg:squarederror", numClass -> 0, seed -> 0, seedPerIteration -> false,
|
||||||
|
tweedieVariancePower -> 1.5, huberSlope -> 1, lambdarankPairMethod -> "mean",
|
||||||
|
lambdarankUnbiased -> false, lambdarankBiasNorm -> 2, ndcgExpGain -> true)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
val SUPPORTED_OBJECTIVES = HashSet("reg:squarederror", "reg:squaredlogerror", "reg:logistic",
|
||||||
|
"reg:pseudohubererror", "reg:absoluteerror", "reg:quantileerror", "binary:logistic",
|
||||||
|
"binary:logitraw", "binary:hinge", "count:poisson", "survival:cox", "survival:aft",
|
||||||
|
"multi:softmax", "multi:softprob", "rank:ndcg", "rank:map", "rank:pairwise", "reg:gamma",
|
||||||
|
"reg:tweedie")
|
||||||
|
|
||||||
val supportedObjectiveType = HashSet("regression", "classification")
|
val BINARY_CLASSIFICATION_OBJS = HashSet("binary:logistic", "binary:hinge", "binary:logitraw")
|
||||||
|
val MULTICLASSIFICATION_OBJS = HashSet("multi:softmax", "multi:softprob")
|
||||||
|
val RANKER_OBJS = HashSet("rank:ndcg", "rank:map", "rank:pairwise")
|
||||||
|
val REGRESSION_OBJS = SUPPORTED_OBJECTIVES -- BINARY_CLASSIFICATION_OBJS --
|
||||||
|
MULTICLASSIFICATION_OBJS -- RANKER_OBJS
|
||||||
|
|
||||||
|
val SUPPORTED_EVAL_METRICS = HashSet("rmse", "rmsle", "mae", "mape", "mphe", "logloss", "error",
|
||||||
|
"error@t", "merror", "mlogloss", "auc", "aucpr", "pre", "ndcg", "map", "ndcg@n", "map@n",
|
||||||
|
"pre@n", "ndcg-", "map-", "ndcg@n-", "map@n-", "poisson-nloglik", "gamma-nloglik",
|
||||||
|
"cox-nloglik", "gamma-deviance", "tweedie-nloglik", "aft-nloglik",
|
||||||
|
"interval-regression-accuracy")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,36 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.spark.sql.DataFrame
|
|
||||||
|
|
||||||
trait NonParamVariables {
|
|
||||||
protected var evalSetsMap: Map[String, DataFrame] = Map.empty
|
|
||||||
|
|
||||||
def setEvalSets(evalSets: Map[String, DataFrame]): this.type = {
|
|
||||||
evalSetsMap = evalSets
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
def getEvalSets(params: Map[String, Any]): Map[String, DataFrame] = {
|
|
||||||
if (params.contains("eval_sets")) {
|
|
||||||
params("eval_sets").asInstanceOf[Map[String, DataFrame]]
|
|
||||||
} else {
|
|
||||||
evalSetsMap
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
|
private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert XGBoost parameters to Spark Parameters
|
||||||
|
*
|
||||||
|
* @param xgboostParams XGBoost style parameters
|
||||||
|
*/
|
||||||
|
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
|
||||||
|
for ((name, paramValue) <- xgboostParams) {
|
||||||
|
params.find(_.name == name).foreach {
|
||||||
|
case _: DoubleParam =>
|
||||||
|
set(name, paramValue.toString.toDouble)
|
||||||
|
case _: BooleanParam =>
|
||||||
|
set(name, paramValue.toString.toBoolean)
|
||||||
|
case _: IntParam =>
|
||||||
|
set(name, paramValue.toString.toInt)
|
||||||
|
case _: FloatParam =>
|
||||||
|
set(name, paramValue.toString.toFloat)
|
||||||
|
case _: LongParam =>
|
||||||
|
set(name, paramValue.toString.toLong)
|
||||||
|
case _: Param[_] =>
|
||||||
|
set(name, paramValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert the user-supplied parameters to the XGBoost parameters.
|
||||||
|
*
|
||||||
|
* Note that this also contains jvm-specific parameters.
|
||||||
|
*/
|
||||||
|
def getXGBoostParams: Map[String, Any] = {
|
||||||
|
val xgboostParams = new mutable.HashMap[String, Any]()
|
||||||
|
|
||||||
|
// Only pass user-supplied parameters to xgboost.
|
||||||
|
for (param <- params) {
|
||||||
|
if (isSet(param) && !nonXGBoostParams.contains(param.name)) {
|
||||||
|
xgboostParams += param.name -> $(param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xgboostParams.toMap
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,25 +18,27 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
private[spark] trait RabitParams extends Params {
|
private[spark] trait RabitParams extends Params with NonXGBoostParams {
|
||||||
/**
|
|
||||||
* Rabit parameters passed through Rabit.Init into native layer
|
|
||||||
* rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation
|
|
||||||
* rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable
|
|
||||||
* dmlc_worker_connect_retry - number of retrys to tracker
|
|
||||||
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
|
|
||||||
*/
|
|
||||||
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
|
|
||||||
"threshold count to enable allreduce/broadcast with ring based topology",
|
|
||||||
ParamValidators.gtEq(1))
|
|
||||||
setDefault(rabitRingReduceThreshold, (32 << 10))
|
|
||||||
|
|
||||||
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
|
final val rabitTrackerTimeout = new IntParam(this, "rabitTrackerTimeout", "The number of " +
|
||||||
"timeout threshold after rabit observed failures")
|
"seconds before timeout waiting for workers to connect. and for the tracker to shutdown.",
|
||||||
setDefault(rabitTimeout, -1)
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
final def getRabitTrackerTimeout: Int = $(rabitTrackerTimeout)
|
||||||
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
|
||||||
setDefault(rabitConnectRetry, 5)
|
|
||||||
|
|
||||||
|
final val rabitTrackerHostIp = new Param[String](this, "rabitTrackerHostIp", "The Rabit " +
|
||||||
|
"Tracker host IP address. This is only needed if the host IP cannot be automatically " +
|
||||||
|
"guessed.")
|
||||||
|
|
||||||
|
final def getRabitTrackerHostIp: String = $(rabitTrackerHostIp)
|
||||||
|
|
||||||
|
final val rabitTrackerPort = new IntParam(this, "rabitTrackerPort", "The port number for the " +
|
||||||
|
"tracker to listen to. Use a system allocated one by default.",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getRabitTrackerPort: Int = $(rabitTrackerPort)
|
||||||
|
|
||||||
|
setDefault(rabitTrackerTimeout -> 0, rabitTrackerHostIp -> "", rabitTrackerPort -> 0)
|
||||||
|
|
||||||
|
addNonXGBoostParam(rabitTrackerPort, rabitTrackerHostIp, rabitTrackerPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,238 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import scala.collection.immutable.HashSet
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TreeBoosterParams defines the XGBoost TreeBooster parameters for Spark
|
||||||
|
*
|
||||||
|
* The details can be found at
|
||||||
|
* https://xgboost.readthedocs.io/en/stable/parameter.html#parameters-for-tree-booster
|
||||||
|
*/
|
||||||
|
private[spark] trait TreeBoosterParams extends Params {
|
||||||
|
|
||||||
|
final val eta = new DoubleParam(this, "eta", "Step size shrinkage used in update to prevents " +
|
||||||
|
"overfitting. After each boosting step, we can directly get the weights of new features, " +
|
||||||
|
"and eta shrinks the feature weights to make the boosting process more conservative.",
|
||||||
|
ParamValidators.inRange(0, 1, lowerInclusive = true, upperInclusive = true))
|
||||||
|
|
||||||
|
final def getEta: Double = $(eta)
|
||||||
|
|
||||||
|
final val gamma = new DoubleParam(this, "gamma", "Minimum loss reduction required to make a " +
|
||||||
|
"further partition on a leaf node of the tree. The larger gamma is, the more conservative " +
|
||||||
|
"the algorithm will be.",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getGamma: Double = $(gamma)
|
||||||
|
|
||||||
|
final val maxDepth = new IntParam(this, "max_depth", "Maximum depth of a tree. Increasing this " +
|
||||||
|
"value will make the model more complex and more likely to overfit. 0 indicates no limit " +
|
||||||
|
"on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. " +
|
||||||
|
"exact tree method requires non-zero value.",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getMaxDepth: Int = $(maxDepth)
|
||||||
|
|
||||||
|
final val minChildWeight = new DoubleParam(this, "min_child_weight", "Minimum sum of instance " +
|
||||||
|
"weight (hessian) needed in a child. If the tree partition step results in a leaf node " +
|
||||||
|
"with the sum of instance weight less than min_child_weight, then the building process " +
|
||||||
|
"will give up further partitioning. In linear regression task, this simply corresponds " +
|
||||||
|
"to minimum number of instances needed to be in each node. The larger min_child_weight " +
|
||||||
|
"is, the more conservative the algorithm will be.",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getMinChildWeight: Double = $(minChildWeight)
|
||||||
|
|
||||||
|
final val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow " +
|
||||||
|
"each leaf output to be. If the value is set to 0, it means there is no constraint. If it " +
|
||||||
|
"is set to a positive value, it can help making the update step more conservative. Usually " +
|
||||||
|
"this parameter is not needed, but it might help in logistic regression when class is " +
|
||||||
|
"extremely imbalanced. Set it to value of 1-10 might help control the update.",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getMaxDeltaStep: Double = $(maxDeltaStep)
|
||||||
|
|
||||||
|
final val subsample = new DoubleParam(this, "subsample", "Subsample ratio of the training " +
|
||||||
|
"instances. Setting it to 0.5 means that XGBoost would randomly sample half of the " +
|
||||||
|
"training data prior to growing trees. and this will prevent overfitting. Subsampling " +
|
||||||
|
"will occur once in every boosting iteration.",
|
||||||
|
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
|
||||||
|
|
||||||
|
final def getSubsample: Double = $(subsample)
|
||||||
|
|
||||||
|
final val samplingMethod = new Param[String](this, "sampling_method", "The method to use to " +
|
||||||
|
"sample the training instances. The supported sampling methods" +
|
||||||
|
"uniform: each training instance has an equal probability of being selected. Typically set " +
|
||||||
|
"subsample >= 0.5 for good results.\n" +
|
||||||
|
"gradient_based: the selection probability for each training instance is proportional to " +
|
||||||
|
"the regularized absolute value of gradients. subsample may be set to as low as 0.1 " +
|
||||||
|
"without loss of model accuracy. Note that this sampling method is only supported when " +
|
||||||
|
"tree_method is set to hist and the device is cuda; other tree methods only support " +
|
||||||
|
"uniform sampling.",
|
||||||
|
ParamValidators.inArray(Array("uniform", "gradient_based")))
|
||||||
|
|
||||||
|
final def getSamplingMethod: String = $(samplingMethod)
|
||||||
|
|
||||||
|
final val colsampleBytree = new DoubleParam(this, "colsample_bytree", "Subsample ratio of " +
|
||||||
|
"columns when constructing each tree. Subsampling occurs once for every tree constructed.",
|
||||||
|
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
|
||||||
|
|
||||||
|
final def getColsampleBytree: Double = $(colsampleBytree)
|
||||||
|
|
||||||
|
|
||||||
|
final val colsampleBylevel = new DoubleParam(this, "colsample_bylevel", "Subsample ratio of " +
|
||||||
|
"columns for each level. Subsampling occurs once for every new depth level reached in a " +
|
||||||
|
"tree. Columns are subsampled from the set of columns chosen for the current tree.",
|
||||||
|
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
|
||||||
|
|
||||||
|
final def getColsampleBylevel: Double = $(colsampleBylevel)
|
||||||
|
|
||||||
|
|
||||||
|
final val colsampleBynode = new DoubleParam(this, "colsample_bynode", "Subsample ratio of " +
|
||||||
|
"columns for each node (split). Subsampling occurs once every time a new split is " +
|
||||||
|
"evaluated. Columns are subsampled from the set of columns chosen for the current level.",
|
||||||
|
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
|
||||||
|
|
||||||
|
final def getColsampleBynode: Double = $(colsampleBynode)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* L2 regularization term on weights, increase this value will make model more conservative.
|
||||||
|
* [default=1]
|
||||||
|
*/
|
||||||
|
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights. " +
|
||||||
|
"Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getLambda: Double = $(lambda)
|
||||||
|
|
||||||
|
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights. " +
|
||||||
|
"Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getAlpha: Double = $(alpha)
|
||||||
|
|
||||||
|
final val treeMethod = new Param[String](this, "tree_method", "The tree construction " +
|
||||||
|
"algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
|
||||||
|
ParamValidators.inArray(BoosterParams.supportedTreeMethods.toArray))
|
||||||
|
|
||||||
|
final def getTreeMethod: String = $(treeMethod)
|
||||||
|
|
||||||
|
final val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of " +
|
||||||
|
"positive and negative weights, useful for unbalanced classes. A typical value to consider: " +
|
||||||
|
"sum(negative instances) / sum(positive instances)")
|
||||||
|
|
||||||
|
final def getScalePosWeight: Double = $(scalePosWeight)
|
||||||
|
|
||||||
|
final val updater = new Param[String](this, "updater", "A comma separated string defining the " +
|
||||||
|
"sequence of tree updaters to run, providing a modular way to construct and to modify the " +
|
||||||
|
"trees. This is an advanced parameter that is usually set automatically, depending on some " +
|
||||||
|
"other parameters. However, it could be also set explicitly by a user. " +
|
||||||
|
"The following updaters exist:\n" +
|
||||||
|
"grow_colmaker: non-distributed column-based construction of trees.\n" +
|
||||||
|
"grow_histmaker: distributed tree construction with row-based data splitting based on " +
|
||||||
|
"global proposal of histogram counting.\n" +
|
||||||
|
"grow_quantile_histmaker: Grow tree using quantized histogram.\n" +
|
||||||
|
"grow_gpu_hist: Enabled when tree_method is set to hist along with device=cuda.\n" +
|
||||||
|
"grow_gpu_approx: Enabled when tree_method is set to approx along with device=cuda.\n" +
|
||||||
|
"sync: synchronizes trees in all distributed nodes.\n" +
|
||||||
|
"refresh: refreshes tree's statistics and or leaf values based on the current data. Note " +
|
||||||
|
"that no random subsampling of data rows is performed.\n" +
|
||||||
|
"prune: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth " +
|
||||||
|
"greater than max_depth.",
|
||||||
|
(value: String) => value.split(",").forall(
|
||||||
|
ParamValidators.inArray(BoosterParams.supportedUpdaters.toArray)))
|
||||||
|
|
||||||
|
final def getUpdater: String = $(updater)
|
||||||
|
|
||||||
|
final val refreshLeaf = new BooleanParam(this, "refresh_leaf", "This is a parameter of the " +
|
||||||
|
"refresh updater. When this flag is 1, tree leafs as well as tree nodes' stats are updated. " +
|
||||||
|
"When it is 0, only node stats are updated.")
|
||||||
|
|
||||||
|
final def getRefreshLeaf: Boolean = $(refreshLeaf)
|
||||||
|
|
||||||
|
// TODO set updater/refreshLeaf defaul value
|
||||||
|
final val processType = new Param[String](this, "process_type", "A type of boosting process to " +
|
||||||
|
"run. options: {default, update}",
|
||||||
|
ParamValidators.inArray(Array("default", "update")))
|
||||||
|
|
||||||
|
final def getProcessType: String = $(processType)
|
||||||
|
|
||||||
|
final val growPolicy = new Param[String](this, "grow_policy", "Controls a way new nodes are " +
|
||||||
|
"added to the tree. Currently supported only if tree_method is set to hist or approx. " +
|
||||||
|
"Choices: depthwise, lossguide. depthwise: split at nodes closest to the root. " +
|
||||||
|
"lossguide: split at nodes with highest loss change.",
|
||||||
|
ParamValidators.inArray(Array("depthwise", "lossguide")))
|
||||||
|
|
||||||
|
final def getGrowPolicy: String = $(growPolicy)
|
||||||
|
|
||||||
|
|
||||||
|
final val maxLeaves = new IntParam(this, "max_leaves", "Maximum number of nodes to be added. " +
|
||||||
|
"Not used by exact tree method", ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getMaxLeaves: Int = $(maxLeaves)
|
||||||
|
|
||||||
|
final val maxBins = new IntParam(this, "max_bin", "Maximum number of discrete bins to bucket " +
|
||||||
|
"continuous features. Increasing this number improves the optimality of splits at the cost " +
|
||||||
|
"of higher computation time. Only used if tree_method is set to hist or approx.",
|
||||||
|
ParamValidators.gt(0))
|
||||||
|
|
||||||
|
final def getMaxBins: Int = $(maxBins)
|
||||||
|
|
||||||
|
final val numParallelTree = new IntParam(this, "num_parallel_tree", "Number of parallel trees " +
|
||||||
|
"constructed during each iteration. This option is used to support boosted random forest.",
|
||||||
|
ParamValidators.gt(0))
|
||||||
|
|
||||||
|
final def getNumParallelTree: Int = $(numParallelTree)
|
||||||
|
|
||||||
|
final val monotoneConstraints = new IntArrayParam(this, "monotone_constraints", "Constraint of " +
|
||||||
|
"variable monotonicity.")
|
||||||
|
|
||||||
|
final def getMonotoneConstraints: Array[Int] = $(monotoneConstraints)
|
||||||
|
|
||||||
|
final val interactionConstraints = new Param[String](this,
|
||||||
|
name = "interaction_constraints",
|
||||||
|
doc = "Constraints for interaction representing permitted interactions. The constraints" +
|
||||||
|
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
|
||||||
|
" where each inner list is a group of indices of features that are allowed to interact" +
|
||||||
|
" with each other. See tutorial for more information")
|
||||||
|
|
||||||
|
final def getInteractionConstraints: String = $(interactionConstraints)
|
||||||
|
|
||||||
|
|
||||||
|
final val maxCachedHistNode = new IntParam(this, "max_cached_hist_node", "Maximum number of " +
|
||||||
|
"cached nodes for CPU histogram.",
|
||||||
|
ParamValidators.gt(0))
|
||||||
|
|
||||||
|
final def getMaxCachedHistNode: Int = $(maxCachedHistNode)
|
||||||
|
|
||||||
|
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0,
|
||||||
|
subsample -> 1, samplingMethod -> "uniform", colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||||
|
colsampleBynode -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", scalePosWeight -> 1,
|
||||||
|
processType -> "default", growPolicy -> "depthwise", maxLeaves -> 0, maxBins -> 256,
|
||||||
|
numParallelTree -> 1, maxCachedHistNode -> 65536)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] object BoosterParams {
|
||||||
|
|
||||||
|
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
|
||||||
|
|
||||||
|
val supportedUpdaters = HashSet("grow_colmaker", "grow_histmaker", "grow_quantile_histmaker",
|
||||||
|
"grow_gpu_hist", "grow_gpu_approx", "sync", "refresh", "prune")
|
||||||
|
}
|
||||||
@ -1,119 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.apache.spark.ml.param.{Param, ParamValidators}
|
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
|
|
||||||
import org.apache.spark.ml.util.XGBoostSchemaUtils
|
|
||||||
import org.apache.spark.sql.Dataset
|
|
||||||
import org.apache.spark.sql.types.StructType
|
|
||||||
|
|
||||||
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
|
||||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
|
|
||||||
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
|
|
||||||
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
|
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
|
||||||
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
|
|
||||||
isDefined(checkpointInterval) && getCheckpointInterval > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|
|
||||||
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|
|
||||||
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|
|
||||||
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|
|
||||||
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|
|
||||||
* Default: "error"
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
|
|
||||||
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
|
|
||||||
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
|
|
||||||
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
|
|
||||||
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
|
|
||||||
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
|
|
||||||
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
|
|
||||||
ParamValidators.inArray(Array("skip", "error", "keep")))
|
|
||||||
|
|
||||||
setDefault(handleInvalid, "error")
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Specify an array of feature column names which must be numeric types.
|
|
||||||
*/
|
|
||||||
def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value)
|
|
||||||
|
|
||||||
/** Set the handleInvalid for VectorAssembler */
|
|
||||||
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if schema has a field named with the value of "featuresCol" param and it's data type
|
|
||||||
* must be VectorUDT
|
|
||||||
*/
|
|
||||||
def isFeaturesColSet(schema: StructType): Boolean = {
|
|
||||||
schema.fieldNames.contains(getFeaturesCol) &&
|
|
||||||
XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** check the features columns type */
|
|
||||||
def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = {
|
|
||||||
if (isFeaturesColsValid) {
|
|
||||||
if (fit) {
|
|
||||||
XGBoostSchemaUtils.checkNumericType(schema, $(labelCol))
|
|
||||||
}
|
|
||||||
$(featuresCols).foreach(feature =>
|
|
||||||
XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType))
|
|
||||||
schema
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException("featuresCol or featuresCols must be specified")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Vectorize the features columns if necessary.
|
|
||||||
*
|
|
||||||
* @param input the input dataset
|
|
||||||
* @return (output dataset and the feature column name)
|
|
||||||
*/
|
|
||||||
def vectorize(input: Dataset[_]): (Dataset[_], String) = {
|
|
||||||
val schema = input.schema
|
|
||||||
if (isFeaturesColSet(schema)) {
|
|
||||||
// Dataset already has vectorized.
|
|
||||||
(input, getFeaturesCol)
|
|
||||||
} else if (isFeaturesColsValid) {
|
|
||||||
val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) {
|
|
||||||
getFeaturesCol
|
|
||||||
} else {
|
|
||||||
"features_" + uid
|
|
||||||
}
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setHandleInvalid($(handleInvalid))
|
|
||||||
.setInputCols(getFeaturesCols)
|
|
||||||
.setOutputCol(featuresName)
|
|
||||||
(vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName)
|
|
||||||
} else {
|
|
||||||
// never reach here, since transformSchema will take care of the case
|
|
||||||
// that featuresCols is invalid
|
|
||||||
(input, getFeaturesCol)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
|
||||||
|
|
||||||
private[scala] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
|
|
||||||
@ -0,0 +1,359 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param._
|
||||||
|
import org.apache.spark.ml.param.shared._
|
||||||
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
|
||||||
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
/**
|
||||||
|
* Param for leaf prediction column name.
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
|
||||||
|
"name of the predictLeaf results")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getLeafPredictionCol: String = $(leafPredictionCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
trait HasContribPredictionCol extends Params {
|
||||||
|
/**
|
||||||
|
* Param for contribution prediction column name.
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
|
||||||
|
"name of the predictContrib results")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getContribPredictionCol: String = $(contribPredictionCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
trait HasBaseMarginCol extends Params {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Param for initial prediction (aka base margin) column name.
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
|
||||||
|
"Initial prediction (aka base margin) column name.")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getBaseMarginCol: String = $(baseMarginCol)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
trait HasGroupCol extends Params {
|
||||||
|
|
||||||
|
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getGroupCol: String = $(groupCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Trait for shared param featuresCols.
|
||||||
|
*/
|
||||||
|
trait HasFeaturesCols extends Params {
|
||||||
|
/**
|
||||||
|
* Param for the names of feature columns.
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
|
||||||
|
"An array of feature column names.")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getFeaturesCols: Array[String] = $(featuresCols)
|
||||||
|
|
||||||
|
/** Check if featuresCols is valid */
|
||||||
|
def isFeaturesColsValid: Boolean = {
|
||||||
|
isDefined(featuresCols) && $(featuresCols) != Array.empty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A trait to hold non-xgboost parameters
|
||||||
|
*/
|
||||||
|
trait NonXGBoostParams extends Params {
|
||||||
|
private val paramNames: ArrayBuffer[String] = ArrayBuffer.empty
|
||||||
|
|
||||||
|
protected def addNonXGBoostParam(ps: Param[_]*): Unit = {
|
||||||
|
ps.foreach(p => paramNames.append(p.name))
|
||||||
|
}
|
||||||
|
|
||||||
|
protected lazy val nonXGBoostParams: Array[String] = paramNames.toSet.toArray
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost spark-specific parameters which should not be passed
|
||||||
|
* into the xgboost library
|
||||||
|
*
|
||||||
|
* @tparam T should be the XGBoost estimators or models
|
||||||
|
*/
|
||||||
|
private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFeaturesCol
|
||||||
|
with HasLabelCol with HasBaseMarginCol with HasWeightCol with HasPredictionCol
|
||||||
|
with HasLeafPredictionCol with HasContribPredictionCol
|
||||||
|
with RabitParams with NonXGBoostParams with SchemaValidationTrait {
|
||||||
|
|
||||||
|
final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost",
|
||||||
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
final def getNumRound: Int = $(numRound)
|
||||||
|
|
||||||
|
final val forceRepartition = new BooleanParam(this, "forceRepartition", "If the partition " +
|
||||||
|
"is equal to numWorkers, xgboost won't repartition the dataset. Set forceRepartition to " +
|
||||||
|
"true to force repartition.")
|
||||||
|
|
||||||
|
final def getForceRepartition: Boolean = $(forceRepartition)
|
||||||
|
|
||||||
|
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
||||||
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", "Stop training " +
|
||||||
|
"Number of rounds of decreasing eval metric to tolerate before stopping training",
|
||||||
|
ParamValidators.gtEq(0))
|
||||||
|
|
||||||
|
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
||||||
|
|
||||||
|
final val inferBatchSize = new IntParam(this, "inferBatchSize", "batch size in rows " +
|
||||||
|
"to be grouped for inference",
|
||||||
|
ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getInferBatchSize: Int = $(inferBatchSize)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the value treated as missing. default: Float.NaN
|
||||||
|
*/
|
||||||
|
final val missing = new FloatParam(this, "missing", "The value treated as missing")
|
||||||
|
|
||||||
|
final def getMissing: Float = $(missing)
|
||||||
|
|
||||||
|
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
|
||||||
|
"provided by user")
|
||||||
|
|
||||||
|
final def getCustomObj: ObjectiveTrait = $(customObj)
|
||||||
|
|
||||||
|
final val customEval = new CustomEvalParam(this, "customEval",
|
||||||
|
"customized evaluation function provided by user")
|
||||||
|
|
||||||
|
final def getCustomEval: EvalTrait = $(customEval)
|
||||||
|
|
||||||
|
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
|
||||||
|
* In native code, the parameter name is feature_name.
|
||||||
|
* */
|
||||||
|
final val featureNames = new StringArrayParam(this, "feature_names",
|
||||||
|
"an array of feature names")
|
||||||
|
|
||||||
|
final def getFeatureNames: Array[String] = $(featureNames)
|
||||||
|
|
||||||
|
/** Feature types, q is numeric and c is categorical.
|
||||||
|
* In native code, the parameter name is feature_type
|
||||||
|
* */
|
||||||
|
final val featureTypes = new StringArrayParam(this, "feature_types",
|
||||||
|
"an array of feature types")
|
||||||
|
|
||||||
|
final def getFeatureTypes: Array[String] = $(featureTypes)
|
||||||
|
|
||||||
|
setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
|
||||||
|
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
|
||||||
|
featuresCols -> Array.empty, customObj -> null, customEval -> null,
|
||||||
|
featureNames -> Array.empty, featureTypes -> Array.empty)
|
||||||
|
|
||||||
|
addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
|
||||||
|
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
|
||||||
|
forceRepartition, featuresCols, customEval, customObj, featureTypes, featureNames)
|
||||||
|
|
||||||
|
final def getNumWorkers: Int = $(numWorkers)
|
||||||
|
|
||||||
|
def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setForceRepartition(value: Boolean): T = set(forceRepartition, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setFeaturesCol(value: Array[String]): T = set(featuresCols, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setBaseMarginCol(value: String): T = set(baseMarginCol, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setWeightCol(value: String): T = set(weightCol, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLeafPredictionCol(value: String): T = set(leafPredictionCol, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setContribPredictionCol(value: String): T = set(contribPredictionCol, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMissing(value: Float): T = set(missing, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setCustomObj(value: ObjectiveTrait): T = set(customObj, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setCustomEval(value: EvalTrait): T = set(customEval, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setRabitTrackerTimeout(value: Int): T = set(rabitTrackerTimeout, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setRabitTrackerPort(value: Int): T = set(rabitTrackerPort, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] trait SchemaValidationTrait {
|
||||||
|
|
||||||
|
def validateAndTransformSchema(schema: StructType,
|
||||||
|
fitting: Boolean): StructType = schema
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost ranking spark-specific parameters
|
||||||
|
*
|
||||||
|
* @tparam T should be XGBoostRanker or XGBoostRankingModel
|
||||||
|
*/
|
||||||
|
private[spark] trait RankerParams[T <: Params] extends HasGroupCol with NonXGBoostParams {
|
||||||
|
def setGroupCol(value: String): T = set(groupCol, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
addNonXGBoostParam(groupCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost-specific parameters to pass into xgboost libraray
|
||||||
|
*
|
||||||
|
* @tparam T should be the XGBoost estimators or models
|
||||||
|
*/
|
||||||
|
private[spark] trait XGBoostParams[T <: Params] extends TreeBoosterParams
|
||||||
|
with LearningTaskParams with GeneralParams with DartBoosterParams {
|
||||||
|
|
||||||
|
// Setters for TreeBoosterParams
|
||||||
|
def setEta(value: Double): T = set(eta, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setGamma(value: Double): T = set(gamma, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMaxDepth(value: Int): T = set(maxDepth, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMinChildWeight(value: Double): T = set(minChildWeight, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMaxDeltaStep(value: Double): T = set(maxDeltaStep, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setSubsample(value: Double): T = set(subsample, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setSamplingMethod(value: String): T = set(samplingMethod, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setColsampleBytree(value: Double): T = set(colsampleBytree, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setColsampleBylevel(value: Double): T = set(colsampleBylevel, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setColsampleBynode(value: Double): T = set(colsampleBynode, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLambda(value: Double): T = set(lambda, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setAlpha(value: Double): T = set(alpha, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setTreeMethod(value: String): T = set(treeMethod, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setScalePosWeight(value: Double): T = set(scalePosWeight, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setUpdater(value: String): T = set(updater, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setRefreshLeaf(value: Boolean): T = set(refreshLeaf, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setProcessType(value: String): T = set(processType, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setGrowPolicy(value: String): T = set(growPolicy, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMaxLeaves(value: Int): T = set(maxLeaves, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMaxBins(value: Int): T = set(maxBins, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNumParallelTree(value: Int): T = set(numParallelTree, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setInteractionConstraints(value: String): T =
|
||||||
|
set(interactionConstraints, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setMaxCachedHistNode(value: Int): T = set(maxCachedHistNode, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
// Setters for LearningTaskParams
|
||||||
|
|
||||||
|
def setObjective(value: String): T = set(objective, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNumClass(value: Int): T = set(numClass, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setBaseScore(value: Double): T = set(baseScore, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setEvalMetric(value: String): T = set(evalMetric, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setSeed(value: Long): T = set(seed, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setSeedPerIteration(value: Boolean): T = set(seedPerIteration, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setTweedieVariancePower(value: Double): T = set(tweedieVariancePower, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setHuberSlope(value: Double): T = set(huberSlope, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setAftLossDistribution(value: String): T = set(aftLossDistribution, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLambdarankPairMethod(value: String): T = set(lambdarankPairMethod, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLambdarankNumPairPerSample(value: Int): T =
|
||||||
|
set(lambdarankNumPairPerSample, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLambdarankUnbiased(value: Boolean): T = set(lambdarankUnbiased, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setLambdarankBiasNorm(value: Double): T = set(lambdarankBiasNorm, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNdcgExpGain(value: Boolean): T = set(ndcgExpGain, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
// Setters for Dart
|
||||||
|
def setSampleType(value: String): T = set(sampleType, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNormalizeType(value: String): T = set(normalizeType, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setRateDrop(value: Double): T = set(rateDrop, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setOneDrop(value: Boolean): T = set(oneDrop, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setSkipDrop(value: Double): T = set(skipDrop, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
// Setters for GeneralParams
|
||||||
|
def setBooster(value: String): T = set(booster, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setDevice(value: String): T = set(device, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setVerbosity(value: Int): T = set(verbosity, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setValidateParameters(value: Boolean): T = set(validateParameters, value).asInstanceOf[T]
|
||||||
|
|
||||||
|
def setNthread(value: Int): T = set(nthread, value).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] trait ParamUtils[T <: Params] extends Params {
|
||||||
|
|
||||||
|
def isDefinedNonEmpty(param: Param[String]): Boolean = {
|
||||||
|
isDefined(param) && $(param).nonEmpty
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,229 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.util
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
|
|
||||||
import org.apache.spark.HashPartitioner
|
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
|
||||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
|
||||||
|
|
||||||
object DataUtils extends Serializable {
|
|
||||||
private[spark] implicit class XGBLabeledPointFeatures(
|
|
||||||
val labeledPoint: XGBLabeledPoint
|
|
||||||
) extends AnyVal {
|
|
||||||
/** Converts the point to [[MLLabeledPoint]]. */
|
|
||||||
private[spark] def asML: MLLabeledPoint = {
|
|
||||||
MLLabeledPoint(labeledPoint.label, labeledPoint.features)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
|
|
||||||
*/
|
|
||||||
def features: Vector = if (labeledPoint.indices == null) {
|
|
||||||
Vectors.dense(labeledPoint.values.map(_.toDouble))
|
|
||||||
} else {
|
|
||||||
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] implicit class MLLabeledPointToXGBLabeledPoint(
|
|
||||||
val labeledPoint: MLLabeledPoint
|
|
||||||
) extends AnyVal {
|
|
||||||
/** Converts an [[MLLabeledPoint]] to an [[XGBLabeledPoint]]. */
|
|
||||||
def asXGB: XGBLabeledPoint = {
|
|
||||||
labeledPoint.features.asXGB.copy(label = labeledPoint.label.toFloat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
|
|
||||||
/**
|
|
||||||
* Converts a [[Vector]] to a data point with a dummy label.
|
|
||||||
*
|
|
||||||
* This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
|
|
||||||
* for prediction.
|
|
||||||
*/
|
|
||||||
def asXGB: XGBLabeledPoint = v match {
|
|
||||||
case v: DenseVector =>
|
|
||||||
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
|
|
||||||
case v: SparseVector =>
|
|
||||||
XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def attachPartitionKey(
|
|
||||||
row: Row,
|
|
||||||
deterministicPartition: Boolean,
|
|
||||||
numWorkers: Int,
|
|
||||||
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
|
|
||||||
if (deterministicPartition) {
|
|
||||||
(math.abs(row.hashCode() % numWorkers), xgbLp)
|
|
||||||
} else {
|
|
||||||
(1, xgbLp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def repartitionRDDs(
|
|
||||||
deterministicPartition: Boolean,
|
|
||||||
numWorkers: Int,
|
|
||||||
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
|
|
||||||
if (deterministicPartition) {
|
|
||||||
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
|
|
||||||
rdd => rdd.map(_._2)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
arrayOfRDDs.map(rdd => {
|
|
||||||
if (rdd.getNumPartitions != numWorkers) {
|
|
||||||
rdd.map(_._2).repartition(numWorkers)
|
|
||||||
} else {
|
|
||||||
rdd.map(_._2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
|
|
||||||
private[spark] case class PackedParams(labelCol: Column,
|
|
||||||
featuresCol: Column,
|
|
||||||
weight: Column,
|
|
||||||
baseMargin: Column,
|
|
||||||
group: Option[Column],
|
|
||||||
numWorkers: Int,
|
|
||||||
deterministicPartition: Boolean)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
|
|
||||||
*
|
|
||||||
* First, it serves converting each instance of input into XGBLabeledPoint
|
|
||||||
* Second, it repartition the RDD to the number workers.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
|
||||||
packedParams: PackedParams,
|
|
||||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
|
||||||
|
|
||||||
packedParams match {
|
|
||||||
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
|
|
||||||
deterministicPartition) =>
|
|
||||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
|
||||||
featuresCol,
|
|
||||||
weight.cast(FloatType),
|
|
||||||
groupCol.cast(IntegerType),
|
|
||||||
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
|
|
||||||
featuresCol,
|
|
||||||
weight.cast(FloatType),
|
|
||||||
baseMargin.cast(FloatType)))
|
|
||||||
val arrayOfRDDs = dataFrames.toArray.map {
|
|
||||||
df => df.select(selectedColumns: _*).rdd.map {
|
|
||||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
|
||||||
baseMargin: Float) =>
|
|
||||||
val (size, indices, values) = features match {
|
|
||||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
|
||||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
|
||||||
}
|
|
||||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
|
|
||||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
|
||||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
|
||||||
val (size, indices, values) = features match {
|
|
||||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
|
||||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
|
||||||
}
|
|
||||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
|
|
||||||
baseMargin = baseMargin)
|
|
||||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
|
|
||||||
|
|
||||||
case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def processMissingValues(
|
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
|
||||||
missing: Float,
|
|
||||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
|
||||||
if (!missing.isNaN) {
|
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
|
||||||
missing, (v: Float) => v != missing)
|
|
||||||
} else {
|
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
|
||||||
missing, (v: Float) => !v.isNaN)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] def processMissingValuesWithGroup(
|
|
||||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
|
||||||
missing: Float,
|
|
||||||
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
|
||||||
if (!missing.isNaN) {
|
|
||||||
xgbLabelPointGroups.map {
|
|
||||||
labeledPoints => processMissingValues(
|
|
||||||
labeledPoints.iterator,
|
|
||||||
missing,
|
|
||||||
allowNonZeroMissing
|
|
||||||
).toArray
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xgbLabelPointGroups
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def removeMissingValues(
|
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
|
||||||
missing: Float,
|
|
||||||
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
|
|
||||||
xgbLabelPoints.map { labeledPoint =>
|
|
||||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
|
||||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
|
||||||
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
|
|
||||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
|
||||||
valuesBuilder += value
|
|
||||||
}
|
|
||||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def verifyMissingSetting(
|
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
|
||||||
missing: Float,
|
|
||||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
|
||||||
if (missing != 0.0f && !allowNonZeroMissing) {
|
|
||||||
xgbLabelPoints.map(labeledPoint => {
|
|
||||||
if (labeledPoint.indices != null) {
|
|
||||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
|
||||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
|
||||||
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
|
||||||
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
|
||||||
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
|
|
||||||
s" (only use if you know what you are doing)")
|
|
||||||
}
|
|
||||||
labeledPoint
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
xgbLabelPoints
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark.ml.util
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.hadoop.fs.FSDataInputStream
|
|
||||||
import org.json4s.DefaultFormats
|
|
||||||
import org.json4s.JsonAST.JObject
|
|
||||||
import org.json4s.JsonDSL._
|
|
||||||
import org.json4s.jackson.JsonMethods.{compact, render}
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
import org.apache.spark.ml.param.Params
|
|
||||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
|
||||||
|
|
||||||
abstract class XGBoostWriter extends MLWriter {
|
|
||||||
def getModelFormat(): String = {
|
|
||||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object DefaultXGBoostParamsWriter {
|
|
||||||
|
|
||||||
val XGBOOST_VERSION_TAG = "xgboostVersion"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
|
|
||||||
*/
|
|
||||||
def saveMetadata(
|
|
||||||
instance: Params,
|
|
||||||
path: String,
|
|
||||||
sc: SparkContext): Unit = {
|
|
||||||
// save xgboost version to distinguish the old model.
|
|
||||||
val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
|
|
||||||
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object DefaultXGBoostParamsReader {
|
|
||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
|
|
||||||
*
|
|
||||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
|
||||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
|
||||||
*/
|
|
||||||
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
|
||||||
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract Params from metadata, and set them in the instance.
|
|
||||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
|
||||||
*
|
|
||||||
* And it will auto-skip the parameter not defined.
|
|
||||||
*
|
|
||||||
* This API is mainly copied from DefaultParamsReader
|
|
||||||
*/
|
|
||||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
|
||||||
|
|
||||||
// XGBoost didn't set the default parameters since the save/load code is copied
|
|
||||||
// from spark 2.3.x, which means it just used the default values
|
|
||||||
// as the same with XGBoost version instead of them in model.
|
|
||||||
// For the compatibility, here we still don't set the default parameters.
|
|
||||||
// setParams(instance, metadata, isDefault = true)
|
|
||||||
|
|
||||||
setParams(instance, metadata, isDefault = false)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** This API is only for XGBoostClassificationModel */
|
|
||||||
def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
|
|
||||||
implicit val format = DefaultFormats
|
|
||||||
|
|
||||||
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
|
|
||||||
// or the new xgboost compatible.
|
|
||||||
val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
|
|
||||||
.extractOpt[String]
|
|
||||||
|
|
||||||
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
|
|
||||||
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
|
|
||||||
// or else, XGBoost will throw exception.
|
|
||||||
// So it's safe to get numClass from meta data.
|
|
||||||
xgbVerOpt
|
|
||||||
.map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
|
|
||||||
.getOrElse(dataInStream.readInt())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private def setParams(
|
|
||||||
instance: Params,
|
|
||||||
metadata: Metadata,
|
|
||||||
isDefault: Boolean): Unit = {
|
|
||||||
val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
|
|
||||||
paramsToSet match {
|
|
||||||
case JObject(pairs) =>
|
|
||||||
pairs.foreach { case (paramName, jsonValue) =>
|
|
||||||
val finalName = handleBrokenlyChangedName(paramName)
|
|
||||||
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
|
||||||
// So we need to check if the parameter exists instead of blindly setting it.
|
|
||||||
if (instance.hasParam(finalName)) {
|
|
||||||
val param = instance.getParam(finalName)
|
|
||||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
|
||||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
|
||||||
} else {
|
|
||||||
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
|
||||||
|
|
||||||
/** This is really not good to do this transformation, but it is needed since there're
|
|
||||||
* some tests based on 0.82 saved model in which the objective is "reg:linear" */
|
|
||||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
|
||||||
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
|
||||||
|
|
||||||
private def handleBrokenlyChangedName(paramName: String): String = {
|
|
||||||
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
|
||||||
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2022-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark.ml.util
|
|
||||||
|
|
||||||
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
|
|
||||||
import org.apache.spark.ml.linalg.VectorUDT
|
|
||||||
|
|
||||||
object XGBoostSchemaUtils {
|
|
||||||
|
|
||||||
/** check if the dataType is VectorUDT */
|
|
||||||
def isVectorUDFType(dataType: DataType): Boolean = {
|
|
||||||
dataType match {
|
|
||||||
case _: VectorUDT => true
|
|
||||||
case _ => false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** The feature columns will be vectorized by VectorAssembler first, which only
|
|
||||||
* supports Numeric, Boolean and VectorUDT types */
|
|
||||||
def checkFeatureColumnType(dataType: DataType): Unit = {
|
|
||||||
dataType match {
|
|
||||||
case _: NumericType | BooleanType =>
|
|
||||||
case _: VectorUDT =>
|
|
||||||
case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " +
|
|
||||||
s"boolean and VectorUDT types, found: ${d}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def checkNumericType(
|
|
||||||
schema: StructType,
|
|
||||||
colName: String,
|
|
||||||
msg: String = ""): Unit = {
|
|
||||||
SchemaUtils.checkNumericType(schema, colName, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.xgboost
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.classification.ProbabilisticClassifierParams
|
||||||
|
import org.apache.spark.ml.linalg.VectorUDT
|
||||||
|
import org.apache.spark.ml.param.Params
|
||||||
|
import org.apache.spark.ml.util.{DatasetUtils, DefaultParamsReader, DefaultParamsWriter, SchemaUtils}
|
||||||
|
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||||
|
import org.apache.spark.sql.Dataset
|
||||||
|
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
|
import org.json4s.{JObject, JValue}
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.NonXGBoostParams
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost classification spark-specific parameters which should not be passed
|
||||||
|
* into the xgboost library
|
||||||
|
*
|
||||||
|
* @tparam T should be XGBoostClassifier or XGBoostClassificationModel
|
||||||
|
*/
|
||||||
|
trait XGBProbabilisticClassifierParams[T <: Params]
|
||||||
|
extends ProbabilisticClassifierParams with NonXGBoostParams {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* XGBoost doesn't use validateAndTransformSchema since spark validateAndTransformSchema
|
||||||
|
* needs to ensure the feature is vector type
|
||||||
|
*/
|
||||||
|
override protected def validateAndTransformSchema(
|
||||||
|
schema: StructType,
|
||||||
|
fitting: Boolean,
|
||||||
|
featuresDataType: DataType): StructType = {
|
||||||
|
var outputSchema = SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
|
||||||
|
outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(rawPredictionCol))
|
||||||
|
outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(probabilityCol))
|
||||||
|
outputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
addNonXGBoostParam(rawPredictionCol, probabilityCol, thresholds)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Utils to access the spark internal functions */
|
||||||
|
object SparkUtils {
|
||||||
|
|
||||||
|
def getNumClasses(dataset: Dataset[_], labelCol: String, maxNumClasses: Int = 100): Int = {
|
||||||
|
DatasetUtils.getNumClasses(dataset, labelCol, maxNumClasses)
|
||||||
|
}
|
||||||
|
|
||||||
|
def checkNumericType(schema: StructType, colName: String, msg: String = ""): Unit = {
|
||||||
|
SchemaUtils.checkNumericType(schema, colName, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
def saveMetadata(instance: Params,
|
||||||
|
path: String,
|
||||||
|
sc: SparkContext,
|
||||||
|
extraMetadata: Option[JObject] = None,
|
||||||
|
paramMap: Option[JValue] = None): Unit = {
|
||||||
|
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, paramMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||||
|
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
||||||
|
}
|
||||||
|
|
||||||
|
def appendColumn(schema: StructType,
|
||||||
|
colName: String,
|
||||||
|
dataType: DataType,
|
||||||
|
nullable: Boolean = false): StructType = {
|
||||||
|
SchemaUtils.appendColumn(schema, colName, dataType, nullable)
|
||||||
|
}
|
||||||
|
|
||||||
|
def appendVectorUDTColumn(schema: StructType,
|
||||||
|
colName: String,
|
||||||
|
dataType: DataType = new VectorUDT,
|
||||||
|
nullable: Boolean = false): StructType = {
|
||||||
|
SchemaUtils.appendColumn(schema, colName, dataType, nullable)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,21 +16,11 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.util.concurrent.LinkedBlockingDeque
|
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||||
|
|
||||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||||
val classifier = new XGBoostClassifier(paramMap)
|
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
|
||||||
xgbParamsFactory.buildXGBRuntimeParams
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||||
/*
|
/*
|
||||||
@ -113,9 +103,11 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
|||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
"silent" -> "1",
|
"silent" -> "1",
|
||||||
"objective" -> "binary:logistic")
|
"objective" -> "binary:logistic")
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
val trainingDF = smallBinaryClassificationVector
|
||||||
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
|
val model = new XGBoostClassifier(paramMap)
|
||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
.setNumWorkers(numWorkers)
|
||||||
|
.setNumRound(10)
|
||||||
|
.fit(trainingDF)
|
||||||
val prediction = model.transform(trainingDF)
|
val prediction = model.transform(trainingDF)
|
||||||
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
|
||||||
// threads
|
// threads
|
||||||
|
|||||||
@ -16,10 +16,12 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import scala.collection.mutable.ListBuffer
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,114 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
|
||||||
|
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
|
|
||||||
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
|
||||||
|
|
||||||
test("perform deterministic partitioning when checkpointInternal and" +
|
|
||||||
" checkpointPath is set (Classifier)") {
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
|
||||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
|
||||||
assert(xgbClassifier.needDeterministicRepartitioning)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("perform deterministic partitioning when checkpointInternal and" +
|
|
||||||
" checkpointPath is set (Regressor)") {
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
|
||||||
val xgbRegressor = new XGBoostRegressor(paramMap)
|
|
||||||
assert(xgbRegressor.needDeterministicRepartitioning)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("deterministic partitioning takes effect with various parts of data") {
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
|
|
||||||
// have to produce the identical RDDs
|
|
||||||
val transformedDFs = (1 until 6).map(shuffleCount => {
|
|
||||||
var resultDF = trainingDF
|
|
||||||
for (i <- 0 until shuffleCount) {
|
|
||||||
resultDF = resultDF.repartition(numWorkers)
|
|
||||||
}
|
|
||||||
resultDF
|
|
||||||
})
|
|
||||||
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
|
||||||
PackedParams(col("label"),
|
|
||||||
col("features"),
|
|
||||||
lit(1.0),
|
|
||||||
lit(Float.NaN),
|
|
||||||
None,
|
|
||||||
numWorkers,
|
|
||||||
deterministicPartition = true),
|
|
||||||
df
|
|
||||||
).head)
|
|
||||||
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
|
|
||||||
case (partitionIndex, labelPoints) =>
|
|
||||||
Iterator((partitionIndex, labelPoints.toList))
|
|
||||||
}.collect().toMap)
|
|
||||||
resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
|
|
||||||
assert(map1.keys.toSet === map2.keys.toSet)
|
|
||||||
for ((parIdx, labeledPoints) <- map1) {
|
|
||||||
val sortedA = labeledPoints.sortBy(_.hashCode())
|
|
||||||
val sortedB = map2(parIdx).sortBy(_.hashCode())
|
|
||||||
assert(sortedA.length === sortedB.length)
|
|
||||||
assert(sortedA.indices.forall(idx =>
|
|
||||||
sortedA(idx).values.toSet === sortedB(idx).values.toSet))
|
|
||||||
}
|
|
||||||
map2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("deterministic partitioning has a uniform repartition on dataset with missing values") {
|
|
||||||
val N = 10000
|
|
||||||
val dataset = (0 until N).map{ n =>
|
|
||||||
(n, n % 2, Vectors.sparse(3, Array(0, 1, 2), Array(Double.NaN, n, Double.NaN)))
|
|
||||||
}
|
|
||||||
|
|
||||||
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
|
|
||||||
|
|
||||||
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
|
||||||
PackedParams(col("label"),
|
|
||||||
col("features"),
|
|
||||||
lit(1.0),
|
|
||||||
lit(Float.NaN),
|
|
||||||
None,
|
|
||||||
10,
|
|
||||||
deterministicPartition = true), df
|
|
||||||
).head
|
|
||||||
|
|
||||||
val partitionsSizes = dfRepartitioned
|
|
||||||
.mapPartitions(iter => Array(iter.size.toDouble).iterator, true)
|
|
||||||
.collect()
|
|
||||||
val partitionMean = partitionsSizes.sum / partitionsSizes.length
|
|
||||||
val squaredDiffSum = partitionsSizes
|
|
||||||
.map(partitionSize => Math.pow(partitionSize - partitionMean, 2))
|
|
||||||
val standardDeviation = math.sqrt(squaredDiffSum.sum / squaredDiffSum.length)
|
|
||||||
|
|
||||||
assert(standardDeviation < math.sqrt(N.toDouble))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -16,9 +16,10 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
|
|
||||||
class EvalError extends EvalTrait {
|
class EvalError extends EvalTrait {
|
||||||
|
|
||||||
|
|||||||
@ -1,131 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2023 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
||||||
|
|
||||||
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
|
||||||
|
|
||||||
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
|
||||||
Map[String, Any] = {
|
|
||||||
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
|
||||||
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def createNewModels():
|
|
||||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
|
||||||
val (model2, model4) = {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val paramMap = produceParamMap(tmpPath, 2)
|
|
||||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
|
||||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
|
||||||
}
|
|
||||||
(tmpPath, model2, model4)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test update/load models") {
|
|
||||||
val (tmpPath, model2, model4) = createNewModels()
|
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
||||||
|
|
||||||
manager.updateCheckpoint(model2._booster.booster)
|
|
||||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
||||||
assert(files.length == 1)
|
|
||||||
assert(files.head.getPath.getName == "1.ubj")
|
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
|
||||||
|
|
||||||
manager.updateCheckpoint(model4._booster)
|
|
||||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
||||||
assert(files.length == 1)
|
|
||||||
assert(files.head.getPath.getName == "3.ubj")
|
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test cleanUpHigherVersions") {
|
|
||||||
val (tmpPath, model2, model4) = createNewModels()
|
|
||||||
|
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
||||||
manager.updateCheckpoint(model4._booster)
|
|
||||||
manager.cleanUpHigherVersions(3)
|
|
||||||
assert(new File(s"$tmpPath/3.ubj").exists())
|
|
||||||
|
|
||||||
manager.cleanUpHigherVersions(2)
|
|
||||||
assert(!new File(s"$tmpPath/3.ubj").exists())
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test checkpoint rounds") {
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
val (tmpPath, model2, model4) = createNewModels()
|
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
||||||
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
|
|
||||||
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
|
|
||||||
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
||||||
|
|
||||||
val paramMap = produceParamMap(tmpPath, 2)
|
|
||||||
|
|
||||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
|
||||||
val skipCleanCheckpointMap =
|
|
||||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
|
||||||
|
|
||||||
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
|
||||||
|
|
||||||
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
|
||||||
|
|
||||||
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
|
|
||||||
if (skipCleanCheckpoint) {
|
|
||||||
// Check only one model is kept after training
|
|
||||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
||||||
assert(files.length == 1)
|
|
||||||
assert(files.head.getPath.getName == "4.ubj")
|
|
||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
|
|
||||||
// Train next model based on prev model
|
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
|
||||||
assert(error(tmpModel) >= error(prevModel._booster))
|
|
||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
|
||||||
assert(error(nextModel._booster) < 0.1)
|
|
||||||
} else {
|
|
||||||
assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with checkpoint boosters") {
|
|
||||||
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with checkpoint boosters with cached training dataset") {
|
|
||||||
trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("the checkpoint file should be cleaned after a successful training") {
|
|
||||||
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,70 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import org.apache.spark.Partitioner
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
|
|
||||||
|
|
||||||
test("transform throwing exception if feature size of dataset is greater than model's") {
|
|
||||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
|
||||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
|
||||||
val r = new Random(0)
|
|
||||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
|
||||||
// if feature size of data is not equal to 251
|
|
||||||
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
|
||||||
toDF("feature", "label")
|
|
||||||
for (x <- 1 to 252) {
|
|
||||||
df = df.withColumn(s"feature_${x}", lit(1))
|
|
||||||
}
|
|
||||||
val assembler = new VectorAssembler()
|
|
||||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val thrown = intercept[Exception] {
|
|
||||||
model.transform(assembler.transform(df)).show()
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains(
|
|
||||||
"Number of columns does not match number of features in booster"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("train throwing exception if feature size of dataset is different on distributed train") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val sparkSession = ss
|
|
||||||
import sparkSession.implicits._
|
|
||||||
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
|
|
||||||
.map(lp => (lp.label, lp)).partitionBy(
|
|
||||||
new Partitioner {
|
|
||||||
override def numPartitions: Int = 2
|
|
||||||
|
|
||||||
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
|
|
||||||
}
|
|
||||||
).map(_._2).zipWithIndex().map {
|
|
||||||
case (lp, id) =>
|
|
||||||
(id, lp.label, lp.features)
|
|
||||||
}.toDF("id", "label", "features")
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
xgb.fit(repartitioned)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,235 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
|
||||||
import org.apache.spark.sql.DataFrame
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
|
||||||
|
|
||||||
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
|
|
||||||
test("dense vectors containing missing value") {
|
|
||||||
def buildDenseDataFrame(): DataFrame = {
|
|
||||||
val numRows = 100
|
|
||||||
val numCols = 5
|
|
||||||
val data = (0 until numRows).map { x =>
|
|
||||||
val label = Random.nextInt(2)
|
|
||||||
val values = Array.tabulate[Double](numCols) { c =>
|
|
||||||
if (c == numCols - 1) 0 else Random.nextDouble
|
|
||||||
}
|
|
||||||
(label, Vectors.dense(values))
|
|
||||||
}
|
|
||||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
|
||||||
}
|
|
||||||
val denseDF = buildDenseDataFrame().repartition(4)
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
|
||||||
model.transform(denseDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("handle Float.NaN as missing value correctly") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val testDF = Seq(
|
|
||||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0),
|
|
||||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
model.transform(inputDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("specify a non-zero missing value but with dense vector does not stop" +
|
|
||||||
" application") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
|
||||||
// vector,
|
|
||||||
val testDF = Seq(
|
|
||||||
(1.0f, 0.0f, -1.0f, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, -1.0f, 0.0f, 0.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0),
|
|
||||||
(-1.0f, 0.0f, 0.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
model.transform(inputDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("specify a non-zero missing value and meet an empty vector we should" +
|
|
||||||
" stop the application") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val testDF = Seq(
|
|
||||||
(1.0f, 0.0f, -1.0f, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, -1.0f, 0.0f, 0.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 1.0),// empty vector
|
|
||||||
(-1.0f, 0.0f, 0.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
|
||||||
intercept[SparkException] {
|
|
||||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("specify a non-zero missing value and meet a Sparse vector we should" +
|
|
||||||
" stop the application") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
|
||||||
// vector,
|
|
||||||
val testDF = Seq(
|
|
||||||
(1.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "col4", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
inputDF.show()
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
|
||||||
intercept[SparkException] {
|
|
||||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("specify a non-zero missing value but set allow_non_zero_for_missing " +
|
|
||||||
"does not stop application") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
|
||||||
// vector,
|
|
||||||
val testDF = Seq(
|
|
||||||
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "col4", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
inputDF.show()
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> -1.0f,
|
|
||||||
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
model.transform(inputDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/dmlc/xgboost/pull/5929
|
|
||||||
test("handle the empty last row correctly with a missing value as 0") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
|
||||||
// vector,
|
|
||||||
val testDF = Seq(
|
|
||||||
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 0.0f, 0.0)
|
|
||||||
).toDF("col1", "col2", "col3", "col4", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
inputDF.show()
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> 0.0f,
|
|
||||||
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
model.transform(inputDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Getter and setter for AllowNonZeroForMissingValue works") {
|
|
||||||
{
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val classifier = new XGBoostClassifier(paramMap)
|
|
||||||
classifier.setAllowNonZeroForMissing(true)
|
|
||||||
assert(classifier.getAllowNonZeroForMissingValue)
|
|
||||||
classifier.setAllowNonZeroForMissing(false)
|
|
||||||
assert(!classifier.getAllowNonZeroForMissingValue)
|
|
||||||
val model = classifier.fit(training)
|
|
||||||
model.setAllowNonZeroForMissing(true)
|
|
||||||
assert(model.getAllowNonZeroForMissingValue)
|
|
||||||
model.setAllowNonZeroForMissing(false)
|
|
||||||
assert(!model.getAllowNonZeroForMissingValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val regressor = new XGBoostRegressor(paramMap)
|
|
||||||
regressor.setAllowNonZeroForMissing(true)
|
|
||||||
assert(regressor.getAllowNonZeroForMissingValue)
|
|
||||||
regressor.setAllowNonZeroForMissing(false)
|
|
||||||
assert(!regressor.getAllowNonZeroForMissingValue)
|
|
||||||
val model = regressor.fit(training)
|
|
||||||
model.setAllowNonZeroForMissing(true)
|
|
||||||
assert(model.getAllowNonZeroForMissingValue)
|
|
||||||
model.setAllowNonZeroForMissing(false)
|
|
||||||
assert(!model.getAllowNonZeroForMissingValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,104 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import org.scalatest.BeforeAndAfterAll
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
|
||||||
import org.apache.spark.ml.param.ParamMap
|
|
||||||
|
|
||||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
|
||||||
test("XGBoost and Spark parameters synchronize correctly") {
|
|
||||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
|
||||||
"objective_type" -> "classification")
|
|
||||||
// from xgboost params to spark params
|
|
||||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
|
||||||
assert(xgb.getEta === 1.0)
|
|
||||||
assert(xgb.getObjective === "binary:logistic")
|
|
||||||
assert(xgb.getObjectiveType === "classification")
|
|
||||||
// from spark to xgboost params
|
|
||||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
|
||||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
|
||||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
|
||||||
}
|
|
||||||
|
|
||||||
test("fail training elegantly with unsupported objective function") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers)
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
intercept[SparkException] {
|
|
||||||
xgb.fit(trainingDF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("fail training elegantly with unsupported eval metrics") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
intercept[SparkException] {
|
|
||||||
xgb.fit(trainingDF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("custom_eval does not support early stopping") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1",
|
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
|
|
||||||
val thrown = intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(paramMap).fit(trainingDF)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(thrown.getMessage.contains("custom_eval does not support early stopping"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("early stopping should work without custom_eval setting") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "silent" -> "1",
|
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
|
|
||||||
new XGBoostClassifier(paramMap).fit(trainingDF)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Default parameters") {
|
|
||||||
val classifier = new XGBoostClassifier()
|
|
||||||
intercept[NoSuchElementException] {
|
|
||||||
classifier.getBaseScore
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("approx can't be used for gpu train") {
|
|
||||||
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val thrown = intercept[IllegalArgumentException] {
|
|
||||||
xgb.fit(trainingDF)
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
|
|
||||||
"for Spark GPU cluster"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -18,24 +18,25 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.{File, FileInputStream}
|
import java.io.{File, FileInputStream}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import org.apache.commons.io.IOUtils
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.BeforeAndAfterEach
|
import org.scalatest.BeforeAndAfterEach
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
import scala.math.min
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import org.apache.commons.io.IOUtils
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.{withResource, XGBLabeledPointFeatures}
|
||||||
|
|
||||||
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
trait PerTest extends BeforeAndAfterEach {
|
||||||
|
self: AnyFunSuite =>
|
||||||
|
|
||||||
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
protected val numWorkers: Int = 4
|
||||||
|
|
||||||
@transient private var currentSession: SparkSession = _
|
@transient private var currentSession: SparkSession = _
|
||||||
|
|
||||||
def ss: SparkSession = getOrCreateSession
|
def ss: SparkSession = getOrCreateSession
|
||||||
|
|
||||||
implicit def sc: SparkContext = ss.sparkContext
|
implicit def sc: SparkContext = ss.sparkContext
|
||||||
|
|
||||||
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||||
@ -45,10 +46,11 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
|||||||
.config("spark.driver.memory", "512m")
|
.config("spark.driver.memory", "512m")
|
||||||
.config("spark.barrier.sync.timeout", 10)
|
.config("spark.barrier.sync.timeout", 10)
|
||||||
.config("spark.task.cpus", 1)
|
.config("spark.task.cpus", 1)
|
||||||
|
.config("spark.stage.maxConsecutiveAttempts", 1)
|
||||||
|
|
||||||
override def beforeEach(): Unit = getOrCreateSession
|
override def beforeEach(): Unit = getOrCreateSession
|
||||||
|
|
||||||
override def afterEach() {
|
override def afterEach(): Unit = {
|
||||||
if (currentSession != null) {
|
if (currentSession != null) {
|
||||||
currentSession.stop()
|
currentSession.stop()
|
||||||
cleanExternalCache(currentSession.sparkContext.appName)
|
cleanExternalCache(currentSession.sparkContext.appName)
|
||||||
@ -74,42 +76,25 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
|||||||
protected def buildDataFrame(
|
protected def buildDataFrame(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val it = labeledPoints.iterator.zipWithIndex
|
val it = labeledPoints.iterator.zipWithIndex
|
||||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||||
(id, labeledPoint.label, labeledPoint.features)
|
(id, labeledPoint.label, labeledPoint.features, labeledPoint.weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||||
.toDF("id", "label", "features")
|
.toDF("id", "label", "features", "weight")
|
||||||
}
|
|
||||||
|
|
||||||
protected def buildDataFrameWithRandSort(
|
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
|
||||||
val df = buildDataFrame(labeledPoints, numPartitions)
|
|
||||||
val rndSortedRDD = df.rdd.mapPartitions { iter =>
|
|
||||||
iter.map(_ -> Random.nextDouble()).toList
|
|
||||||
.sortBy(_._2)
|
|
||||||
.map(_._1).iterator
|
|
||||||
}
|
|
||||||
ss.createDataFrame(rndSortedRDD, df.schema)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def buildDataFrameWithGroup(
|
protected def buildDataFrameWithGroup(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val it = labeledPoints.iterator.zipWithIndex
|
val it = labeledPoints.iterator.zipWithIndex
|
||||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||||
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group, labeledPoint.weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||||
.toDF("id", "label", "features", "group")
|
.toDF("id", "label", "features", "group", "weight")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
||||||
withResource(new FileInputStream(lhs)) { lfis =>
|
withResource(new FileInputStream(lhs)) { lfis =>
|
||||||
withResource(new FileInputStream(rhs)) { rfis =>
|
withResource(new FileInputStream(rhs)) { rfis =>
|
||||||
@ -118,12 +103,32 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Executes the provided code block and then closes the resource */
|
def smallBinaryClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
|
||||||
try {
|
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
|
||||||
block(r)
|
(0.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
|
||||||
} finally {
|
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
|
||||||
r.close()
|
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
|
||||||
}
|
(1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
|
||||||
}
|
))).toDF("label", "margin", "weight", "features")
|
||||||
|
|
||||||
|
def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
|
||||||
|
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
|
||||||
|
(2.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
|
||||||
|
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
|
||||||
|
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
|
||||||
|
(2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
|
||||||
|
))).toDF("label", "margin", "weight", "features")
|
||||||
|
|
||||||
|
|
||||||
|
def smallGroupVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 0.5, 2.0, Vectors.dense(1.0, 2.0, 3.0)),
|
||||||
|
(0.0, 1, 0.4, 1.0, Vectors.dense(0.0, 0.0, 0.0)),
|
||||||
|
(0.0, 1, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
|
||||||
|
(1.0, 0, 1.2, 2.0, Vectors.dense(2.0, 0.0, 4.0)),
|
||||||
|
(1.0, 2, -0.5, 3.0, Vectors.dense(0.2, 1.2, 2.0)),
|
||||||
|
(0.0, 2, -0.4, 3.0, Vectors.dense(0.5, 2.2, 1.7))
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,195 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
import java.util.Arrays
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
import org.apache.spark.ml.feature._
|
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
|
||||||
|
|
||||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
|
|
||||||
val xgbc = new XGBoostClassifier(paramMap)
|
|
||||||
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
|
||||||
xgbc.write.overwrite().save(xgbcPath)
|
|
||||||
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
|
||||||
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
|
||||||
paramMap.foreach {
|
|
||||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
|
||||||
}
|
|
||||||
|
|
||||||
val model = xgbc.fit(trainingDF)
|
|
||||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults < 0.1)
|
|
||||||
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
|
|
||||||
model.write.overwrite.save(xgbcModelPath)
|
|
||||||
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
|
||||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
|
||||||
|
|
||||||
assert(model.getEta === model2.getEta)
|
|
||||||
assert(model.getNumRound === model2.getNumRound)
|
|
||||||
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
|
||||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults === evalResults2)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
|
||||||
val testDM = new DMatrix(Regression.test.iterator)
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
|
|
||||||
val xgbr = new XGBoostRegressor(paramMap)
|
|
||||||
val xgbrPath = new File(tempDir.toFile, "xgbr").getPath
|
|
||||||
xgbr.write.overwrite().save(xgbrPath)
|
|
||||||
val xgbr2 = XGBoostRegressor.load(xgbrPath)
|
|
||||||
val paramMap2 = xgbr2.MLlib2XGBoostParams
|
|
||||||
paramMap.foreach {
|
|
||||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
|
||||||
}
|
|
||||||
|
|
||||||
val model = xgbr.fit(trainingDF)
|
|
||||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults < 0.1)
|
|
||||||
val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath
|
|
||||||
model.write.overwrite.save(xgbrModelPath)
|
|
||||||
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
|
|
||||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
|
||||||
|
|
||||||
assert(model.getEta === model2.getEta)
|
|
||||||
assert(model.getNumRound === model2.getNumRound)
|
|
||||||
assert(model.getPredictionCol === model2.getPredictionCol)
|
|
||||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults === evalResults2)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
|
||||||
val r = new Random(0)
|
|
||||||
// maybe move to shared context, but requires session to import implicits
|
|
||||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
|
||||||
toDF("feature", "label")
|
|
||||||
|
|
||||||
val assembler = new VectorAssembler()
|
|
||||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
|
||||||
.setOutputCol("features")
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
|
|
||||||
// Construct MLlib pipeline, save and load
|
|
||||||
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
|
|
||||||
val pipePath = new File(tempDir.toFile, "pipeline").getPath
|
|
||||||
pipeline.write.overwrite().save(pipePath)
|
|
||||||
val pipeline2 = Pipeline.read.load(pipePath)
|
|
||||||
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
|
|
||||||
val paramMap2 = xgb2.MLlib2XGBoostParams
|
|
||||||
paramMap.foreach {
|
|
||||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model training, save and load
|
|
||||||
val pipeModel = pipeline.fit(df)
|
|
||||||
val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath
|
|
||||||
pipeModel.write.overwrite.save(pipeModelPath)
|
|
||||||
val pipeModel2 = PipelineModel.load(pipeModelPath)
|
|
||||||
|
|
||||||
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
|
|
||||||
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
|
|
||||||
|
|
||||||
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
|
|
||||||
|
|
||||||
assert(xgbModel.getEta === xgbModel2.getEta)
|
|
||||||
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
|
||||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
|
|
||||||
"using custom Eval and Obj") {
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
|
|
||||||
"num_round" -> "10", "num_workers" -> numWorkers, "objective" -> "binary:logistic")
|
|
||||||
|
|
||||||
val xgbc = new XGBoostClassifier(paramMap)
|
|
||||||
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
|
||||||
xgbc.write.overwrite().save(xgbcPath)
|
|
||||||
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
|
||||||
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
|
||||||
paramMap.foreach {
|
|
||||||
case ("custom_eval", v) => assert(v.isInstanceOf[EvalError])
|
|
||||||
case ("custom_obj", v) =>
|
|
||||||
assert(v.isInstanceOf[CustomObj])
|
|
||||||
assert(v.asInstanceOf[CustomObj].customParameter ==
|
|
||||||
paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter)
|
|
||||||
case (_, _) =>
|
|
||||||
}
|
|
||||||
|
|
||||||
val eval = new EvalError()
|
|
||||||
|
|
||||||
val model = xgbc.fit(trainingDF)
|
|
||||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults < 0.1)
|
|
||||||
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
|
|
||||||
model.write.overwrite.save(xgbcModelPath)
|
|
||||||
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
|
||||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
|
||||||
|
|
||||||
assert(model.getEta === model2.getEta)
|
|
||||||
assert(model.getNumRound === model2.getNumRound)
|
|
||||||
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
|
||||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(evalResults === evalResults2)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("cross-version model loading (0.82)") {
|
|
||||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
|
||||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
|
||||||
val r = new Random(0)
|
|
||||||
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
|
||||||
toDF("feature", "label")
|
|
||||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
|
||||||
// if feature size of data is not equal to 251
|
|
||||||
for (x <- 1 to 250) {
|
|
||||||
df = df.withColumn(s"feature_${x}", lit(1))
|
|
||||||
}
|
|
||||||
val assembler = new VectorAssembler()
|
|
||||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
|
||||||
.setOutputCol("features")
|
|
||||||
df = assembler.transform(df)
|
|
||||||
for (x <- 1 to 250) {
|
|
||||||
df = df.drop(s"feature_${x}")
|
|
||||||
}
|
|
||||||
model.transform(df).show()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,8 +16,9 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
trait TrainTestData {
|
trait TrainTestData {
|
||||||
@ -31,8 +32,8 @@ trait TrainTestData {
|
|||||||
Source.fromInputStream(is).getLines()
|
Source.fromInputStream(is).getLines()
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
|
protected def getLabeledPoints(resource: String, featureSize: Int,
|
||||||
Seq[XGBLabeledPoint] = {
|
zeroBased: Boolean): Seq[XGBLabeledPoint] = {
|
||||||
getResourceLines(resource).map { line =>
|
getResourceLines(resource).map { line =>
|
||||||
val labelAndFeatures = line.split(" ")
|
val labelAndFeatures = line.split(" ")
|
||||||
val label = labelAndFeatures.head.toFloat
|
val label = labelAndFeatures.head.toFloat
|
||||||
@ -65,10 +66,32 @@ trait TrainTestData {
|
|||||||
object Classification extends TrainTestData {
|
object Classification extends TrainTestData {
|
||||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
|
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
|
||||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
|
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
|
||||||
|
|
||||||
|
Random.setSeed(10)
|
||||||
|
val randomWeights = Array.fill(train.length)(Random.nextFloat())
|
||||||
|
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
|
||||||
|
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
|
||||||
|
randomWeights(index), v.group, v.baseMargin)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object MultiClassification extends TrainTestData {
|
object MultiClassification extends TrainTestData {
|
||||||
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
|
|
||||||
|
private def split(): (Seq[XGBLabeledPoint], Seq[XGBLabeledPoint]) = {
|
||||||
|
val tmp: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
|
||||||
|
Random.setSeed(100)
|
||||||
|
val randomizedTmp = Random.shuffle(tmp)
|
||||||
|
val splitIndex = (randomizedTmp.length * 0.8).toInt
|
||||||
|
(randomizedTmp.take(splitIndex), randomizedTmp.drop(splitIndex))
|
||||||
|
}
|
||||||
|
|
||||||
|
val (train, test) = split()
|
||||||
|
Random.setSeed(10)
|
||||||
|
val randomWeights = Array.fill(train.length)(Random.nextFloat())
|
||||||
|
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
|
||||||
|
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
|
||||||
|
randomWeights(index), v.group, v.baseMargin)
|
||||||
|
}
|
||||||
|
|
||||||
private def getLabeledPoints(resource: String): Seq[XGBLabeledPoint] = {
|
private def getLabeledPoints(resource: String): Seq[XGBLabeledPoint] = {
|
||||||
getResourceLines(resource).map { line =>
|
getResourceLines(resource).map { line =>
|
||||||
@ -92,31 +115,25 @@ object Regression extends TrainTestData {
|
|||||||
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
|
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
|
||||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
||||||
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
|
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
|
||||||
}
|
|
||||||
|
|
||||||
object Ranking extends TrainTestData {
|
Random.setSeed(10)
|
||||||
|
val randomWeights = Array.fill(train.length)(Random.nextFloat())
|
||||||
|
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
|
||||||
|
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
|
||||||
|
randomWeights(index), v.group, v.baseMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
object Ranking extends TrainTestData {
|
||||||
val RANK_COL_NUM = 3
|
val RANK_COL_NUM = 3
|
||||||
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
|
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
|
||||||
|
// use the group as the weight
|
||||||
|
val trainWithWeight = train.map { labelPoint =>
|
||||||
|
XGBLabeledPoint(labelPoint.label, labelPoint.size, labelPoint.indices, labelPoint.values,
|
||||||
|
labelPoint.group, labelPoint.group, labelPoint.baseMargin)
|
||||||
|
}
|
||||||
|
val trainGroups = train.map(_.group)
|
||||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
|
||||||
"/rank.test.txt", RANK_COL_NUM, zeroBased = false)
|
"/rank.test.txt", RANK_COL_NUM, zeroBased = false)
|
||||||
|
|
||||||
private def getGroups(resource: String): Seq[Int] = {
|
|
||||||
getResourceLines(resource).map(_.toInt).toList
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
object Synthetic extends {
|
|
||||||
val TRAIN_COL_NUM = 3
|
|
||||||
val TRAIN_WRONG_COL_NUM = 2
|
|
||||||
val train: Seq[XGBLabeledPoint] = Seq(
|
|
||||||
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
|
|
||||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
|
||||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
|
|
||||||
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f))
|
|
||||||
)
|
|
||||||
|
|
||||||
val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq(
|
|
||||||
XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
|
|
||||||
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f))
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,241 +16,212 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.io.{File, FileInputStream}
|
import java.io.File
|
||||||
|
|
||||||
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.sql.DataFrame
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
|
||||||
import org.apache.spark.ml.linalg._
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
import org.apache.commons.io.IOUtils
|
|
||||||
|
|
||||||
import org.apache.spark.Partitioner
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.json4s.{DefaultFormats, Formats}
|
|
||||||
import org.json4s.jackson.parseJson
|
|
||||||
|
|
||||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||||
|
|
||||||
protected val treeMethod: String = "auto"
|
test("XGBoostClassifier copy") {
|
||||||
|
val classifier = new XGBoostClassifier().setNthread(2).setNumWorkers(10)
|
||||||
|
val classifierCopied = classifier.copy(ParamMap.empty)
|
||||||
|
|
||||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
assert(classifier.uid === classifierCopied.uid)
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
assert(classifier.getNthread === classifierCopied.getNthread)
|
||||||
val testDF = buildDataFrame(Classification.test)
|
assert(classifier.getNumWorkers === classifier.getNumWorkers)
|
||||||
val round = 5
|
|
||||||
|
|
||||||
val paramMap = Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> round,
|
|
||||||
"tree_method" -> treeMethod,
|
|
||||||
"num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
// Set params in XGBoost way
|
|
||||||
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
|
|
||||||
// Set params in MLlib way
|
|
||||||
val model2 = new XGBoostClassifier()
|
|
||||||
.setEta(1)
|
|
||||||
.setMaxDepth(6)
|
|
||||||
.setSilent(1)
|
|
||||||
.setObjective("binary:logistic")
|
|
||||||
.setNumRound(round)
|
|
||||||
.setNumWorkers(numWorkers)
|
|
||||||
.fit(trainingDF)
|
|
||||||
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
|
||||||
|
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
|
||||||
assert(p1 === p2)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test schema of XGBoostClassificationModel") {
|
test("XGBoostClassification copy") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val model = new XGBoostClassificationModel("hello").setNthread(2).setNumWorkers(10)
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
val modelCopied = model.copy(ParamMap.empty)
|
||||||
"tree_method" -> treeMethod)
|
assert(model.uid === modelCopied.uid)
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
assert(model.getNthread === modelCopied.getNthread)
|
||||||
val testDF = buildDataFrame(Classification.test)
|
assert(model.getNumWorkers === modelCopied.getNumWorkers)
|
||||||
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
|
|
||||||
|
|
||||||
model.setRawPredictionCol("raw_prediction")
|
|
||||||
.setProbabilityCol("probability_prediction")
|
|
||||||
.setPredictionCol("final_prediction")
|
|
||||||
var predictionDF = model.transform(testDF)
|
|
||||||
assert(predictionDF.columns.contains("id"))
|
|
||||||
assert(predictionDF.columns.contains("features"))
|
|
||||||
assert(predictionDF.columns.contains("label"))
|
|
||||||
assert(predictionDF.columns.contains("raw_prediction"))
|
|
||||||
assert(predictionDF.columns.contains("probability_prediction"))
|
|
||||||
assert(predictionDF.columns.contains("final_prediction"))
|
|
||||||
model.setRawPredictionCol("").setPredictionCol("final_prediction")
|
|
||||||
predictionDF = model.transform(testDF)
|
|
||||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
|
||||||
assert(predictionDF.columns.contains("final_prediction"))
|
|
||||||
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
|
|
||||||
predictionDF = model.transform(testDF)
|
|
||||||
assert(predictionDF.columns.contains("raw_prediction"))
|
|
||||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
|
||||||
|
|
||||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
|
||||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("multi class classification") {
|
test("read/write") {
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val trainDf = smallBinaryClassificationVector
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
val xgbParams: Map[String, Any] = Map(
|
||||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
"max_depth" -> 5,
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
"eta" -> 0.2,
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
"objective" -> "binary:logistic"
|
||||||
val model = xgb.fit(trainingDF)
|
)
|
||||||
assert(model.getEta == 0.1)
|
|
||||||
assert(model.getMaxDepth == 6)
|
def check(xgboostParams: XGBoostParams[_]): Unit = {
|
||||||
assert(model.numClasses == 6)
|
assert(xgboostParams.getMaxDepth === 5)
|
||||||
val transformedDf = model.transform(trainingDF)
|
assert(xgboostParams.getEta === 0.2)
|
||||||
assert(!transformedDf.columns.contains("probability"))
|
assert(xgboostParams.getObjective === "binary:logistic")
|
||||||
}
|
}
|
||||||
|
|
||||||
test("objective will be set if not specifying it") {
|
val classifierPath = new File(tempDir.toFile, "classifier").getPath
|
||||||
val training = buildDataFrame(Classification.train)
|
val classifier = new XGBoostClassifier(xgbParams).setNumRound(2)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
check(classifier)
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
assert(!xgb.isDefined(xgb.objective))
|
|
||||||
xgb.fit(training)
|
|
||||||
assert(xgb.getObjective == "binary:logistic")
|
|
||||||
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
classifier.write.overwrite().save(classifierPath)
|
||||||
val paramMap1 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val loadedClassifier = XGBoostClassifier.load(classifierPath)
|
||||||
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
|
check(loadedClassifier)
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
|
||||||
assert(!xgb1.isDefined(xgb1.objective))
|
|
||||||
xgb1.fit(trainingDF)
|
|
||||||
assert(xgb1.getObjective == "multi:softprob")
|
|
||||||
|
|
||||||
// shouldn't change user's objective setting
|
val model = loadedClassifier.fit(trainDf)
|
||||||
val paramMap2 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
check(model)
|
||||||
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
|
assert(model.numClasses === 2)
|
||||||
"tree_method" -> treeMethod, "objective" -> "multi:softmax")
|
|
||||||
val xgb2 = new XGBoostClassifier(paramMap2)
|
val modelPath = new File(tempDir.toFile, "model").getPath
|
||||||
assert(xgb2.getObjective == "multi:softmax")
|
model.write.overwrite().save(modelPath)
|
||||||
xgb2.fit(trainingDF)
|
val modelLoaded = XGBoostClassificationModel.load(modelPath)
|
||||||
assert(xgb2.getObjective == "multi:softmax")
|
assert(modelLoaded.numClasses === 2)
|
||||||
|
check(modelLoaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("use base margin") {
|
test("XGBoostClassificationModel transformed schema") {
|
||||||
val training1 = buildDataFrame(Classification.train)
|
val trainDf = smallBinaryClassificationVector
|
||||||
val training2 = training1.withColumn("margin", functions.rand())
|
val classifier = new XGBoostClassifier().setNumRound(1)
|
||||||
val test = buildDataFrame(Classification.test)
|
val model = classifier.fit(trainDf)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
var out = model.transform(trainDf)
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
|
||||||
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
// Transform should not discard the other columns of the transforming dataframe
|
||||||
val model1 = xgb.fit(training1)
|
Seq("label", "margin", "weight", "features").foreach { v =>
|
||||||
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
|
assert(out.schema.names.contains(v))
|
||||||
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
|
|
||||||
.collect().map(row => row.getAs[Vector](0))
|
|
||||||
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
|
|
||||||
.collect().map(row => row.getAs[Vector](0))
|
|
||||||
var count = 0
|
|
||||||
for ((r1, r2) <- prediction1.zip(prediction2)) {
|
|
||||||
if (!r1.equals(r2)) count = count + 1
|
|
||||||
}
|
|
||||||
assert(count != 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test predictionLeaf") {
|
// Transform needs to add extra columns
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
Seq("rawPrediction", "probability", "prediction").foreach { v =>
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
assert(out.schema.names.contains(v))
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val test = buildDataFrame(Classification.test)
|
|
||||||
val groundTruth = test.count()
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setLeafPredictionCol("predictLeaf")
|
|
||||||
val resultDF = model.transform(test)
|
|
||||||
assert(resultDF.count == groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictLeaf"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test predictionLeaf with empty column name") {
|
assert(out.schema.names.length === 7)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
model.setRawPredictionCol("").setProbabilityCol("")
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
out = model.transform(trainDf)
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val test = buildDataFrame(Classification.test)
|
// rawPrediction="", probability=""
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
Seq("rawPrediction", "probability").foreach { v =>
|
||||||
val model = xgb.fit(training)
|
assert(!out.schema.names.contains(v))
|
||||||
model.setLeafPredictionCol("")
|
|
||||||
val resultDF = model.transform(test)
|
|
||||||
assert(!resultDF.columns.contains("predictLeaf"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test predictionContrib") {
|
assert(out.schema.names.contains("prediction"))
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
out = model.transform(trainDf)
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val test = buildDataFrame(Classification.test)
|
assert(out.schema.names.contains("leaf"))
|
||||||
val groundTruth = test.count()
|
assert(out.schema.names.contains("contrib"))
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
val out1 = classifier.setLeafPredictionCol("leaf1")
|
||||||
model.setContribPredictionCol("predictContrib")
|
.setContribPredictionCol("contrib1")
|
||||||
val resultDF = model.transform(buildDataFrame(Classification.test))
|
.fit(trainDf).transform(trainDf)
|
||||||
assert(resultDF.count == groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictContrib"))
|
assert(out1.schema.names.contains("leaf1"))
|
||||||
|
assert(out1.schema.names.contains("contrib1"))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test predictionContrib with empty column name") {
|
test("Supported objectives") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val classifier = new XGBoostClassifier()
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
val df = smallMultiClassificationVector
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
(BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq).foreach { obj =>
|
||||||
val training = buildDataFrame(Classification.train)
|
classifier.setObjective(obj)
|
||||||
val test = buildDataFrame(Classification.test)
|
classifier.validate(df)
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setContribPredictionCol("")
|
|
||||||
val resultDF = model.transform(test)
|
|
||||||
assert(!resultDF.columns.contains("predictContrib"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test predictionLeaf and predictionContrib") {
|
classifier.setObjective("reg:squaredlogerror")
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
intercept[IllegalArgumentException](
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
classifier.validate(df)
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
)
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val test = buildDataFrame(Classification.test)
|
|
||||||
val groundTruth = test.count()
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setLeafPredictionCol("predictLeaf")
|
|
||||||
model.setContribPredictionCol("predictContrib")
|
|
||||||
val resultDF = model.transform(buildDataFrame(Classification.test))
|
|
||||||
assert(resultDF.count == groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictLeaf"))
|
|
||||||
assert(resultDF.columns.contains("predictContrib"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
|
test("BinaryClassification infer objective and num_class") {
|
||||||
|
val trainDf = smallBinaryClassificationVector
|
||||||
|
var classifier = new XGBoostClassifier()
|
||||||
|
assert(classifier.getObjective === "reg:squarederror")
|
||||||
|
assert(classifier.getNumClass === 0)
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
assert(classifier.getObjective === "binary:logistic")
|
||||||
|
assert(!classifier.isSet(classifier.numClass))
|
||||||
|
|
||||||
|
// Infer objective according num class
|
||||||
|
classifier = new XGBoostClassifier()
|
||||||
|
classifier.setNumClass(2)
|
||||||
|
intercept[IllegalArgumentException](
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Infer to num class according to num class
|
||||||
|
classifier = new XGBoostClassifier()
|
||||||
|
classifier.setObjective("binary:logistic")
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
assert(classifier.getObjective === "binary:logistic")
|
||||||
|
assert(!classifier.isSet(classifier.numClass))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("MultiClassification infer objective and num_class") {
|
||||||
|
val trainDf = smallMultiClassificationVector
|
||||||
|
var classifier = new XGBoostClassifier()
|
||||||
|
assert(classifier.getObjective === "reg:squarederror")
|
||||||
|
assert(classifier.getNumClass === 0)
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
assert(classifier.getObjective === "multi:softprob")
|
||||||
|
assert(classifier.getNumClass === 3)
|
||||||
|
|
||||||
|
// Infer to objective according to num class
|
||||||
|
classifier = new XGBoostClassifier()
|
||||||
|
classifier.setNumClass(3)
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
assert(classifier.getObjective === "multi:softprob")
|
||||||
|
assert(classifier.getNumClass === 3)
|
||||||
|
|
||||||
|
// Infer to num class according to objective
|
||||||
|
classifier = new XGBoostClassifier()
|
||||||
|
classifier.setObjective("multi:softmax")
|
||||||
|
classifier.validate(trainDf)
|
||||||
|
assert(classifier.getObjective === "multi:softmax")
|
||||||
|
assert(classifier.getNumClass === 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark binary classification output should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
val testDF = buildDataFrame(Classification.test)
|
val testDF = buildDataFrame(Classification.test)
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
val paramMap = Map("objective" -> "binary:logistic")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("XGBoostClassifier should make correct predictions after upstream random sort") {
|
test("XGBoost-Spark binary classification output with weight should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
val trainingDM = new DMatrix(Classification.trainWithWeight.iterator)
|
||||||
|
trainingDM.setWeight(Classification.randomWeights)
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
val trainingDF = buildDataFrameWithRandSort(Classification.train)
|
val trainingDF = buildDataFrame(Classification.trainWithWeight)
|
||||||
val testDF = buildDataFrameWithRandSort(Classification.test)
|
val testDF = buildDataFrame(Classification.test)
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
val paramMap = Map("objective" -> "binary:logistic")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
|
||||||
|
5, paramMap, Some("weight"))
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq("multi:softprob", "multi:softmax").foreach { objective =>
|
||||||
|
test(s"XGBoost-Spark multi classification with $objective output should match XGBoost4j") {
|
||||||
|
val trainingDM = new DMatrix(MultiClassification.train.iterator)
|
||||||
|
val testDM = new DMatrix(MultiClassification.test.iterator)
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val testDF = buildDataFrame(MultiClassification.test)
|
||||||
|
val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark multi classification output with weight should match XGBoost4j") {
|
||||||
|
val trainingDM = new DMatrix(MultiClassification.trainWithWeight.iterator)
|
||||||
|
trainingDM.setWeight(MultiClassification.randomWeights)
|
||||||
|
val testDM = new DMatrix(MultiClassification.test.iterator)
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.trainWithWeight)
|
||||||
|
val testDF = buildDataFrame(MultiClassification.test)
|
||||||
|
val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap, Some("weight"))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def checkResultsWithXGBoost4j(
|
private def checkResultsWithXGBoost4j(
|
||||||
@ -258,223 +229,73 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
|||||||
testDM: DMatrix,
|
testDM: DMatrix,
|
||||||
trainingDF: DataFrame,
|
trainingDF: DataFrame,
|
||||||
testDF: DataFrame,
|
testDF: DataFrame,
|
||||||
round: Int = 5): Unit = {
|
round: Int = 5,
|
||||||
|
xgbParams: Map[String, Any] = Map.empty,
|
||||||
|
weightCol: Option[String] = None): Unit = {
|
||||||
val paramMap = Map(
|
val paramMap = Map(
|
||||||
"eta" -> "1",
|
"eta" -> "1",
|
||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
"silent" -> "1",
|
|
||||||
"base_score" -> 0.5,
|
"base_score" -> 0.5,
|
||||||
"objective" -> "binary:logistic",
|
"max_bin" -> 16) ++ xgbParams
|
||||||
"tree_method" -> treeMethod,
|
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||||
"max_bin" -> 16)
|
|
||||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
|
||||||
val prediction1 = model1.predict(testDM)
|
|
||||||
|
|
||||||
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
|
val classifier = new XGBoostClassifier(paramMap)
|
||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
.setNumRound(round)
|
||||||
|
.setNumWorkers(numWorkers)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
weightCol.foreach(weight => classifier.setWeightCol(weight))
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).
|
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
assert(left.size === right.size)
|
||||||
|
left.zipWithIndex.foreach { case (leftValue, index) =>
|
||||||
assert(testDF.count() === prediction2.size)
|
assert(leftValue.sameElements(right(index)))
|
||||||
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
|
||||||
for (i <- prediction1.indices) {
|
|
||||||
assert(prediction1(i).length === prediction2(i).values.length - 1)
|
|
||||||
for (j <- prediction1(i).indices) {
|
|
||||||
assert(prediction1(i)(j) === prediction2(i)(j + 1))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
val xgbSparkModel = classifier.fit(trainingDF)
|
||||||
val prediction4 = model2.transform(testDF).
|
val rows = xgbSparkModel.transform(testDF).collect()
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
|
||||||
|
|
||||||
assert(testDF.count() === prediction4.size)
|
// Check Leaf
|
||||||
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
|
||||||
for (i <- prediction3.indices) {
|
val xgbSparkLeaf = rows.map(row =>
|
||||||
assert(prediction3(i).length === prediction4(i).values.length - 1)
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
|
||||||
for (j <- prediction3(i).indices) {
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
assert(prediction3(i)(j) === prediction4(i)(j + 1))
|
|
||||||
|
// Check contrib
|
||||||
|
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
def checkEqualForBinary(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
|
||||||
|
assert(left.size === right.size)
|
||||||
|
left.zipWithIndex.foreach { case (leftValue, index) =>
|
||||||
|
assert(leftValue.length === 1)
|
||||||
|
assert(leftValue.length === right(index).length - 1)
|
||||||
|
assert(leftValue(0) === right(index)(1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check the equality of single instance prediction
|
// Check probability
|
||||||
val firstOfDM = testDM.slice(Array(0))
|
val xgb4jProb = xgb4jModel.predict(testDM)
|
||||||
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
val xgbSparkProb = rows.map(row =>
|
||||||
.head()
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("probability").toArray.map(_.toFloat))).toMap
|
||||||
.getAs[Vector]("features")
|
if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
|
||||||
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
checkEqualForBinary(xgb4jProb, xgbSparkProb)
|
||||||
val prediction6 = model2.predict(firstOfDF)
|
} else {
|
||||||
assert(prediction5 === prediction6)
|
checkEqual(xgb4jProb, xgbSparkProb)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("infrequent features") {
|
// Check rawPrediction
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val xgb4jRawPred = xgb4jModel.predict(testDM, outPutMargin = true)
|
||||||
"objective" -> "binary:logistic",
|
val xgbSparkRawPred = rows.map(row =>
|
||||||
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction").toArray.map(_.toFloat))).toMap
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
checkEqualForBinary(xgb4jRawPred, xgbSparkRawPred)
|
||||||
import sparkSession.implicits._
|
} else {
|
||||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
checkEqual(xgb4jRawPred, xgbSparkRawPred)
|
||||||
new Partitioner {
|
|
||||||
override def numPartitions: Int = 2
|
|
||||||
|
|
||||||
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
|
|
||||||
}
|
}
|
||||||
).map(_._2).zipWithIndex().map {
|
|
||||||
case (lp, id) =>
|
|
||||||
(id, lp.label, lp.features)
|
|
||||||
}.toDF("id", "label", "features")
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
xgb.fit(repartitioned)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("infrequent features (use_external_memory)") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
|
||||||
import sparkSession.implicits._
|
|
||||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
|
||||||
new Partitioner {
|
|
||||||
override def numPartitions: Int = 2
|
|
||||||
|
|
||||||
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
|
|
||||||
}
|
|
||||||
).map(_._2).zipWithIndex().map {
|
|
||||||
case (lp, id) =>
|
|
||||||
(id, lp.label, lp.features)
|
|
||||||
}.toDF("id", "label", "features")
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
xgb.fit(repartitioned)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("featuresCols with features column can work") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val xgbInput = Seq(
|
|
||||||
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
|
||||||
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
|
||||||
.toDF("f1", "f2", "f3", "features", "label")
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
|
|
||||||
|
|
||||||
val featuresName = Array("f1", "f2", "f3", "features")
|
|
||||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
|
||||||
.setFeaturesCol(featuresName)
|
|
||||||
.setLabelCol("label")
|
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
|
||||||
|
|
||||||
val df = model.transform(xgbInput)
|
|
||||||
assert(df.schema.fieldNames.contains("features_" + model.uid))
|
|
||||||
df.show()
|
|
||||||
|
|
||||||
val newFeatureName = "features_new"
|
|
||||||
// transform also can work for vectorized dataset
|
|
||||||
val vectorizedInput = new VectorAssembler()
|
|
||||||
.setInputCols(featuresName)
|
|
||||||
.setOutputCol(newFeatureName)
|
|
||||||
.transform(xgbInput)
|
|
||||||
.select(newFeatureName, "label")
|
|
||||||
|
|
||||||
val df1 = model
|
|
||||||
.setFeaturesCol(newFeatureName)
|
|
||||||
.transform(vectorizedInput)
|
|
||||||
assert(df1.schema.fieldNames.contains(newFeatureName))
|
|
||||||
df1.show()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("featuresCols without features column can work") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val xgbInput = Seq(
|
|
||||||
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
|
||||||
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
|
||||||
.toDF("f1", "f2", "f3", "f4", "label")
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
|
|
||||||
|
|
||||||
val featuresName = Array("f1", "f2", "f3", "f4")
|
|
||||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
|
||||||
.setFeaturesCol(featuresName)
|
|
||||||
.setLabelCol("label")
|
|
||||||
.setEvalSets(Map("eval" -> xgbInput))
|
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
|
||||||
|
|
||||||
// transform should work for the dataset which includes the feature column names.
|
|
||||||
val df = model.transform(xgbInput)
|
|
||||||
assert(df.schema.fieldNames.contains("features"))
|
|
||||||
df.show()
|
|
||||||
|
|
||||||
// transform also can work for vectorized dataset
|
|
||||||
val vectorizedInput = new VectorAssembler()
|
|
||||||
.setInputCols(featuresName)
|
|
||||||
.setOutputCol("features")
|
|
||||||
.transform(xgbInput)
|
|
||||||
.select("features", "label")
|
|
||||||
|
|
||||||
val df1 = model.transform(vectorizedInput)
|
|
||||||
df1.show()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("XGBoostClassificationModel should be compatible") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(trainingDF)
|
|
||||||
|
|
||||||
// test json
|
|
||||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
|
||||||
model.write.option("format", "json").save(modelPath)
|
|
||||||
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeJsonModelPath)
|
|
||||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
|
||||||
nativeJsonModelPath))
|
|
||||||
|
|
||||||
// test ubj
|
|
||||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
|
||||||
model.write.save(modelUbjPath)
|
|
||||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
|
||||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
|
|
||||||
nativeUbjModelPath))
|
|
||||||
|
|
||||||
// json file should be indifferent with ubj file
|
|
||||||
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
|
||||||
model.write.option("format", "json").save(modelJsonPath)
|
|
||||||
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeUbjModelPath1)
|
|
||||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
|
||||||
nativeUbjModelPath1))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("native json model file should store feature_name and feature_type") {
|
|
||||||
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
|
|
||||||
val featureTypes = (1 to 33).map(idx => "q").toArray
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod
|
|
||||||
)
|
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
.setFeatureNames(featureNames)
|
|
||||||
.setFeatureTypes(featureTypes)
|
|
||||||
val model = xgb.fit(trainingDF)
|
|
||||||
val modelStr = new String(model._booster.toByteArray("json"))
|
|
||||||
val jsonModel = parseJson(modelStr)
|
|
||||||
implicit val formats: Formats = DefaultFormats
|
|
||||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
|
||||||
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
|
|
||||||
assert(featureNamesInModel.length == 33)
|
|
||||||
assert(featureTypesInModel.length == 33)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,75 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Communicator
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
|
||||||
|
|
||||||
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
|
||||||
val predictionErrorMin = 0.00001f
|
|
||||||
val maxFailure = 2;
|
|
||||||
|
|
||||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
|
||||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
|
||||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
|
||||||
.master(s"local[${numWorkers},${maxFailure}]")
|
|
||||||
|
|
||||||
test("test classification prediction parity w/o ring reduce") {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDF = buildDataFrame(Classification.test)
|
|
||||||
|
|
||||||
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
val model1 = new XGBoostClassifier(xgbSettings).fit(training)
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
|
||||||
|
|
||||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
|
||||||
.fit(training)
|
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
|
||||||
// check parity w/o rabit cache
|
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
|
||||||
assert(p1 == p2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test regression prediction parity w/o ring reduce") {
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val model1 = new XGBoostRegressor(xgbSettings).fit(training)
|
|
||||||
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
|
||||||
|
|
||||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
|
||||||
).fit(training)
|
|
||||||
// check the equality of single instance prediction
|
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
|
||||||
// check parity w/o rabit cache
|
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
|
||||||
assert(math.abs(p1 - p2) < predictionErrorMin)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
|
||||||
|
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
|
|
||||||
|
|
||||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
|
||||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
|
||||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
|
||||||
|
|
||||||
test("nthread configuration must be no larger than spark.task.cpus") {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
|
||||||
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
|
||||||
intercept[IllegalArgumentException] {
|
|
||||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("kryoSerializer test") {
|
|
||||||
// TODO write an isolated test for Booster.
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator, null)
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
val eval = new EvalError()
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Check for Spark encryption over-the-wire") {
|
|
||||||
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
|
||||||
ss.conf.set("spark.ssl.enabled", true)
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
|
|
||||||
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
|
|
||||||
"xgboost.spark.ignoreSsl != true") {
|
|
||||||
val thrown = intercept[Exception] {
|
|
||||||
new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
}
|
|
||||||
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
|
|
||||||
thrown.getMessage.contains("spark.ssl.enabled"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Confirm that this check can be overridden.
|
|
||||||
ss.conf.set("xgboost.spark.ignoreSsl", true)
|
|
||||||
new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
|
|
||||||
originalSslConfOpt match {
|
|
||||||
case None =>
|
|
||||||
ss.conf.unset("spark.ssl.enabled")
|
|
||||||
case Some(originalSslConf) =>
|
|
||||||
ss.conf.set("spark.ssl.enabled", originalSslConf)
|
|
||||||
}
|
|
||||||
ss.conf.unset("xgboost.spark.ignoreSsl")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -0,0 +1,512 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import java.util.Arrays
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors}
|
||||||
|
import org.apache.spark.SparkException
|
||||||
|
import org.json4s.{DefaultFormats, Formats}
|
||||||
|
import org.json4s.jackson.parseJson
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Utils.TRAIN_NAME
|
||||||
|
|
||||||
|
class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
val df = smallBinaryClassificationVector
|
||||||
|
val xgbParams: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"eta" -> 0.2,
|
||||||
|
"objective" -> "binary:logistic"
|
||||||
|
)
|
||||||
|
val estimator = new XGBoostClassifier(xgbParams)
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setMissing(0.2f)
|
||||||
|
.setAlpha(0.97)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setNumRound(1)
|
||||||
|
|
||||||
|
assert(estimator.getMaxDepth === 5)
|
||||||
|
assert(estimator.getEta === 0.2)
|
||||||
|
assert(estimator.getObjective === "binary:logistic")
|
||||||
|
assert(estimator.getFeaturesCol === "features")
|
||||||
|
assert(estimator.getMissing === 0.2f)
|
||||||
|
assert(estimator.getAlpha === 0.97)
|
||||||
|
|
||||||
|
estimator.setEta(0.66).setMaxDepth(7)
|
||||||
|
assert(estimator.getMaxDepth === 7)
|
||||||
|
assert(estimator.getEta === 0.66)
|
||||||
|
|
||||||
|
val model = estimator.fit(df)
|
||||||
|
assert(model.getMaxDepth === 7)
|
||||||
|
assert(model.getEta === 0.66)
|
||||||
|
assert(model.getObjective === "binary:logistic")
|
||||||
|
assert(model.getFeaturesCol === "features")
|
||||||
|
assert(model.getMissing === 0.2f)
|
||||||
|
assert(model.getAlpha === 0.97)
|
||||||
|
assert(model.getLeafPredictionCol === "leaf")
|
||||||
|
assert(model.getContribPredictionCol === "contrib")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("nthread") {
|
||||||
|
val classifier = new XGBoostClassifier().setNthread(100)
|
||||||
|
|
||||||
|
intercept[IllegalArgumentException](
|
||||||
|
classifier.validate(smallBinaryClassificationVector)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("RuntimeParameter") {
|
||||||
|
var runtimeParams = new XGBoostClassifier(
|
||||||
|
Map("device" -> "cpu"))
|
||||||
|
.getRuntimeParameters(true)
|
||||||
|
assert(!runtimeParams.runOnGpu)
|
||||||
|
|
||||||
|
runtimeParams = new XGBoostClassifier(
|
||||||
|
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
|
||||||
|
.getRuntimeParameters(true)
|
||||||
|
assert(runtimeParams.runOnGpu)
|
||||||
|
|
||||||
|
runtimeParams = new XGBoostClassifier(
|
||||||
|
Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
|
||||||
|
.getRuntimeParameters(true)
|
||||||
|
assert(runtimeParams.runOnGpu)
|
||||||
|
|
||||||
|
runtimeParams = new XGBoostClassifier(
|
||||||
|
Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
|
||||||
|
.getRuntimeParameters(true)
|
||||||
|
assert(runtimeParams.runOnGpu)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("missing value exception for sparse vector") {
|
||||||
|
val sparse1 = Vectors.dense(0.0, 0.0, 0.0).toSparse
|
||||||
|
assert(sparse1.isInstanceOf[SparseVector])
|
||||||
|
val sparse2 = Vectors.dense(0.5, 2.2, 1.7).toSparse
|
||||||
|
assert(sparse2.isInstanceOf[SparseVector])
|
||||||
|
|
||||||
|
val sparseInput = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, sparse1),
|
||||||
|
(2.0, sparse2)
|
||||||
|
))).toDF("label", "features")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
val (input, columnIndexes) = classifier.preprocess(sparseInput)
|
||||||
|
val rdd = classifier.toXGBLabeledPoint(input, columnIndexes)
|
||||||
|
|
||||||
|
val exception = intercept[SparkException] {
|
||||||
|
rdd.collect()
|
||||||
|
}
|
||||||
|
assert(exception.getMessage.contains("We've detected sparse vectors in the dataset " +
|
||||||
|
"that need conversion to dense format"))
|
||||||
|
|
||||||
|
// explicitly set missing value, no exception
|
||||||
|
classifier.setMissing(Float.NaN)
|
||||||
|
val rdd1 = classifier.toXGBLabeledPoint(input, columnIndexes)
|
||||||
|
rdd1.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("missing value for dense vector no need to set missing explicitly") {
|
||||||
|
val dense1 = Vectors.dense(0.0, 0.0, 0.0)
|
||||||
|
assert(dense1.isInstanceOf[DenseVector])
|
||||||
|
val dense2 = Vectors.dense(0.5, 2.2, 1.7)
|
||||||
|
assert(dense2.isInstanceOf[DenseVector])
|
||||||
|
|
||||||
|
val sparseInput = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, dense1),
|
||||||
|
(2.0, dense2)
|
||||||
|
))).toDF("label", "features")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
val (input, columnIndexes) = classifier.preprocess(sparseInput)
|
||||||
|
val rdd = classifier.toXGBLabeledPoint(input, columnIndexes)
|
||||||
|
rdd.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
|
||||||
|
"using custom Eval and Obj") {
|
||||||
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6",
|
||||||
|
"verbosity" -> "1", "objective" -> "binary:logistic")
|
||||||
|
|
||||||
|
val xgbc = new XGBoostClassifier(paramMap)
|
||||||
|
.setCustomObj(new CustomObj(1))
|
||||||
|
.setCustomEval(new EvalError)
|
||||||
|
.setNumRound(10)
|
||||||
|
.setNumWorkers(numWorkers)
|
||||||
|
|
||||||
|
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
||||||
|
xgbc.write.overwrite().save(xgbcPath)
|
||||||
|
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
||||||
|
|
||||||
|
assert(xgbc.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
|
||||||
|
assert(xgbc2.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
|
||||||
|
|
||||||
|
val eval = new EvalError()
|
||||||
|
|
||||||
|
val model = xgbc.fit(trainingDF)
|
||||||
|
val evalResults = eval.eval(model.nativeBooster.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
assert(evalResults < 0.1)
|
||||||
|
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
|
||||||
|
model.write.overwrite.save(xgbcModelPath)
|
||||||
|
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
||||||
|
assert(Arrays.equals(model.nativeBooster.toByteArray, model2.nativeBooster.toByteArray))
|
||||||
|
|
||||||
|
assert(model.getEta === model2.getEta)
|
||||||
|
assert(model.getNumRound === model2.getNumRound)
|
||||||
|
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
||||||
|
val evalResults2 = eval.eval(model2.nativeBooster.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
assert(evalResults === evalResults2)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Check for Spark encryption over-the-wire") {
|
||||||
|
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
||||||
|
ss.conf.set("spark.ssl.enabled", true)
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val training = smallBinaryClassificationVector
|
||||||
|
|
||||||
|
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
|
||||||
|
"xgboost.spark.ignoreSsl != true") {
|
||||||
|
val thrown = intercept[Exception] {
|
||||||
|
new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
|
||||||
|
}
|
||||||
|
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
|
||||||
|
thrown.getMessage.contains("spark.ssl.enabled"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm that this check can be overridden.
|
||||||
|
ss.conf.set("xgboost.spark.ignoreSsl", true)
|
||||||
|
new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
|
||||||
|
|
||||||
|
originalSslConfOpt match {
|
||||||
|
case None =>
|
||||||
|
ss.conf.unset("spark.ssl.enabled")
|
||||||
|
case Some(originalSslConf) =>
|
||||||
|
ss.conf.set("spark.ssl.enabled", originalSslConf)
|
||||||
|
}
|
||||||
|
ss.conf.unset("xgboost.spark.ignoreSsl")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||||
|
val training = smallBinaryClassificationVector
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
new XGBoostClassifier(paramMap)
|
||||||
|
.setNumWorkers(numWorkers)
|
||||||
|
.setNumRound(2)
|
||||||
|
.setNthread(sc.getConf.getInt("spark.task.cpus", 1) + 1)
|
||||||
|
.fit(training)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("preprocess dataset") {
|
||||||
|
val dataset = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0), "a"),
|
||||||
|
(0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0), "b"),
|
||||||
|
(2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7), "c")
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features", "other")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setWeightCol("weight")
|
||||||
|
|
||||||
|
val (df, indices) = classifier.preprocess(dataset)
|
||||||
|
var schema = df.schema
|
||||||
|
assert(!schema.names.contains("group") && !schema.names.contains("other"))
|
||||||
|
assert(indices.labelId == schema.fieldIndex("label") &&
|
||||||
|
indices.groupId.isEmpty &&
|
||||||
|
indices.marginId.get == schema.fieldIndex("margin") &&
|
||||||
|
indices.weightId.get == schema.fieldIndex("weight") &&
|
||||||
|
indices.featureId.get == schema.fieldIndex("features") &&
|
||||||
|
indices.featureIds.isEmpty)
|
||||||
|
|
||||||
|
classifier.setWeightCol("")
|
||||||
|
val (df1, indices1) = classifier.preprocess(dataset)
|
||||||
|
schema = df1.schema
|
||||||
|
Seq("weight", "group", "other").foreach(v => assert(!schema.names.contains(v)))
|
||||||
|
assert(indices1.labelId == schema.fieldIndex("label") &&
|
||||||
|
indices1.groupId.isEmpty &&
|
||||||
|
indices1.marginId.get == schema.fieldIndex("margin") &&
|
||||||
|
indices1.weightId.isEmpty &&
|
||||||
|
indices1.featureId.get == schema.fieldIndex("features") &&
|
||||||
|
indices1.featureIds.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("to XGBoostLabeledPoint RDD") {
|
||||||
|
val data = Array(
|
||||||
|
Array(1.0, 2.0, 3.0, 4.0, 5.0),
|
||||||
|
Array(0.0, 0.0, 0.0, 0.0, 2.0),
|
||||||
|
Array(12.0, 13.0, 14.0, 14.0, 15.0),
|
||||||
|
Array(20.5, 21.2, 0.0, 0.0, 2.0)
|
||||||
|
)
|
||||||
|
val dataset = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
|
||||||
|
(2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
|
||||||
|
(3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"),
|
||||||
|
(4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c")
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features", "other")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setNumWorkers(2)
|
||||||
|
.setMissing(Float.NaN)
|
||||||
|
|
||||||
|
val (df, indices) = classifier.preprocess(dataset)
|
||||||
|
val rdd = classifier.toXGBLabeledPoint(df, indices)
|
||||||
|
val result = rdd.collect().sortBy(x => x.label)
|
||||||
|
|
||||||
|
assert(result.length == data.length)
|
||||||
|
|
||||||
|
def toArray(index: Int): Array[Float] = {
|
||||||
|
val labelPoint = result(index)
|
||||||
|
if (labelPoint.indices != null) {
|
||||||
|
Vectors.sparse(labelPoint.size,
|
||||||
|
labelPoint.indices,
|
||||||
|
labelPoint.values.map(_.toDouble)).toArray.map(_.toFloat)
|
||||||
|
} else {
|
||||||
|
labelPoint.values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(result(0).label === 1.0f && result(0).baseMargin.isNaN &&
|
||||||
|
result(0).weight === 1.0f && toArray(0) === data(0).map(_.toFloat))
|
||||||
|
assert(result(1).label == 2.0f && result(1).baseMargin.isNaN &&
|
||||||
|
result(1).weight === 0.0f && toArray(1) === data(1).map(_.toFloat))
|
||||||
|
assert(result(2).label === 3.0f && result(2).baseMargin.isNaN &&
|
||||||
|
result(2).weight == 0.0f && toArray(2) === data(2).map(_.toFloat))
|
||||||
|
assert(result(3).label === 4.0f && result(3).baseMargin.isNaN &&
|
||||||
|
result(3).weight === -2.1f && toArray(3) === data(3).map(_.toFloat))
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq((Float.NaN, 2), (0.0f, 7 + 2), (15.0f, 1 + 2), (10101011.0f, 0 + 2)).foreach {
|
||||||
|
case (missing, expectedMissingValue) =>
|
||||||
|
test(s"to RDD watches with missing $missing") {
|
||||||
|
val data = Array(
|
||||||
|
Array(1.0, 2.0, 3.0, 4.0, 5.0),
|
||||||
|
Array(1.0, Float.NaN, 0.0, 0.0, 2.0),
|
||||||
|
Array(12.0, 13.0, Float.NaN, 14.0, 15.0),
|
||||||
|
Array(0.0, 0.0, 0.0, 0.0, 0.0)
|
||||||
|
)
|
||||||
|
val dataset = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
|
||||||
|
(2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
|
||||||
|
(3.0, 3, -0.5, 0.0, Vectors.dense(data(2)), "b"),
|
||||||
|
(4.0, 4, -0.4, -2.1, Vectors.dense(data(3)), "c")
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features", "other")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setMissing(missing)
|
||||||
|
.setNumWorkers(2)
|
||||||
|
|
||||||
|
val (df, indices) = classifier.preprocess(dataset)
|
||||||
|
val rdd = classifier.toRdd(df, indices)
|
||||||
|
val result = rdd.mapPartitions { iter =>
|
||||||
|
if (iter.hasNext) {
|
||||||
|
val watches = iter.next()
|
||||||
|
val size = watches.size
|
||||||
|
val trainDM = watches.toMap(TRAIN_NAME)
|
||||||
|
val rowNum = trainDM.rowNum
|
||||||
|
val labels = trainDM.getLabel
|
||||||
|
val weight = trainDM.getWeight
|
||||||
|
val margins = trainDM.getBaseMargin
|
||||||
|
val nonMissing = trainDM.nonMissingNum
|
||||||
|
watches.delete()
|
||||||
|
Iterator.single((size, rowNum, labels, weight, margins, nonMissing))
|
||||||
|
} else {
|
||||||
|
Iterator.empty
|
||||||
|
}
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
var nonMissingValues = 0L
|
||||||
|
var totalRows = 0L
|
||||||
|
|
||||||
|
for (row <- result) {
|
||||||
|
assert(row._1 === 1)
|
||||||
|
totalRows = totalRows + row._2
|
||||||
|
labels.append(row._3: _*)
|
||||||
|
weight.append(row._4: _*)
|
||||||
|
margins.append(row._5: _*)
|
||||||
|
nonMissingValues = nonMissingValues + row._6
|
||||||
|
}
|
||||||
|
assert(totalRows === 4)
|
||||||
|
assert(nonMissingValues === data.size * data(0).length - expectedMissingValue)
|
||||||
|
assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
|
||||||
|
assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
|
||||||
|
assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("to RDD watches with eval") {
|
||||||
|
val trainData = Array(
|
||||||
|
Array(-1.0, -2.0, -3.0, -4.0, -5.0),
|
||||||
|
Array(2.0, 2.0, 2.0, 3.0, -2.0),
|
||||||
|
Array(-12.0, -13.0, -14.0, -14.0, -15.0),
|
||||||
|
Array(-20.5, -21.2, 0.0, 0.0, 2.0)
|
||||||
|
)
|
||||||
|
val trainDataset = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(11.0, 0, 0.15, 11.0, Vectors.dense(trainData(0)), "a"),
|
||||||
|
(12.0, 12, -0.15, 10.0, Vectors.dense(trainData(1)).toSparse, "b"),
|
||||||
|
(13.0, 12, -0.15, 10.0, Vectors.dense(trainData(2)), "b"),
|
||||||
|
(14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c")
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features", "other")
|
||||||
|
val evalData = Array(
|
||||||
|
Array(1.0, 2.0, 3.0, 4.0, 5.0),
|
||||||
|
Array(0.0, 0.0, 0.0, 0.0, 2.0),
|
||||||
|
Array(12.0, 13.0, 14.0, 14.0, 15.0),
|
||||||
|
Array(20.5, 21.2, 0.0, 0.0, 2.0)
|
||||||
|
)
|
||||||
|
val evalDataset = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 0.5, 1.0, Vectors.dense(evalData(0)), "a"),
|
||||||
|
(2.0, 2, -0.5, 0.0, Vectors.dense(evalData(1)).toSparse, "b"),
|
||||||
|
(3.0, 2, -0.5, 0.0, Vectors.dense(evalData(2)), "b"),
|
||||||
|
(4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c")
|
||||||
|
))).toDF("label", "group", "margin", "weight", "features", "other")
|
||||||
|
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setWeightCol("weight")
|
||||||
|
.setBaseMarginCol("margin")
|
||||||
|
.setEvalDataset(evalDataset)
|
||||||
|
.setNumWorkers(2)
|
||||||
|
.setMissing(Float.NaN)
|
||||||
|
|
||||||
|
val (df, indices) = classifier.preprocess(trainDataset)
|
||||||
|
val rdd = classifier.toRdd(df, indices)
|
||||||
|
val result = rdd.mapPartitions { iter =>
|
||||||
|
if (iter.hasNext) {
|
||||||
|
val watches = iter.next()
|
||||||
|
val size = watches.size
|
||||||
|
val evalDM = watches.toMap(Utils.VALIDATION_NAME)
|
||||||
|
val rowNum = evalDM.rowNum
|
||||||
|
val labels = evalDM.getLabel
|
||||||
|
val weight = evalDM.getWeight
|
||||||
|
val margins = evalDM.getBaseMargin
|
||||||
|
watches.delete()
|
||||||
|
Iterator.single((size, rowNum, labels, weight, margins))
|
||||||
|
} else {
|
||||||
|
Iterator.empty
|
||||||
|
}
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
|
||||||
|
var totalRows = 0L
|
||||||
|
for (row <- result) {
|
||||||
|
assert(row._1 === 2)
|
||||||
|
totalRows = totalRows + row._2
|
||||||
|
labels.append(row._3: _*)
|
||||||
|
weight.append(row._4: _*)
|
||||||
|
margins.append(row._5: _*)
|
||||||
|
}
|
||||||
|
assert(totalRows === 4)
|
||||||
|
assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
|
||||||
|
assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
|
||||||
|
assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark model format should match xgboost4j") {
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
|
||||||
|
Seq(new XGBoostClassifier()).foreach { est =>
|
||||||
|
est.setNumRound(5)
|
||||||
|
val model = est.fit(trainingDF)
|
||||||
|
|
||||||
|
// test json
|
||||||
|
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||||
|
model.write.overwrite().option("format", "json").save(modelPath)
|
||||||
|
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeJsonModelPath)
|
||||||
|
assert(compareTwoFiles(new File(modelPath, "data/model").getPath,
|
||||||
|
nativeJsonModelPath))
|
||||||
|
|
||||||
|
// test ubj
|
||||||
|
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||||
|
model.write.overwrite().save(modelUbjPath)
|
||||||
|
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||||
|
assert(compareTwoFiles(new File(modelUbjPath, "data/model").getPath,
|
||||||
|
nativeUbjModelPath))
|
||||||
|
|
||||||
|
// json file should be indifferent with ubj file
|
||||||
|
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||||
|
model.write.overwrite().option("format", "json").save(modelJsonPath)
|
||||||
|
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeUbjModelPath1)
|
||||||
|
assert(!compareTwoFiles(new File(modelJsonPath, "data/model").getPath,
|
||||||
|
nativeUbjModelPath1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("native json model file should store feature_name and feature_type") {
|
||||||
|
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
|
||||||
|
val featureTypes = (1 to 33).map(idx => "q").toArray
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val xgb = new XGBoostClassifier()
|
||||||
|
.setNumWorkers(numWorkers)
|
||||||
|
.setFeatureNames(featureNames)
|
||||||
|
.setFeatureTypes(featureTypes)
|
||||||
|
.setNumRound(2)
|
||||||
|
val model = xgb.fit(trainingDF)
|
||||||
|
val modelStr = new String(model.nativeBooster.toByteArray("json"))
|
||||||
|
val jsonModel = parseJson(modelStr)
|
||||||
|
implicit val formats: Formats = DefaultFormats
|
||||||
|
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||||
|
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
|
||||||
|
assert(featureNamesInModel.length == 33)
|
||||||
|
assert(featureTypesInModel.length == 33)
|
||||||
|
assert(featureNames sameElements featureNamesInModel)
|
||||||
|
assert(featureTypes sameElements featureTypesInModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Exception with clear message") {
|
||||||
|
val df = smallMultiClassificationVector
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
.setNumRound(2)
|
||||||
|
.setObjective("multi:softprob")
|
||||||
|
.setNumClass(2)
|
||||||
|
|
||||||
|
val exception = intercept[SparkException] {
|
||||||
|
classifier.fit(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
exception.getMessage.contains("SoftmaxMultiClassObj: label must be in [0, num_class).")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,376 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, TaskContext}
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
|
||||||
import org.apache.spark.sql.functions.lit
|
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
|
||||||
|
|
||||||
test("distributed training with the specified worker number") {
|
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
|
||||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
|
|
||||||
val (booster, metrics) = XGBoost.trainDistributed(
|
|
||||||
sc,
|
|
||||||
buildTrainingRDD,
|
|
||||||
List("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
|
||||||
"missing" -> Float.NaN).toMap)
|
|
||||||
assert(booster != null)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with external memory cache") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"use_external_memory" -> true)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist with monotone_constraints (lossguide)") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist with interaction_constraints (lossguide)") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist with monotone_constraints (depthwise)") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist with interaction_constraints (depthwise)") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist depthwise") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist lossguide") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
|
||||||
"max_leaves" -> "8", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(x < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist lossguide with max bin") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
|
||||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
|
||||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(x < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist depthwidth with max depth") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
|
||||||
"grow_policy" -> "depthwise", "max_depth" -> "2",
|
|
||||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(x < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test with quantile hist depthwidth with max depth and max bin") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
|
||||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
|
||||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
|
||||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
assert(x < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("repartitionForTrainingGroup with group data") {
|
|
||||||
// test different splits to cover the corner cases.
|
|
||||||
for (split <- 1 to 20) {
|
|
||||||
val trainingRDD = sc.parallelize(Ranking.train, split)
|
|
||||||
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
|
||||||
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
|
|
||||||
// check the the order of the groups with group id.
|
|
||||||
// Ranking.train has 20 groups
|
|
||||||
assert(trainingGroups.length == 20)
|
|
||||||
|
|
||||||
// compare all points
|
|
||||||
val allPoints = trainingGroups.sortBy(_(0).group).flatten
|
|
||||||
assert(allPoints.length == Ranking.train.size)
|
|
||||||
for (i <- 0 to Ranking.train.size - 1) {
|
|
||||||
assert(allPoints(i).group == Ranking.train(i).group)
|
|
||||||
assert(allPoints(i).label == Ranking.train(i).label)
|
|
||||||
assert(allPoints(i).values.sameElements(Ranking.train(i).values))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("repartitionForTrainingGroup with group data which has empty partition") {
|
|
||||||
val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
|
|
||||||
// make one partition empty for testing
|
|
||||||
it.filter(_ => TaskContext.getPartitionId() != 3)
|
|
||||||
})
|
|
||||||
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("distributed training with group data") {
|
|
||||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
|
||||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
|
|
||||||
val (booster, _) = XGBoost.trainDistributed(
|
|
||||||
sc,
|
|
||||||
buildTrainingRDD,
|
|
||||||
List("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
|
||||||
"missing" -> Float.NaN).toMap)
|
|
||||||
|
|
||||||
assert(booster != null)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training summary") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
|
|
||||||
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(trainingDF)
|
|
||||||
|
|
||||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
|
||||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("train/test split") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
assert(model.summary.validationObjectiveHistory.length === 1)
|
|
||||||
assert(model.summary.validationObjectiveHistory(0)._1 === "test")
|
|
||||||
assert(model.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model.summary.trainObjectiveHistory !== model.summary.validationObjectiveHistory(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("train with multiple validation datasets (non-ranking)") {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
|
|
||||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val model1 = xgb1.fit(train)
|
|
||||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
|
||||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
|
||||||
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
|
|
||||||
|
|
||||||
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val xgb2 = new XGBoostClassifier(paramMap2)
|
|
||||||
val model2 = xgb2.fit(train)
|
|
||||||
assert(model2.summary.validationObjectiveHistory.length === 2)
|
|
||||||
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
|
||||||
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
|
|
||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
|
||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("train with multiple validation datasets (ranking)") {
|
|
||||||
val training = buildDataFrameWithGroup(Ranking.train, 5)
|
|
||||||
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0)
|
|
||||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "rank:ndcg",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
|
||||||
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val model1 = xgb1.fit(train)
|
|
||||||
assert(model1 != null)
|
|
||||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
|
||||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
|
||||||
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
|
|
||||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
|
|
||||||
|
|
||||||
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
|
|
||||||
"objective" -> "rank:ndcg",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
|
|
||||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val xgb2 = new XGBoostRegressor(paramMap2)
|
|
||||||
val model2 = xgb2.fit(train)
|
|
||||||
assert(model2 != null)
|
|
||||||
assert(model2.summary.validationObjectiveHistory.length === 2)
|
|
||||||
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
|
||||||
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
|
|
||||||
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
|
|
||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
|
||||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("infer with different batch sizes") {
|
|
||||||
val regModel = new XGBoostRegressor(Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers))
|
|
||||||
.fit(buildDataFrame(Regression.train))
|
|
||||||
val regDF = buildDataFrame(Regression.test)
|
|
||||||
|
|
||||||
val regRet1 = regModel.transform(regDF).collect()
|
|
||||||
val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
|
|
||||||
val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
|
|
||||||
val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
|
|
||||||
assert(regRet1 sameElements regRet2)
|
|
||||||
assert(regRet1 sameElements regRet3)
|
|
||||||
assert(regRet1 sameElements regRet4)
|
|
||||||
|
|
||||||
val clsModel = new XGBoostClassifier(Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic",
|
|
||||||
"num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers))
|
|
||||||
.fit(buildDataFrame(Classification.train))
|
|
||||||
val clsDF = buildDataFrame(Classification.test)
|
|
||||||
|
|
||||||
val clsRet1 = clsModel.transform(clsDF).collect()
|
|
||||||
val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
|
|
||||||
val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
|
|
||||||
val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
|
|
||||||
assert(clsRet1 sameElements clsRet2)
|
|
||||||
assert(clsRet1 sameElements clsRet3)
|
|
||||||
assert(clsRet1 sameElements clsRet4)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("chaining the prediction") {
|
|
||||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
|
||||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
|
||||||
val r = new Random(0)
|
|
||||||
var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
|
||||||
toDF("feature", "label").repartition(5)
|
|
||||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
|
||||||
// if feature size of data is not equal to 251
|
|
||||||
for (x <- 1 to 250) {
|
|
||||||
df = df.withColumn(s"feature_${x}", lit(1))
|
|
||||||
}
|
|
||||||
val assembler = new VectorAssembler()
|
|
||||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
|
||||||
.setOutputCol("features")
|
|
||||||
df = assembler.transform(df)
|
|
||||||
for (x <- 1 to 250) {
|
|
||||||
df = df.drop(s"feature_${x}")
|
|
||||||
}
|
|
||||||
val df1 = model.transform(df).withColumnRenamed(
|
|
||||||
"prediction", "prediction1").withColumnRenamed(
|
|
||||||
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
|
||||||
"probability", "probability1")
|
|
||||||
val df2 = model.transform(df1)
|
|
||||||
df1.collect()
|
|
||||||
df2.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("throw exception for empty partition in trainingset") {
|
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_class" -> "2", "num_round" -> 5,
|
|
||||||
"num_workers" -> numWorkers, "tree_method" -> "auto", "allow_non_zero_for_missing" -> true)
|
|
||||||
// The Dmatrix will be empty
|
|
||||||
val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 4,
|
|
||||||
Array(0, 1, 2, 3), Array(0, 1, 2, 3))))
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
|
||||||
intercept[SparkException] {
|
|
||||||
xgb.fit(trainingDF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -18,32 +18,116 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
import org.apache.spark.sql.DataFrame
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
|
||||||
|
|
||||||
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||||
protected val treeMethod: String = "auto"
|
test("XGBoostRegressor copy") {
|
||||||
|
val regressor = new XGBoostRegressor().setNthread(2).setNumWorkers(10)
|
||||||
|
val regressortCopied = regressor.copy(ParamMap.empty)
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
assert(regressor.uid === regressortCopied.uid)
|
||||||
|
assert(regressor.getNthread === regressortCopied.getNthread)
|
||||||
|
assert(regressor.getNumWorkers === regressor.getNumWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostRegressionModel copy") {
|
||||||
|
val model = new XGBoostRegressionModel("hello").setNthread(2).setNumWorkers(10)
|
||||||
|
val modelCopied = model.copy(ParamMap.empty)
|
||||||
|
assert(model.uid === modelCopied.uid)
|
||||||
|
assert(model.getNthread === modelCopied.getNthread)
|
||||||
|
assert(model.getNumWorkers === modelCopied.getNumWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("read/write") {
|
||||||
|
val trainDf = smallBinaryClassificationVector
|
||||||
|
val xgbParams: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"eta" -> 0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
def check(xgboostParams: XGBoostParams[_]): Unit = {
|
||||||
|
assert(xgboostParams.getMaxDepth === 5)
|
||||||
|
assert(xgboostParams.getEta === 0.2)
|
||||||
|
assert(xgboostParams.getObjective === "reg:squarederror")
|
||||||
|
}
|
||||||
|
|
||||||
|
val regressorPath = new File(tempDir.toFile, "regressor").getPath
|
||||||
|
val regressor = new XGBoostRegressor(xgbParams).setNumRound(1)
|
||||||
|
check(regressor)
|
||||||
|
|
||||||
|
regressor.write.overwrite().save(regressorPath)
|
||||||
|
val loadedRegressor = XGBoostRegressor.load(regressorPath)
|
||||||
|
check(loadedRegressor)
|
||||||
|
|
||||||
|
val model = loadedRegressor.fit(trainDf)
|
||||||
|
check(model)
|
||||||
|
|
||||||
|
val modelPath = new File(tempDir.toFile, "model").getPath
|
||||||
|
model.write.overwrite().save(modelPath)
|
||||||
|
val modelLoaded = XGBoostRegressionModel.load(modelPath)
|
||||||
|
check(modelLoaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostRegressionModel transformed schema") {
|
||||||
|
val trainDf = smallBinaryClassificationVector
|
||||||
|
val regressor = new XGBoostRegressor().setNumRound(1)
|
||||||
|
val model = regressor.fit(trainDf)
|
||||||
|
var out = model.transform(trainDf)
|
||||||
|
// Transform should not discard the other columns of the transforming dataframe
|
||||||
|
Seq("label", "margin", "weight", "features").foreach { v =>
|
||||||
|
assert(out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
// Regressor does not have extra columns
|
||||||
|
Seq("rawPrediction", "probability").foreach { v =>
|
||||||
|
assert(!out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
assert(out.schema.names.contains("prediction"))
|
||||||
|
assert(out.schema.names.length === 5)
|
||||||
|
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
|
||||||
|
out = model.transform(trainDf)
|
||||||
|
assert(out.schema.names.contains("leaf"))
|
||||||
|
assert(out.schema.names.contains("contrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Supported objectives") {
|
||||||
|
val regressor = new XGBoostRegressor()
|
||||||
|
val df = smallMultiClassificationVector
|
||||||
|
REGRESSION_OBJS.foreach { obj =>
|
||||||
|
regressor.setObjective(obj)
|
||||||
|
regressor.validate(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
regressor.setObjective("binary:logistic")
|
||||||
|
intercept[IllegalArgumentException](
|
||||||
|
regressor.validate(df)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark output should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||||
val testDM = new DMatrix(Regression.test.iterator)
|
val testDM = new DMatrix(Regression.test.iterator)
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
val paramMap = Map("objective" -> "reg:squarederror")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("XGBoostRegressor should make correct predictions after upstream random sort") {
|
test("XGBoost-Spark output with weight should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
val trainingDM = new DMatrix(Regression.trainWithWeight.iterator)
|
||||||
|
trainingDM.setWeight(Regression.randomWeights)
|
||||||
val testDM = new DMatrix(Regression.test.iterator)
|
val testDM = new DMatrix(Regression.test.iterator)
|
||||||
val trainingDF = buildDataFrameWithRandSort(Regression.train)
|
val trainingDF = buildDataFrame(Regression.trainWithWeight)
|
||||||
val testDF = buildDataFrameWithRandSort(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
val paramMap = Map("objective" -> "reg:squarederror")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
|
||||||
|
5, paramMap, Some("weight"))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def checkResultsWithXGBoost4j(
|
private def checkResultsWithXGBoost4j(
|
||||||
@ -51,306 +135,51 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
|
|||||||
testDM: DMatrix,
|
testDM: DMatrix,
|
||||||
trainingDF: DataFrame,
|
trainingDF: DataFrame,
|
||||||
testDF: DataFrame,
|
testDF: DataFrame,
|
||||||
round: Int = 5): Unit = {
|
round: Int = 5,
|
||||||
|
xgbParams: Map[String, Any] = Map.empty,
|
||||||
|
weightCol: Option[String] = None): Unit = {
|
||||||
val paramMap = Map(
|
val paramMap = Map(
|
||||||
"eta" -> "1",
|
"eta" -> "1",
|
||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
"silent" -> "1",
|
"base_score" -> 0.5,
|
||||||
"objective" -> "reg:squarederror",
|
"max_bin" -> 16) ++ xgbParams
|
||||||
"max_bin" -> 64,
|
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
|
|
||||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
val regressor = new XGBoostRegressor(paramMap)
|
||||||
val prediction1 = model1.predict(testDM)
|
|
||||||
|
|
||||||
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
|
|
||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).
|
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
|
|
||||||
|
|
||||||
assert(prediction1.indices.count { i =>
|
|
||||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
|
||||||
} < prediction1.length * 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
// check the equality of single instance prediction
|
|
||||||
val firstOfDM = testDM.slice(Array(0))
|
|
||||||
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
|
||||||
.head()
|
|
||||||
.getAs[Vector]("features")
|
|
||||||
val prediction3 = model1.predict(firstOfDM)(0)(0)
|
|
||||||
val prediction4 = model2.predict(firstOfDF)
|
|
||||||
assert(math.abs(prediction3 - prediction4) <= 0.01f)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val round = 5
|
|
||||||
|
|
||||||
val paramMap = Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> round,
|
|
||||||
"tree_method" -> treeMethod,
|
|
||||||
"num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
// Set params in XGBoost way
|
|
||||||
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
|
|
||||||
// Set params in MLlib way
|
|
||||||
val model2 = new XGBoostRegressor()
|
|
||||||
.setEta(1)
|
|
||||||
.setMaxDepth(6)
|
|
||||||
.setSilent(1)
|
|
||||||
.setObjective("reg:squarederror")
|
|
||||||
.setNumRound(round)
|
.setNumRound(round)
|
||||||
.setTreeMethod(treeMethod)
|
|
||||||
.setNumWorkers(numWorkers)
|
.setNumWorkers(numWorkers)
|
||||||
.fit(trainingDF)
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
weightCol.foreach(weight => regressor.setWeightCol(weight))
|
||||||
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
assert(left.size === right.size)
|
||||||
|
left.zipWithIndex.foreach { case (leftValue, index) =>
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
assert(leftValue.sameElements(right(index)))
|
||||||
assert(math.abs(p1 - p2) <= 0.01f)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("ranking: use group data") {
|
val xgbSparkModel = regressor.fit(trainingDF)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val rows = xgbSparkModel.transform(testDF).collect()
|
||||||
"objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
|
|
||||||
"group_col" -> "group", "tree_method" -> treeMethod)
|
|
||||||
|
|
||||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
// Check Leaf
|
||||||
val testDF = buildDataFrame(Ranking.test)
|
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
|
||||||
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
val xgbSparkLeaf = rows.map(row =>
|
||||||
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
|
||||||
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
|
|
||||||
val prediction = model.transform(testDF).collect()
|
// Check contrib
|
||||||
assert(testDF.count() === prediction.length)
|
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
// Check prediction
|
||||||
|
val xgb4jPred = xgb4jModel.predict(testDM)
|
||||||
|
val xgbSparkPred = rows.map(row => {
|
||||||
|
val pred = row.getAs[Double]("prediction").toFloat
|
||||||
|
(row.getAs[Int]("id"), Array(pred))}).toMap
|
||||||
|
checkEqual(xgb4jPred, xgbSparkPred)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("use weight") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
|
|
||||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
|
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
|
||||||
.withColumn("weight", getWeightFromId(col("id")))
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
|
|
||||||
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
|
|
||||||
val prediction = model.transform(testDF).collect()
|
|
||||||
val first = prediction.head.getAs[Double]("prediction")
|
|
||||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("objective will be set if not specifying it") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
assert(!xgb.isDefined(xgb.objective))
|
|
||||||
xgb.fit(training)
|
|
||||||
assert(xgb.getObjective == "reg:squarederror")
|
|
||||||
|
|
||||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod,
|
|
||||||
"objective" -> "reg:squaredlogerror")
|
|
||||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
|
||||||
assert(xgb1.getObjective == "reg:squaredlogerror")
|
|
||||||
xgb1.fit(training)
|
|
||||||
assert(xgb1.getObjective == "reg:squaredlogerror")
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test predictionLeaf") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val groundTruth = testDF.count()
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setLeafPredictionCol("predictLeaf")
|
|
||||||
val resultDF = model.transform(testDF)
|
|
||||||
assert(resultDF.count === groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictLeaf"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test predictionLeaf with empty column name") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setLeafPredictionCol("")
|
|
||||||
val resultDF = model.transform(testDF)
|
|
||||||
assert(!resultDF.columns.contains("predictLeaf"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test predictionContrib") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val groundTruth = testDF.count()
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setContribPredictionCol("predictContrib")
|
|
||||||
val resultDF = model.transform(testDF)
|
|
||||||
assert(resultDF.count === groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictContrib"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test predictionContrib with empty column name") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setContribPredictionCol("")
|
|
||||||
val resultDF = model.transform(testDF)
|
|
||||||
assert(!resultDF.columns.contains("predictContrib"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test predictionLeaf and predictionContrib") {
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"tree_method" -> treeMethod)
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDF = buildDataFrame(Regression.test)
|
|
||||||
val groundTruth = testDF.count()
|
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
|
||||||
val model = xgb.fit(training)
|
|
||||||
model.setLeafPredictionCol("predictLeaf")
|
|
||||||
model.setContribPredictionCol("predictContrib")
|
|
||||||
val resultDF = model.transform(testDF)
|
|
||||||
assert(resultDF.count === groundTruth)
|
|
||||||
assert(resultDF.columns.contains("predictLeaf"))
|
|
||||||
assert(resultDF.columns.contains("predictContrib"))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("featuresCols with features column can work") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val xgbInput = Seq(
|
|
||||||
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
|
||||||
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
|
||||||
.toDF("f1", "f2", "f3", "features", "label")
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
|
|
||||||
|
|
||||||
val featuresName = Array("f1", "f2", "f3", "features")
|
|
||||||
val xgbClassifier = new XGBoostRegressor(paramMap)
|
|
||||||
.setFeaturesCol(featuresName)
|
|
||||||
.setLabelCol("label")
|
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
|
||||||
|
|
||||||
val df = model.transform(xgbInput)
|
|
||||||
assert(df.schema.fieldNames.contains("features_" + model.uid))
|
|
||||||
df.show()
|
|
||||||
|
|
||||||
val newFeatureName = "features_new"
|
|
||||||
// transform also can work for vectorized dataset
|
|
||||||
val vectorizedInput = new VectorAssembler()
|
|
||||||
.setInputCols(featuresName)
|
|
||||||
.setOutputCol(newFeatureName)
|
|
||||||
.transform(xgbInput)
|
|
||||||
.select(newFeatureName, "label")
|
|
||||||
|
|
||||||
val df1 = model
|
|
||||||
.setFeaturesCol(newFeatureName)
|
|
||||||
.transform(vectorizedInput)
|
|
||||||
assert(df1.schema.fieldNames.contains(newFeatureName))
|
|
||||||
df1.show()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("featuresCols without features column can work") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val xgbInput = Seq(
|
|
||||||
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
|
||||||
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
|
||||||
.toDF("f1", "f2", "f3", "f4", "label")
|
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
|
|
||||||
|
|
||||||
val featuresName = Array("f1", "f2", "f3", "f4")
|
|
||||||
val xgbClassifier = new XGBoostRegressor(paramMap)
|
|
||||||
.setFeaturesCol(featuresName)
|
|
||||||
.setLabelCol("label")
|
|
||||||
.setEvalSets(Map("eval" -> xgbInput))
|
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
|
||||||
|
|
||||||
// transform should work for the dataset which includes the feature column names.
|
|
||||||
val df = model.transform(xgbInput)
|
|
||||||
assert(df.schema.fieldNames.contains("features"))
|
|
||||||
df.show()
|
|
||||||
|
|
||||||
// transform also can work for vectorized dataset
|
|
||||||
val vectorizedInput = new VectorAssembler()
|
|
||||||
.setInputCols(featuresName)
|
|
||||||
.setOutputCol("features")
|
|
||||||
.transform(xgbInput)
|
|
||||||
.select("features", "label")
|
|
||||||
|
|
||||||
val df1 = model.transform(vectorizedInput)
|
|
||||||
df1.show()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("XGBoostRegressionModel should be compatible") {
|
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
|
||||||
val paramMap = Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "reg:squarederror",
|
|
||||||
"num_round" -> 5,
|
|
||||||
"tree_method" -> treeMethod,
|
|
||||||
"num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
|
||||||
|
|
||||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
|
||||||
model.write.option("format", "json").save(modelPath)
|
|
||||||
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeJsonModelPath)
|
|
||||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
|
|
||||||
nativeJsonModelPath))
|
|
||||||
|
|
||||||
// test default "ubj"
|
|
||||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
|
||||||
model.write.save(modelUbjPath)
|
|
||||||
|
|
||||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
|
||||||
|
|
||||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
|
|
||||||
nativeUbjModelPath))
|
|
||||||
|
|
||||||
// test the deprecated format
|
|
||||||
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
|
|
||||||
model.write.option("format", "deprecated").save(modelDeprecatedPath)
|
|
||||||
|
|
||||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
|
|
||||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
|
||||||
|
|
||||||
assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
|
|
||||||
nativeDeprecatedModelPath))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2023 by Contributors
|
Copyright (c) 2023-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,40 +16,18 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
|
||||||
class XGBoostSuite extends AnyFunSuite with PerTest {
|
class XGBoostSuite extends AnyFunSuite with PerTest {
|
||||||
|
|
||||||
// Do not create spark context
|
// Do not create spark context
|
||||||
override def beforeEach(): Unit = {}
|
override def beforeEach(): Unit = {}
|
||||||
|
|
||||||
test("XGBoost execution parameters") {
|
|
||||||
var xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
|
||||||
Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1), sc)
|
|
||||||
.buildXGBRuntimeParams
|
|
||||||
assert(!xgbExecutionParams.runOnGpu)
|
|
||||||
|
|
||||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
|
||||||
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
|
|
||||||
.buildXGBRuntimeParams
|
|
||||||
assert(xgbExecutionParams.runOnGpu)
|
|
||||||
|
|
||||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
|
||||||
Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1), sc)
|
|
||||||
.buildXGBRuntimeParams
|
|
||||||
assert(xgbExecutionParams.runOnGpu)
|
|
||||||
|
|
||||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
|
||||||
Map("device" -> "cuda", "tree_method" -> "gpu_hist",
|
|
||||||
"num_workers" -> 1, "num_round" -> 1), sc)
|
|
||||||
.buildXGBRuntimeParams
|
|
||||||
assert(xgbExecutionParams.runOnGpu)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("skip stage-level scheduling") {
|
test("skip stage-level scheduling") {
|
||||||
val conf = new SparkConf()
|
val conf = new SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
@ -101,7 +79,7 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
object FakedXGBoost extends XGBoostStageLevel {
|
object FakedXGBoost extends StageLevelScheduling {
|
||||||
|
|
||||||
// Do not skip stage-level scheduling for testing purposes.
|
// Do not skip stage-level scheduling for testing purposes.
|
||||||
override private[spark] def skipStageLevelScheduling(
|
override private[spark] def skipStageLevelScheduling(
|
||||||
@ -129,12 +107,12 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
|
|||||||
val df = ss.range(1, 10)
|
val df = ss.range(1, 10)
|
||||||
val rdd = df.rdd
|
val rdd = df.rdd
|
||||||
|
|
||||||
val xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
val runtimeParams = new XGBoostClassifier(
|
||||||
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
|
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
|
||||||
.buildXGBRuntimeParams
|
.getRuntimeParameters(true)
|
||||||
assert(xgbExecutionParams.runOnGpu)
|
assert(runtimeParams.runOnGpu)
|
||||||
|
|
||||||
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, xgbExecutionParams,
|
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, runtimeParams,
|
||||||
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
|
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
|
||||||
|
|
||||||
val taskResources = finalRDD.getResourceProfile().taskResources
|
val taskResources = finalRDD.getResourceProfile().taskResources
|
||||||
|
|||||||
@ -519,4 +519,39 @@ public class DMatrix {
|
|||||||
CSR,
|
CSR,
|
||||||
CSC
|
CSC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class to hold the quantile information
|
||||||
|
*/
|
||||||
|
public class QuantileCut {
|
||||||
|
// cut ptr
|
||||||
|
long[] indptr;
|
||||||
|
// cut values
|
||||||
|
float[] values;
|
||||||
|
|
||||||
|
QuantileCut(long[] indptr, float[] values) {
|
||||||
|
this.indptr = indptr;
|
||||||
|
this.values = values;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[] getIndptr() {
|
||||||
|
return indptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] getValues() {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the Quantile Cut.
|
||||||
|
* @return QuantileCut
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public QuantileCut getQuantileCut() throws XGBoostError {
|
||||||
|
long[][] indptr = new long[1][];
|
||||||
|
float[][] values = new float[1][];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetQuantileCut(this.handle, indptr, values));
|
||||||
|
return new QuantileCut(indptr[0], values[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,75 +0,0 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
|
||||||
|
|
||||||
import java.util.Iterator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* QuantileDMatrix will only be used to train
|
|
||||||
*/
|
|
||||||
public class QuantileDMatrix extends DMatrix {
|
|
||||||
/**
|
|
||||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
|
||||||
*
|
|
||||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
|
||||||
* @param missing the missing value
|
|
||||||
* @param maxBin the max bin
|
|
||||||
* @param nthread the parallelism
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public QuantileDMatrix(
|
|
||||||
Iterator<ColumnBatch> iter,
|
|
||||||
float missing,
|
|
||||||
int maxBin,
|
|
||||||
int nthread) throws XGBoostError {
|
|
||||||
super(0);
|
|
||||||
long[] out = new long[1];
|
|
||||||
String conf = getConfig(missing, maxBin, nthread);
|
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
|
||||||
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
|
|
||||||
handle = out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setLabel(Column column) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setWeight(Column column) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setBaseMargin(Column column) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setLabel(float[] labels) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setWeight(float[] weights) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setGroup(int[] group) throws XGBoostError {
|
|
||||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getConfig(float missing, int maxBin, int nthread) {
|
|
||||||
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
|
||||||
missing, maxBin, nthread);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -172,7 +172,7 @@ class XGBoostJNI {
|
|||||||
long handle, String field, String json);
|
long handle, String field, String json);
|
||||||
|
|
||||||
public final static native int XGQuantileDMatrixCreateFromCallback(
|
public final static native int XGQuantileDMatrixCreateFromCallback(
|
||||||
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
String featureJson, float missing, int nthread, long[] out);
|
String featureJson, float missing, int nthread, long[] out);
|
||||||
@ -180,4 +180,7 @@ class XGBoostJNI {
|
|||||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||||
|
|
||||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||||
|
|
||||||
|
public final static native int XGDMatrixGetQuantileCut(long handle, long[][] outIndptr, float[][] outValues);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -365,4 +365,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
override def read(kryo: Kryo, input: Input): Unit = {
|
override def read(kryo: Kryo, input: Input): Unit = {
|
||||||
booster = kryo.readObject(input, classOf[JBooster])
|
booster = kryo.readObject(input, classOf[JBooster])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// a flag to indicate if the device is set for the GPU transform
|
||||||
|
var deviceIsSet = false
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import _root_.scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
//
|
/**
|
||||||
// Created by bobwang on 2021/9/8.
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
//
|
*/
|
||||||
|
|
||||||
#ifndef XGBOOST_USE_CUDA
|
#ifndef XGBOOST_USE_CUDA
|
||||||
|
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||||
jobject jdata_iter, jobject jref_iter,
|
jobject jdata_iter, jlongArray jref,
|
||||||
char const *config, jlongArray jout) {
|
char const *config, jlongArray jout) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
|
#include <xgboost/c_api.h>
|
||||||
|
|
||||||
#include "../../../../src/common/device_helpers.cuh"
|
|
||||||
#include "../../../../src/common/cuda_pinned_allocator.h"
|
#include "../../../../src/common/cuda_pinned_allocator.h"
|
||||||
|
#include "../../../../src/common/device_vector.cuh" // for device_vector
|
||||||
#include "../../../../src/data/array_interface.h"
|
#include "../../../../src/data/array_interface.h"
|
||||||
#include "jvm_utils.h"
|
#include "jvm_utils.h"
|
||||||
#include <xgboost/c_api.h>
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace jni {
|
namespace jni {
|
||||||
@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
|
|||||||
int Next(DataIterHandle self) {
|
int Next(DataIterHandle self) {
|
||||||
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Deleter = std::function<void(T *)>;
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||||
@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||||
jobject jdata_iter, jobject jref_iter,
|
jobject jdata_iter, jlongArray jref,
|
||||||
char const *config, jlongArray jout) {
|
char const *config, jlongArray jout) {
|
||||||
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
|
DMatrixHandle ref{nullptr};
|
||||||
|
|
||||||
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
|
if (jref != NULL) {
|
||||||
if (jref_iter) {
|
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
|
||||||
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
|
[&](jlong *ptr) {
|
||||||
|
jenv->ReleaseLongArrayElements(jref, ptr, 0);
|
||||||
|
jenv->DeleteLocalRef(jref);
|
||||||
|
}};
|
||||||
|
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = XGQuantileDMatrixCreateFromCallback(
|
auto ret = XGQuantileDMatrixCreateFromCallback(
|
||||||
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
|
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
#include <xgboost/c_api.h>
|
#include <xgboost/c_api.h>
|
||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
#include <xgboost/string_view.h> // for StringView
|
||||||
|
|
||||||
#include <algorithm> // for copy_n
|
#include <algorithm> // for copy_n
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
@ -30,8 +31,9 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../src/c_api/c_api_error.h"
|
#include "../../../../src/c_api/c_api_error.h"
|
||||||
#include "../../../src/c_api/c_api_utils.h"
|
#include "../../../../src/c_api/c_api_utils.h"
|
||||||
|
#include "../../../../src/data/array_interface.h" // for ArrayInterface
|
||||||
|
|
||||||
#define JVM_CHECK_CALL(__expr) \
|
#define JVM_CHECK_CALL(__expr) \
|
||||||
{ \
|
{ \
|
||||||
@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
|
|||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGQuantileDMatrixCreateFromCallback
|
* Method: XGQuantileDMatrixCreateFromCallback
|
||||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
|
||||||
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
|
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
|
||||||
jlongArray jout) {
|
jlongArray jout) {
|
||||||
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
|
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
|
||||||
[&](char const *ptr) {
|
[&](char const *ptr) {
|
||||||
jenv->ReleaseStringUTFChars(jconf, ptr);
|
jenv->ReleaseStringUTFChars(jconf, ptr);
|
||||||
}};
|
}};
|
||||||
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
|
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
|
||||||
conf.get(), jout);
|
conf.get(), jout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
|||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixGetQuantileCut
|
||||||
|
* Signature: (J[[J[[F)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
|
||||||
|
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
|
||||||
|
using namespace xgboost; // NOLINT
|
||||||
|
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
|
||||||
|
|
||||||
|
char const *str_indptr;
|
||||||
|
char const *str_data;
|
||||||
|
Json config{Object{}};
|
||||||
|
auto str_config = Json::Dump(config);
|
||||||
|
|
||||||
|
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
|
||||||
|
|
||||||
|
ArrayInterface<1> indptr{StringView{str_indptr}};
|
||||||
|
ArrayInterface<1> data{StringView{str_data}};
|
||||||
|
CHECK_GE(indptr.Shape(0), 2);
|
||||||
|
|
||||||
|
// Cut ptr
|
||||||
|
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
|
||||||
|
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
|
||||||
|
CHECK_LT(indptr(indptr.Shape(0) - 1),
|
||||||
|
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
|
||||||
|
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
|
||||||
|
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
|
||||||
|
static_cast<jlong const *>(indptr.data));
|
||||||
|
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
|
||||||
|
|
||||||
|
// Cut values
|
||||||
|
auto n_cuts = indptr(indptr.Shape(0) - 1);
|
||||||
|
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
|
||||||
|
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
|
||||||
|
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
|
||||||
|
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|||||||
@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
|
|||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGQuantileDMatrixCreateFromCallback
|
* Method: XGQuantileDMatrixCreateFromCallback
|
||||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
|
||||||
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
|
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixGetQuantileCut
|
||||||
|
* Signature: (J[[J[[F)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
|
||||||
|
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -258,8 +258,7 @@ public class DMatrixTest {
|
|||||||
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
private DMatrix createFromDenseMatrix() throws XGBoostError {
|
||||||
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
|
|
||||||
//create DMatrix from 10*5 dense matrix
|
//create DMatrix from 10*5 dense matrix
|
||||||
int nrow = 10;
|
int nrow = 10;
|
||||||
int ncol = 5;
|
int ncol = 5;
|
||||||
@ -280,12 +279,17 @@ public class DMatrixTest {
|
|||||||
label0[i] = random.nextFloat();
|
label0[i] = random.nextFloat();
|
||||||
}
|
}
|
||||||
|
|
||||||
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
|
DMatrix dm = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||||
dmat0.setLabel(label0);
|
dm.setLabel(label0);
|
||||||
|
return dm;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
|
||||||
|
DMatrix dm = createFromDenseMatrix();
|
||||||
//check
|
//check
|
||||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
TestCase.assertTrue(dm.rowNum() == 10);
|
||||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
TestCase.assertTrue(dm.getLabel().length == 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -493,4 +497,28 @@ public class DMatrixTest {
|
|||||||
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
|
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void getGetQuantileCut() throws XGBoostError {
|
||||||
|
DMatrix Xy = createFromDenseMatrix();
|
||||||
|
Map<String, Object> params = new HashMap<String, Object>();
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
watches.put("train", Xy);
|
||||||
|
XGBoost.train(Xy, params, 1, watches, null, null); // Create the cuts
|
||||||
|
DMatrix.QuantileCut cuts = Xy.getQuantileCut();
|
||||||
|
TestCase.assertEquals(cuts.indptr.length, 6);
|
||||||
|
for (int i = 1; i < cuts.indptr.length; ++i) {
|
||||||
|
// Number of bins for each feature + min value.
|
||||||
|
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] >= 5);
|
||||||
|
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] <= Xy.rowNum() + 1);
|
||||||
|
}
|
||||||
|
TestCase.assertEquals(cuts.values.length, cuts.indptr[cuts.indptr.length - 1]);
|
||||||
|
for (int i = 1; i < cuts.indptr.length; ++i) {
|
||||||
|
long begin = cuts.indptr[i - 1];
|
||||||
|
long end = cuts.indptr[i];
|
||||||
|
for (long j = begin + 1; j < end; ++j) {
|
||||||
|
TestCase.assertTrue(cuts.values[(int) j] > cuts.values[(int) j - 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user