[jvm-packages] unify the set features API (#7692)
xgboost4j-spark provides 2 sets of API for setting features, one for CPU, another for GPU, which may cause confusion. This PR removes the GPU API and adds an override CPU function setFeaturesCol to accept Array[String] parameters.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -148,7 +148,7 @@ class XGBoostClassifier (
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
def setFeaturesCol(value: Array[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
@@ -264,7 +264,7 @@ class XGBoostClassificationModel private[ml](
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
def setFeaturesCol(value: Array[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -150,7 +150,7 @@ class XGBoostRegressor (
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
def setFeaturesCols(value: Array[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
@@ -257,7 +257,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
||||
* all feature columns must be numeric types.
|
||||
*/
|
||||
def setFeaturesCols(value: Seq[String]): this.type =
|
||||
def setFeaturesCols(value: Array[String]): this.type =
|
||||
set(featuresCols, value)
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,38 +16,19 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.ml.param.{BooleanParam, Param, Params}
|
||||
import org.apache.spark.ml.param.{Params, StringArrayParam}
|
||||
|
||||
trait GpuParams extends Params {
|
||||
/**
|
||||
* Param for the names of feature columns.
|
||||
* Param for the names of feature columns for GPU pipeline.
|
||||
* @group param
|
||||
*/
|
||||
final val featuresCols: StringSeqParam = new StringSeqParam(this, "featuresCols",
|
||||
"a sequence of feature column names.")
|
||||
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
|
||||
"an array of feature column names for GPU pipeline.")
|
||||
|
||||
setDefault(featuresCols, Seq.empty[String])
|
||||
setDefault(featuresCols, Array.empty[String])
|
||||
|
||||
/** @group getParam */
|
||||
final def getFeaturesCols: Seq[String] = $(featuresCols)
|
||||
final def getFeaturesCols: Array[String] = $(featuresCols)
|
||||
|
||||
}
|
||||
|
||||
class StringSeqParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[Seq[String]](parent, name, doc) {
|
||||
|
||||
override def jsonEncode(value: Seq[String]): String = {
|
||||
import org.json4s.JsonDSL._
|
||||
compact(render(value))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): Seq[String] = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[Seq[String]]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user