fix examples
This commit is contained in:
@@ -81,13 +81,14 @@ object XGBoost {
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param dtrain The training data.
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(params: Map[String, Any],
|
||||
dtrain: DataSet[LabeledVector],
|
||||
round: Int): XGBoostModel = {
|
||||
def train(
|
||||
dtrain: DataSet[LabeledVector],
|
||||
params: Map[String, Any],
|
||||
round: Int): XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start()) {
|
||||
dtrain
|
||||
|
||||
@@ -37,6 +37,15 @@ class XGBoostModel (booster: Booster) extends Serializable {
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
@@ -44,7 +53,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
|
||||
* @return The prediction result.
|
||||
*/
|
||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
|
||||
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
|
||||
Reference in New Issue
Block a user