Expose setCustomObj & setCustomEval for XGBoostClassifier & XGBoostRegressor. (#3486)
This commit is contained in:
parent
b6dcbf0e07
commit
c004cea788
@ -22,6 +22,7 @@ import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
@ -134,6 +135,10 @@ class XGBoostClassifier (
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||
|
||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
|
||||
@ -23,6 +23,7 @@ import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.TaskContext
|
||||
@ -136,6 +137,10 @@ class XGBoostRegressor (
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||
|
||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user