Expose setCustomObj & setCustomEval for XGBoostClassifier & XGBoostRegressor. (#3486)
This commit is contained in:
parent
b6dcbf0e07
commit
c004cea788
@ -22,6 +22,7 @@ import scala.collection.mutable
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
@ -134,6 +135,10 @@ class XGBoostClassifier (
|
|||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||||
|
|
||||||
|
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||||
|
|
||||||
|
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import ml.dmlc.xgboost4j.java.Rabit
|
|||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
@ -136,6 +137,10 @@ class XGBoostRegressor (
|
|||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||||
|
|
||||||
|
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||||
|
|
||||||
|
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user